Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class OnnxTraderModelServiceTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
OnnxTraderModelService service = new OnnxTraderModelService(
|
||||
properties,
|
||||
objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
new FakeInferenceClient());
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
|
||||
assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12");
|
||||
assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5");
|
||||
assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25");
|
||||
assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12");
|
||||
assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38");
|
||||
assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05");
|
||||
}
|
||||
|
||||
@Test
|
||||
void failsWhenShadowSnapshotHasNoOodScore() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
OnnxTraderModelService service = new OnnxTraderModelService(
|
||||
properties,
|
||||
objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
new FakeInferenceClient());
|
||||
|
||||
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
|
||||
|
||||
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("ood_score");
|
||||
}
|
||||
|
||||
private static final class FakeInferenceClient implements TraderOnnxInferenceClient {
|
||||
@Override
|
||||
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
|
||||
return switch (manifest.modelType()) {
|
||||
case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
|
||||
case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
|
||||
case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
|
||||
case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user