Improve Trader V4 training pipeline

Align entry labels with max future edge, tune direction labeling, and harden regression evaluation.

Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
This commit is contained in:
Codex
2026-06-27 19:57:29 +08:00
parent e58e4a5572
commit 9acb3460a1
27 changed files with 2059 additions and 341 deletions
@@ -29,7 +29,7 @@ import java.util.stream.StreamSupport;
public class TraderArtifactLoader {
private static final Logger log = LoggerFactory.getLogger(TraderArtifactLoader.class);
private static final Set<String> REQUIRED_MODELS = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
private static final int REQUIRED_FEATURE_COUNT = 39;
private static final int REQUIRED_FEATURE_COUNT = 54;
private static final int REQUIRED_ONNX_OPSET_VERSION = 17;
private static final Map<String, Set<String>> REQUIRED_OUTPUT_MAPPING_KEYS = Map.of(
"DIRECTION", Set.of("long_prob", "short_prob", "neutral_prob"),
@@ -27,7 +27,7 @@ import java.util.stream.StreamSupport;
@Component
public class TraderFeatureVectorBuilder {
private static final Logger log = LoggerFactory.getLogger(TraderFeatureVectorBuilder.class);
private static final int REQUIRED_FEATURE_COUNT = 39;
private static final int REQUIRED_FEATURE_COUNT = 54;
private final TraderProperties properties;
private final ObjectMapper objectMapper;
@@ -106,7 +106,7 @@ public final class TestFixtures {
public static TraderMarketSnapshot snapshot(boolean dataReady, String depthNotional5Bps) {
return new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", "BTC-USDT-PERP", T0,
"feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), BigDecimal.ZERO,
"feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"), BigDecimal.ZERO,
bd(depthNotional5Bps), bd(depthNotional5Bps), bd(depthNotional5Bps), dataReady,
featureJson(), dataQualityJson());
}
@@ -114,7 +114,7 @@ public final class TestFixtures {
public static ReplayMarketEvent replayEvent(String runId, Instant eventTime, String markPrice,
String indexPrice, String depthNotional) {
BigDecimal depth = bd(depthNotional);
return new ReplayMarketEvent(runId, "BTC-USDT-PERP", eventTime, "feature-v4-p0",
return new ReplayMarketEvent(runId, "BTC-USDT-PERP", eventTime, "feature-v4-p2-book-cross",
bd(markPrice), bd(indexPrice), bd("1.2"), bd("0.5"),
depth, depth.multiply(new BigDecimal("1.4")), depth.multiply(new BigDecimal("2.2")),
depth.compareTo(BigDecimal.ZERO) > 0, featureJson(), dataQualityJson());
@@ -165,6 +165,21 @@ public final class TestFixtures {
features.put("minute_of_day_sin", bd("0.0"));
features.put("minute_of_day_cos", bd("1.0"));
features.put("minutes_to_next_funding", bd("120"));
features.put("book_top_imbalance", bd("0.18"));
features.put("book_microprice_basis_bps", bd("0.04"));
features.put("book_bid_depth_l5_quote", bd("5200000"));
features.put("book_ask_depth_l5_quote", bd("4700000"));
features.put("book_depth_imbalance_l5", bd("0.05"));
features.put("book_depth_imbalance_l20", bd("0.08"));
features.put("book_depth_concentration_l5_l20", bd("0.03"));
features.put("book_pressure_spread_ratio", bd("0.033333"));
features.put("book_pressure_taker_1m", bd("0.0032"));
features.put("book_pressure_taker_5m", bd("0.0048"));
features.put("book_l20_imbalance_taker_15m", bd("0.0032"));
features.put("book_l20_imbalance_ret_15m", bd("0.656"));
features.put("book_pressure_vol_adjusted", bd("0.004255"));
features.put("book_depth_pressure_gap", bd("-0.03"));
features.put("book_pressure_reversal_15m", bd("-0.328"));
return Map.copyOf(features);
}
@@ -344,7 +359,14 @@ public final class TestFixtures {
"liquidation_buy_notional_1m","liquidation_sell_notional_1m",
"liquidation_imbalance_15m","liquidation_notional_zscore_15m",
"liquidation_available","minute_of_day_sin","minute_of_day_cos",
"minutes_to_next_funding"
"minutes_to_next_funding","book_top_imbalance","book_microprice_basis_bps",
"book_bid_depth_l5_quote","book_ask_depth_l5_quote",
"book_depth_imbalance_l5","book_depth_imbalance_l20",
"book_depth_concentration_l5_l20","book_pressure_spread_ratio",
"book_pressure_taker_1m","book_pressure_taker_5m",
"book_l20_imbalance_taker_15m","book_l20_imbalance_ret_15m",
"book_pressure_vol_adjusted","book_depth_pressure_gap",
"book_pressure_reversal_15m"
]
""");
String outputSchemaHash = writeArtifact(artifactRoot.resolve("schemas/outputs.json"), """
@@ -418,7 +440,7 @@ public final class TestFixtures {
"manifest_schema_version": "trader-model-bundle-v1",
"model_bundle_version": "trader-v4-btc-p0",
"calibration_bundle_version": "cal-v4-btc-p0",
"feature_version": "feature-v4-p0",
"feature_version": "feature-v4-p2-book-cross",
"label_version": "label-v4-p0",
"split_version": "split-v4-p0",
"training_run_id": "train-run-p0",
@@ -441,9 +463,9 @@ public final class TestFixtures {
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
"output_schema_hash":"%7$s","output_tensor_names_json":["probabilities"],
"output_mapping_json":{"long_prob":"probabilities[0]","short_prob":"probabilities[1]","neutral_prob":"probabilities[2]"},
@@ -461,9 +483,9 @@ public final class TestFixtures {
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
"output_schema_hash":"%7$s","output_tensor_names_json":["entry"],
"output_mapping_json":{"long_entry_prob":"entry[0]","short_entry_prob":"entry[1]","long_expected_net_edge_bps":"entry[2]","short_expected_net_edge_bps":"entry[3]"},
@@ -481,9 +503,9 @@ public final class TestFixtures {
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
"output_schema_hash":"%7$s","output_tensor_names_json":["continue"],
"output_mapping_json":{"long_continue_prob":"continue[0]","short_continue_prob":"continue[1]","long_expected_continue_edge_bps":"continue[2]","short_expected_continue_edge_bps":"continue[3]"},
@@ -501,9 +523,9 @@ public final class TestFixtures {
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
"output_schema_hash":"%7$s","output_tensor_names_json":["exit"],
"output_mapping_json":{"long_exit_prob":"exit[0]","short_exit_prob":"exit[1]","long_adverse_move_bps":"exit[2]","short_adverse_move_bps":"exit[3]","adverse_move_prob":"exit[4]","reversal_prob":"exit[5]","stop_hit_prob":"exit[6]","stagnation_prob":"exit[7]"},
@@ -521,9 +543,9 @@ public final class TestFixtures {
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
"feature_version":"feature-v4-p2-book-cross","feature_schema_path":"schemas/features.json",
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":54},
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
"output_schema_hash":"%7$s","output_tensor_names_json":["risk"],
"output_mapping_json":{"market_risk_prob":"risk[0]","long_position_risk_prob":"risk[1]","short_position_risk_prob":"risk[2]","market_path_risk_bps":"risk[3]","long_position_path_risk_bps":"risk[4]","short_position_path_risk_bps":"risk[5]","market_drawdown_prob":"risk[6]","volatility_expansion_prob":"risk[7]","spike_prob":"risk[8]","liquidity_deterioration_prob":"risk[9]","position_drawdown_prob":"risk[10]"},
@@ -29,7 +29,7 @@ class TraderArtifactLoaderTest {
assertThat(bundle.pricePlanContext().pricePlanConfigHash()).isEqualTo("p0-price-plan-hash");
assertThat(bundle.modelManifests()).allSatisfy(manifest -> {
assertThat(manifest.featureOrderPath()).isEqualTo("schemas/feature_order.json");
assertThat(manifest.inputShapeJson()).containsEntry("features", 39);
assertThat(manifest.inputShapeJson()).containsEntry("features", 54);
assertThat(manifest.onnxOpsetVersion()).isEqualTo(17);
});
assertThat(bundle.requireReplayModelFixture().entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12.0");
@@ -89,7 +89,7 @@ class TraderArtifactLoaderTest {
void rejectsNonV4InputShape() throws IOException {
writeArtifactBundle(artifactRoot);
Path manifest = artifactRoot.resolve("manifests/model_manifest.json");
Files.writeString(manifest, Files.readString(manifest).replace("\"features\":39", "\"features\":38"));
Files.writeString(manifest, Files.readString(manifest).replace("\"features\":54", "\"features\":53"));
assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle())
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
@@ -27,9 +27,11 @@ class TraderFeatureVectorBuilderTest {
float[] values = builder.build(snapshot(), bundle);
assertThat(values).hasSize(39);
assertThat(values).hasSize(54);
assertThat(values[0]).isEqualTo(1.1f);
assertThat(values[38]).isEqualTo(120.0f);
assertThat(values[45]).isEqualTo(0.03f);
assertThat(values[53]).isEqualTo(-0.328f);
}
@Test
@@ -41,7 +43,7 @@ class TraderFeatureVectorBuilderTest {
Map<String, Object> features = new LinkedHashMap<>(featureJson());
features.remove("ret_1m_bps");
TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
"BTC-USDT-PERP", T0, "feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"),
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson());
assertThatThrownBy(() -> builder.build(badSnapshot, bundle))
@@ -58,7 +60,7 @@ class TraderFeatureVectorBuilderTest {
Map<String, Object> features = new LinkedHashMap<>(featureJson());
features.put("account_balance", bd("1000"));
TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
"BTC-USDT-PERP", T0, "feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"),
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson());
assertThatThrownBy(() -> builder.build(badSnapshot, bundle))
@@ -52,7 +52,7 @@ class OnnxTraderModelServiceTest {
new FakeInferenceClient());
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
"BTC-USDT-PERP", T0, "feature-v4-p2-book-cross", bd("100"), bd("99.5"), bd("1.2"),
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
@@ -25,7 +25,7 @@ class OrtTraderOnnxInferenceClientTest {
.orElseThrow();
assertThatThrownBy(() -> new OrtTraderOnnxInferenceClient()
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[39]))
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[54]))
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
.hasMessageContaining("ONNX model cannot be loaded");
}