Files
quant-trader-service/src/test/java/com/quantai/trader/model/OnnxTraderModelServiceTest.java
T

78 lines
3.8 KiB
Java

package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactLoader;
import com.quantai.trader.artifact.TraderModelManifest;
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;
import static com.quantai.trader.TestFixtures.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class OnnxTraderModelServiceTest {
@TempDir
Path artifactRoot;
@Test
void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
OnnxTraderModelService service = new OnnxTraderModelService(
properties,
objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
new FakeInferenceClient());
var output = service.evaluate(snapshot(), bundle);
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12");
assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5");
assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25");
assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12");
assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38");
assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05");
}
@Test
void failsWhenShadowSnapshotHasNoOodScore() throws IOException {
writeArtifactBundle(artifactRoot);
var properties = propertiesWithArtifactRoot(artifactRoot);
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
OnnxTraderModelService service = new OnnxTraderModelService(
properties,
objectMapper(),
new TraderFeatureVectorBuilder(properties, objectMapper()),
new FakeInferenceClient());
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("ood_score");
}
private static final class FakeInferenceClient implements TraderOnnxInferenceClient {
@Override
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
return switch (manifest.modelType()) {
case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
};
}
}
}