Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class OnnxTraderModelServiceTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
OnnxTraderModelService service = new OnnxTraderModelService(
|
||||
properties,
|
||||
objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
new FakeInferenceClient());
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
|
||||
assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12");
|
||||
assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5");
|
||||
assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25");
|
||||
assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12");
|
||||
assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38");
|
||||
assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05");
|
||||
}
|
||||
|
||||
@Test
|
||||
void failsWhenShadowSnapshotHasNoOodScore() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
OnnxTraderModelService service = new OnnxTraderModelService(
|
||||
properties,
|
||||
objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
new FakeInferenceClient());
|
||||
|
||||
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
|
||||
|
||||
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("ood_score");
|
||||
}
|
||||
|
||||
private static final class FakeInferenceClient implements TraderOnnxInferenceClient {
|
||||
@Override
|
||||
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
|
||||
return switch (manifest.modelType()) {
|
||||
case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
|
||||
case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
|
||||
case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
|
||||
case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class OrtTraderOnnxInferenceClientTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void failsClearlyWhenOnnxFileIsNotLoadable() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
var manifest = bundle.modelManifests().stream()
|
||||
.filter(item -> item.modelType().equals("DIRECTION"))
|
||||
.findFirst()
|
||||
.orElseThrow();
|
||||
|
||||
assertThatThrownBy(() -> new OrtTraderOnnxInferenceClient()
|
||||
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[39]))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("ONNX model cannot be loaded");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderExecutionMode;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class RoutingTraderModelServiceTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void replaySimUsesReplayFixtureOnly() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
var service = new RoutingTraderModelService(
|
||||
properties,
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new OnnxTraderModelService(properties, objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
(manifest, modelPath, features) -> {
|
||||
throw new AssertionError("REPLAY_SIM must not call ONNX");
|
||||
}));
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shadowUsesOnnxService() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.SHADOW, TraderExecutionMode.SHADOW);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
var service = new RoutingTraderModelService(
|
||||
properties,
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new OnnxTraderModelService(properties, objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
(manifest, modelPath, features) -> switch (manifest.modelType()) {
|
||||
case "DIRECTION" -> java.util.Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
|
||||
case "ENTRY" -> java.util.Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
|
||||
case "CONTINUE" -> java.util.Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
|
||||
case "EXIT" -> java.util.Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
case "RISK" -> java.util.Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
|
||||
}));
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.risk().marketRiskProb()).isEqualByComparingTo("0.20");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsRunModesOutsideP0Scope() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(
|
||||
propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM),
|
||||
objectMapper()).loadActiveBundle();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.PAPER, TraderExecutionMode.PAPER);
|
||||
var service = new RoutingTraderModelService(
|
||||
properties,
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new OnnxTraderModelService(properties, objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
(manifest, modelPath, features) -> java.util.Map.of()));
|
||||
|
||||
assertThatThrownBy(() -> service.evaluate(snapshot(), bundle))
|
||||
.isInstanceOf(TraderException.class)
|
||||
.hasMessageContaining("only supports REPLAY_SIM and SHADOW");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.objectMapper;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class TraderOutputSchemaBoundsTest {
|
||||
@TempDir
|
||||
Path tempDir;
|
||||
|
||||
@Test
|
||||
void clipsBpsByOutputSchemaRange() throws IOException {
|
||||
Path path = tempDir.resolve("output_schema.json");
|
||||
Files.writeString(path, """
|
||||
{"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}}
|
||||
""");
|
||||
|
||||
TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path);
|
||||
|
||||
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("12"))).isEqualByComparingTo("10");
|
||||
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("-12"))).isEqualByComparingTo("-10");
|
||||
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("2"))).isEqualByComparingTo("2");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsMissingFieldRange() throws IOException {
|
||||
Path path = tempDir.resolve("output_schema.json");
|
||||
Files.writeString(path, """
|
||||
{"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}}
|
||||
""");
|
||||
TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path);
|
||||
|
||||
assertThatThrownBy(() -> bounds.clip("shortExpectedNetEdgeBps", BigDecimal.ONE))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("does not define range");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsSchemaWithoutRanges() throws IOException {
|
||||
Path path = tempDir.resolve("output_schema.json");
|
||||
Files.writeString(path, "{\"entry\":{}}");
|
||||
|
||||
assertThatThrownBy(() -> TraderOutputSchemaBounds.read(objectMapper(), path))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("must define field ranges");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.objectMapper;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class TraderProbabilityCalibratorTest {
|
||||
@TempDir
|
||||
Path tempDir;
|
||||
|
||||
@Test
|
||||
void appliesBinningCalibration() throws IOException {
|
||||
Path path = tempDir.resolve("calibrator.json");
|
||||
Files.writeString(path, """
|
||||
{"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
|
||||
""");
|
||||
|
||||
TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION");
|
||||
|
||||
assertThat(calibrator.calibrate("longProb", new BigDecimal("0.40"))).isEqualByComparingTo("0.62");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsUnsupportedCalibrationMethod() throws IOException {
|
||||
Path path = tempDir.resolve("calibrator.json");
|
||||
Files.writeString(path, """
|
||||
{"method":"NONE","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
|
||||
""");
|
||||
|
||||
assertThatThrownBy(() -> TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION"))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("must be BINNING");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsMissingTargetAndOutOfRangeBins() throws IOException {
|
||||
Path path = tempDir.resolve("calibrator.json");
|
||||
Files.writeString(path, """
|
||||
{"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":0.1,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
|
||||
""");
|
||||
TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION");
|
||||
|
||||
assertThatThrownBy(() -> calibrator.calibrate("shortProb", new BigDecimal("0.05")))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("target is missing");
|
||||
assertThatThrownBy(() -> calibrator.calibrate("longProb", new BigDecimal("0.50")))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("does not match any");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user