package com.quantai.trader.model; import com.quantai.trader.artifact.TraderArtifactLoader; import com.quantai.trader.artifact.TraderModelManifest; import com.quantai.trader.feature.TraderFeatureVectorBuilder; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.io.IOException; import java.nio.file.Path; import java.util.Map; import static com.quantai.trader.TestFixtures.*; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; class OnnxTraderModelServiceTest { @TempDir Path artifactRoot; @Test void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException { writeArtifactBundle(artifactRoot); var properties = propertiesWithArtifactRoot(artifactRoot); var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); OnnxTraderModelService service = new OnnxTraderModelService( properties, objectMapper(), new TraderFeatureVectorBuilder(properties, objectMapper()), new FakeInferenceClient()); var output = service.evaluate(snapshot(), bundle); assertThat(output.direction().longProb()).isEqualByComparingTo("0.62"); assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12"); assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5"); assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25"); assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12"); assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38"); assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05"); } @Test void failsWhenShadowSnapshotHasNoOodScore() throws IOException { writeArtifactBundle(artifactRoot); var properties = propertiesWithArtifactRoot(artifactRoot); var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); OnnxTraderModelService service = new OnnxTraderModelService( properties, objectMapper(), new TraderFeatureVectorBuilder(properties, objectMapper()), new FakeInferenceClient()); var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", "BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of()); assertThatThrownBy(() -> service.evaluate(snapshot, bundle)) .isInstanceOf(com.quantai.trader.domain.TraderException.class) .hasMessageContaining("ood_score"); } private static final class FakeInferenceClient implements TraderOnnxInferenceClient { @Override public Map infer(TraderModelManifest manifest, Path modelPath, float[] features) { return switch (manifest.modelType()) { case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f}); case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f}); case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f}); case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f}); case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType()); }; } } }