78 lines
3.8 KiB
Java
78 lines
3.8 KiB
Java
package com.quantai.trader.model;
|
|
|
|
import com.quantai.trader.artifact.TraderArtifactLoader;
|
|
import com.quantai.trader.artifact.TraderModelManifest;
|
|
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
|
import org.junit.jupiter.api.Test;
|
|
import org.junit.jupiter.api.io.TempDir;
|
|
|
|
import java.io.IOException;
|
|
import java.nio.file.Path;
|
|
import java.util.Map;
|
|
|
|
import static com.quantai.trader.TestFixtures.*;
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
|
|
class OnnxTraderModelServiceTest {
|
|
@TempDir
|
|
Path artifactRoot;
|
|
|
|
@Test
|
|
void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException {
|
|
writeArtifactBundle(artifactRoot);
|
|
var properties = propertiesWithArtifactRoot(artifactRoot);
|
|
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
|
OnnxTraderModelService service = new OnnxTraderModelService(
|
|
properties,
|
|
objectMapper(),
|
|
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
|
new FakeInferenceClient());
|
|
|
|
var output = service.evaluate(snapshot(), bundle);
|
|
|
|
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
|
|
assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12");
|
|
assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5");
|
|
assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25");
|
|
assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12");
|
|
assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38");
|
|
assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05");
|
|
}
|
|
|
|
@Test
|
|
void failsWhenShadowSnapshotHasNoOodScore() throws IOException {
|
|
writeArtifactBundle(artifactRoot);
|
|
var properties = propertiesWithArtifactRoot(artifactRoot);
|
|
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
|
OnnxTraderModelService service = new OnnxTraderModelService(
|
|
properties,
|
|
objectMapper(),
|
|
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
|
new FakeInferenceClient());
|
|
|
|
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
|
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
|
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
|
|
|
|
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
|
|
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
|
.hasMessageContaining("ood_score");
|
|
}
|
|
|
|
private static final class FakeInferenceClient implements TraderOnnxInferenceClient {
|
|
@Override
|
|
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
|
|
return switch (manifest.modelType()) {
|
|
case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
|
|
case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
|
|
case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
|
|
case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
|
|
case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
|
|
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
|
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
|
|
};
|
|
}
|
|
}
|
|
}
|