Improve Trader V4 training pipeline
Align entry labels with max future edge, tune direction labeling, and harden regression evaluation. Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
This commit is contained in:
@@ -29,7 +29,7 @@ import java.util.stream.StreamSupport;
|
||||
public class TraderArtifactLoader {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderArtifactLoader.class);
|
||||
private static final Set<String> REQUIRED_MODELS = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
|
||||
private static final int REQUIRED_FEATURE_COUNT = 39;
|
||||
private static final int REQUIRED_FEATURE_COUNT = 54;
|
||||
private static final int REQUIRED_ONNX_OPSET_VERSION = 17;
|
||||
private static final Map<String, Set<String>> REQUIRED_OUTPUT_MAPPING_KEYS = Map.of(
|
||||
"DIRECTION", Set.of("long_prob", "short_prob", "neutral_prob"),
|
||||
|
||||
@@ -27,7 +27,7 @@ import java.util.stream.StreamSupport;
|
||||
@Component
|
||||
public class TraderFeatureVectorBuilder {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderFeatureVectorBuilder.class);
|
||||
private static final int REQUIRED_FEATURE_COUNT = 39;
|
||||
private static final int REQUIRED_FEATURE_COUNT = 54;
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@@ -106,7 +106,7 @@ public final class TestFixtures {
|
||||
|
||||
public static TraderMarketSnapshot snapshot(boolean dataReady, String depthNotional5Bps) {
|
||||
return new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", "BTC-USDT-PERP", T0,
|
||||
"feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), BigDecimal.ZERO,
|
||||
"feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"), BigDecimal.ZERO,
|
||||
bd(depthNotional5Bps), bd(depthNotional5Bps), bd(depthNotional5Bps), dataReady,
|
||||
featureJson(), dataQualityJson());
|
||||
}
|
||||
@@ -114,7 +114,7 @@ public final class TestFixtures {
|
||||
public static ReplayMarketEvent replayEvent(String runId, Instant eventTime, String markPrice,
|
||||
String indexPrice, String depthNotional) {
|
||||
BigDecimal depth = bd(depthNotional);
|
||||
return new ReplayMarketEvent(runId, "BTC-USDT-PERP", eventTime, "feature-v4-p0",
|
||||
return new ReplayMarketEvent(runId, "BTC-USDT-PERP", eventTime, "feature-v4-p2-book-cross",
|
||||
bd(markPrice), bd(indexPrice), bd("1.2"), bd("0.5"),
|
||||
depth, depth.multiply(new BigDecimal("1.4")), depth.multiply(new BigDecimal("2.2")),
|
||||
depth.compareTo(BigDecimal.ZERO) > 0, featureJson(), dataQualityJson());
|
||||
@@ -165,6 +165,21 @@ public final class TestFixtures {
|
||||
features.put("minute_of_day_sin", bd("0.0"));
|
||||
features.put("minute_of_day_cos", bd("1.0"));
|
||||
features.put("minutes_to_next_funding", bd("120"));
|
||||
features.put("book_top_imbalance", bd("0.18"));
|
||||
features.put("book_microprice_basis_bps", bd("0.04"));
|
||||
features.put("book_bid_depth_l5_quote", bd("5200000"));
|
||||
features.put("book_ask_depth_l5_quote", bd("4700000"));
|
||||
features.put("book_depth_imbalance_l5", bd("0.05"));
|
||||
features.put("book_depth_imbalance_l20", bd("0.08"));
|
||||
features.put("book_depth_concentration_l5_l20", bd("0.03"));
|
||||
features.put("book_pressure_spread_ratio", bd("0.033333"));
|
||||
features.put("book_pressure_taker_1m", bd("0.0032"));
|
||||
features.put("book_pressure_taker_5m", bd("0.0048"));
|
||||
features.put("book_l20_imbalance_taker_15m", bd("0.0032"));
|
||||
features.put("book_l20_imbalance_ret_15m", bd("0.656"));
|
||||
features.put("book_pressure_vol_adjusted", bd("0.004255"));
|
||||
features.put("book_depth_pressure_gap", bd("-0.03"));
|
||||
features.put("book_pressure_reversal_15m", bd("-0.328"));
|
||||
return Map.copyOf(features);
|
||||
}
|
||||
|
||||
@@ -344,7 +359,14 @@ public final class TestFixtures {
|
||||
"liquidation_buy_notional_1m","liquidation_sell_notional_1m",
|
||||
"liquidation_imbalance_15m","liquidation_notional_zscore_15m",
|
||||
"liquidation_available","minute_of_day_sin","minute_of_day_cos",
|
||||
"minutes_to_next_funding"
|
||||
"minutes_to_next_funding","book_top_imbalance","book_microprice_basis_bps",
|
||||
"book_bid_depth_l5_quote","book_ask_depth_l5_quote",
|
||||
"book_depth_imbalance_l5","book_depth_imbalance_l20",
|
||||
"book_depth_concentration_l5_l20","book_pressure_spread_ratio",
|
||||
"book_pressure_taker_1m","book_pressure_taker_5m",
|
||||
"book_l20_imbalance_taker_15m","book_l20_imbalance_ret_15m",
|
||||
"book_pressure_vol_adjusted","book_depth_pressure_gap",
|
||||
"book_pressure_reversal_15m"
|
||||
]
|
||||
""");
|
||||
String outputSchemaHash = writeArtifact(artifactRoot.resolve("schemas/outputs.json"), """
|
||||
@@ -418,7 +440,7 @@ public final class TestFixtures {
|
||||
"manifest_schema_version": "trader-model-bundle-v1",
|
||||
"model_bundle_version": "trader-v4-btc-p0",
|
||||
"calibration_bundle_version": "cal-v4-btc-p0",
|
||||
"feature_version": "feature-v4-p0",
|
||||
"feature_version": "feature-v4-p2-book-cross",
|
||||
"label_version": "label-v4-p0",
|
||||
"split_version": "split-v4-p0",
|
||||
"training_run_id": "train-run-p0",
|
||||
@@ -441,9 +463,9 @@ public final class TestFixtures {
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["probabilities"],
|
||||
"output_mapping_json":{"long_prob":"probabilities[0]","short_prob":"probabilities[1]","neutral_prob":"probabilities[2]"},
|
||||
@@ -461,9 +483,9 @@ public final class TestFixtures {
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["entry"],
|
||||
"output_mapping_json":{"long_entry_prob":"entry[0]","short_entry_prob":"entry[1]","long_expected_net_edge_bps":"entry[2]","short_expected_net_edge_bps":"entry[3]"},
|
||||
@@ -481,9 +503,9 @@ public final class TestFixtures {
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["continue"],
|
||||
"output_mapping_json":{"long_continue_prob":"continue[0]","short_continue_prob":"continue[1]","long_expected_continue_edge_bps":"continue[2]","short_expected_continue_edge_bps":"continue[3]"},
|
||||
@@ -501,9 +523,9 @@ public final class TestFixtures {
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["exit"],
|
||||
"output_mapping_json":{"long_exit_prob":"exit[0]","short_exit_prob":"exit[1]","long_adverse_move_bps":"exit[2]","short_adverse_move_bps":"exit[3]","adverse_move_prob":"exit[4]","reversal_prob":"exit[5]","stop_hit_prob":"exit[6]","stagnation_prob":"exit[7]"},
|
||||
@@ -521,9 +543,9 @@ public final class TestFixtures {
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["risk"],
|
||||
"output_mapping_json":{"market_risk_prob":"risk[0]","long_position_risk_prob":"risk[1]","short_position_risk_prob":"risk[2]","market_path_risk_bps":"risk[3]","long_position_path_risk_bps":"risk[4]","short_position_path_risk_bps":"risk[5]","market_drawdown_prob":"risk[6]","volatility_expansion_prob":"risk[7]","spike_prob":"risk[8]","liquidity_deterioration_prob":"risk[9]","position_drawdown_prob":"risk[10]"},
|
||||
|
||||
@@ -29,7 +29,7 @@ class TraderArtifactLoaderTest {
|
||||
assertThat(bundle.pricePlanContext().pricePlanConfigHash()).isEqualTo("p0-price-plan-hash");
|
||||
assertThat(bundle.modelManifests()).allSatisfy(manifest -> {
|
||||
assertThat(manifest.featureOrderPath()).isEqualTo("schemas/feature_order.json");
|
||||
assertThat(manifest.inputShapeJson()).containsEntry("features", 39);
|
||||
assertThat(manifest.inputShapeJson()).containsEntry("features", 54);
|
||||
assertThat(manifest.onnxOpsetVersion()).isEqualTo(17);
|
||||
});
|
||||
assertThat(bundle.requireReplayModelFixture().entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12.0");
|
||||
@@ -89,7 +89,7 @@ class TraderArtifactLoaderTest {
|
||||
void rejectsNonV4InputShape() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
Path manifest = artifactRoot.resolve("manifests/model_manifest.json");
|
||||
Files.writeString(manifest, Files.readString(manifest).replace("\"features\":39", "\"features\":38"));
|
||||
Files.writeString(manifest, Files.readString(manifest).replace("\"features\":54", "\"features\":53"));
|
||||
|
||||
assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle())
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
|
||||
@@ -27,9 +27,11 @@ class TraderFeatureVectorBuilderTest {
|
||||
|
||||
float[] values = builder.build(snapshot(), bundle);
|
||||
|
||||
assertThat(values).hasSize(39);
|
||||
assertThat(values).hasSize(54);
|
||||
assertThat(values[0]).isEqualTo(1.1f);
|
||||
assertThat(values[38]).isEqualTo(120.0f);
|
||||
assertThat(values[45]).isEqualTo(0.03f);
|
||||
assertThat(values[53]).isEqualTo(-0.328f);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -41,7 +43,7 @@ class TraderFeatureVectorBuilderTest {
|
||||
Map<String, Object> features = new LinkedHashMap<>(featureJson());
|
||||
features.remove("ret_1m_bps");
|
||||
TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson());
|
||||
|
||||
assertThatThrownBy(() -> builder.build(badSnapshot, bundle))
|
||||
@@ -58,7 +60,7 @@ class TraderFeatureVectorBuilderTest {
|
||||
Map<String, Object> features = new LinkedHashMap<>(featureJson());
|
||||
features.put("account_balance", bd("1000"));
|
||||
TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson());
|
||||
|
||||
assertThatThrownBy(() -> builder.build(badSnapshot, bundle))
|
||||
|
||||
@@ -52,7 +52,7 @@ class OnnxTraderModelServiceTest {
|
||||
new FakeInferenceClient());
|
||||
|
||||
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
|
||||
|
||||
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
|
||||
|
||||
@@ -25,7 +25,7 @@ class OrtTraderOnnxInferenceClientTest {
|
||||
.orElseThrow();
|
||||
|
||||
assertThatThrownBy(() -> new OrtTraderOnnxInferenceClient()
|
||||
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[39]))
|
||||
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[54]))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("ONNX model cannot be loaded");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user