Implement Trader V4 training artifact pipeline

This commit is contained in:
Codex
2026-06-27 16:15:23 +08:00
parent dad6b831b4
commit e58e4a5572
113 changed files with 7959 additions and 477 deletions
@@ -0,0 +1,77 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactLoader;
import com.quantai.trader.artifact.TraderModelManifest;
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;
import static com.quantai.trader.TestFixtures.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class OnnxTraderModelServiceTest {
@TempDir
Path artifactRoot;
@Test
void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
OnnxTraderModelService service = new OnnxTraderModelService(
properties,
objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
new FakeInferenceClient());
var output = service.evaluate(snapshot(), bundle);
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12");
assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5");
assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25");
assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12");
assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38");
assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05");
}
@Test
void failsWhenShadowSnapshotHasNoOodScore() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
OnnxTraderModelService service = new OnnxTraderModelService(
properties,
objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
new FakeInferenceClient());
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("ood_score");
}
private static final class FakeInferenceClient implements TraderOnnxInferenceClient {
@Override
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
return switch (manifest.modelType()) {
case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
};
}
}
}
@@ -0,0 +1,32 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactLoader;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.nio.file.Path;
import static com.quantai.trader.TestFixtures.*;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class OrtTraderOnnxInferenceClientTest {
@TempDir
Path artifactRoot;
@Test
void failsClearlyWhenOnnxFileIsNotLoadable() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
var manifest = bundle.modelManifests().stream()
.filter(item -> item.modelType().equals("DIRECTION"))
.findFirst()
.orElseThrow();
assertThatThrownBy(() -> new OrtTraderOnnxInferenceClient()
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[39]))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("ONNX model cannot be loaded");
}
}
@@ -0,0 +1,84 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactLoader;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.enums.TraderExecutionMode;
import com.quantai.trader.enums.TraderRunMode;
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.nio.file.Path;
import static com.quantai.trader.TestFixtures.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class RoutingTraderModelServiceTest {
@TempDir
Path artifactRoot;
@Test
void replaySimUsesReplayFixtureOnly() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
var service = new RoutingTraderModelService(
properties,
new ReplayFixtureTraderModelService(properties),
new OnnxTraderModelService(properties, objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
(manifest, modelPath, features) -> {
throw new AssertionError("REPLAY_SIM must not call ONNX");
}));
var output = service.evaluate(snapshot(), bundle);
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
}
@Test
void shadowUsesOnnxService() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.SHADOW, TraderExecutionMode.SHADOW);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
var service = new RoutingTraderModelService(
properties,
new ReplayFixtureTraderModelService(properties),
new OnnxTraderModelService(properties, objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
(manifest, modelPath, features) -> switch (manifest.modelType()) {
case "DIRECTION" -> java.util.Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
case "ENTRY" -> java.util.Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
case "CONTINUE" -> java.util.Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
case "EXIT" -> java.util.Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
case "RISK" -> java.util.Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
}));
var output = service.evaluate(snapshot(), bundle);
assertThat(output.risk().marketRiskProb()).isEqualByComparingTo("0.20");
}
@Test
void rejectsRunModesOutsideP0Scope() throws IOException {
writeArtifactBundle(artifactRoot);
var bundle = new TraderArtifactLoader(
propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM),
objectMapper()).loadActiveBundle();
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.PAPER, TraderExecutionMode.PAPER);
var service = new RoutingTraderModelService(
properties,
new ReplayFixtureTraderModelService(properties),
new OnnxTraderModelService(properties, objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
(manifest, modelPath, features) -> java.util.Map.of()));
assertThatThrownBy(() -> service.evaluate(snapshot(), bundle))
.isInstanceOf(TraderException.class)
.hasMessageContaining("only supports REPLAY_SIM and SHADOW");
}
}
@@ -0,0 +1,55 @@
package com.quantai.trader.model;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.nio.file.Path;
import static com.quantai.trader.TestFixtures.objectMapper;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class TraderOutputSchemaBoundsTest {
@TempDir
Path tempDir;
@Test
void clipsBpsByOutputSchemaRange() throws IOException {
Path path = tempDir.resolve("output_schema.json");
Files.writeString(path, """
{"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}}
""");
TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path);
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("12"))).isEqualByComparingTo("10");
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("-12"))).isEqualByComparingTo("-10");
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("2"))).isEqualByComparingTo("2");
}
@Test
void rejectsMissingFieldRange() throws IOException {
Path path = tempDir.resolve("output_schema.json");
Files.writeString(path, """
{"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}}
""");
TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path);
assertThatThrownBy(() -> bounds.clip("shortExpectedNetEdgeBps", BigDecimal.ONE))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("does not define range");
}
@Test
void rejectsSchemaWithoutRanges() throws IOException {
Path path = tempDir.resolve("output_schema.json");
Files.writeString(path, "{\"entry\":{}}");
assertThatThrownBy(() -> TraderOutputSchemaBounds.read(objectMapper(), path))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("must define field ranges");
}
}
@@ -0,0 +1,58 @@
package com.quantai.trader.model;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.nio.file.Path;
import static com.quantai.trader.TestFixtures.objectMapper;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class TraderProbabilityCalibratorTest {
@TempDir
Path tempDir;
@Test
void appliesBinningCalibration() throws IOException {
Path path = tempDir.resolve("calibrator.json");
Files.writeString(path, """
{"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
""");
TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION");
assertThat(calibrator.calibrate("longProb", new BigDecimal("0.40"))).isEqualByComparingTo("0.62");
}
@Test
void rejectsUnsupportedCalibrationMethod() throws IOException {
Path path = tempDir.resolve("calibrator.json");
Files.writeString(path, """
{"method":"NONE","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
""");
assertThatThrownBy(() -> TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION"))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("must be BINNING");
}
@Test
void rejectsMissingTargetAndOutOfRangeBins() throws IOException {
Path path = tempDir.resolve("calibrator.json");
Files.writeString(path, """
{"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":0.1,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
""");
TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION");
assertThatThrownBy(() -> calibrator.calibrate("shortProb", new BigDecimal("0.05")))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("target is missing");
assertThatThrownBy(() -> calibrator.calibrate("longProb", new BigDecimal("0.50")))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("does not match any");
}
}