From e58e4a557277210ca201fbcde769002d5661b11f Mon Sep 17 00:00:00 2001 From: Codex Date: Sat, 27 Jun 2026 16:15:23 +0800 Subject: [PATCH] Implement Trader V4 training artifact pipeline --- .gitignore | 2 + pom.xml | 9 + .../JdbcTraderArtifactManifestRepository.java | 177 ++++++ .../trader/artifact/TraderArtifactBundle.java | 25 +- .../trader/artifact/TraderArtifactLoader.java | 281 ++++++++- .../TraderArtifactManifestRepository.java | 5 + .../artifact/TraderArtifactModelPolicy.java | 76 --- .../artifact/TraderCalibrationManifest.java | 25 + .../artifact/TraderModelBundleManifest.java | 7 + .../trader/artifact/TraderModelManifest.java | 59 ++ .../artifact/TraderPmConfigManifest.java | 5 + .../artifact/TraderReplayModelFixture.java | 66 ++ .../trader/config/TraderProperties.java | 60 +- .../controller/TraderApiExceptionHandler.java | 2 +- .../controller/TraderFeedbackController.java | 7 +- .../quantai/trader/domain/ContinueOutput.java | 30 +- .../trader/domain/DirectionOutput.java | 29 +- .../quantai/trader/domain/EntryOutput.java | 40 +- .../com/quantai/trader/domain/ExitOutput.java | 42 +- .../trader/domain/PositionManagerInput.java | 2 + .../com/quantai/trader/domain/RiskOutput.java | 79 ++- .../trader/domain/TraderModelOutput.java | 16 +- .../domain/TraderModelOutputMetadata.java | 39 ++ .../quantai/trader/domain/TraderPmConfig.java | 8 +- .../trader/domain/TraderPricePlanContext.java | 25 + .../quantai/trader/enums/TraderErrorCode.java | 3 + .../feature/TraderFeatureVectorBuilder.java | 144 +++++ .../JdbcTraderFeedbackRepository.java | 38 ++ .../feedback/TraderFeedbackRepository.java | 7 + .../model/ArtifactTraderModelService.java | 72 --- .../trader/model/OnnxTraderModelService.java | 267 ++++++++ .../model/OrtTraderOnnxInferenceClient.java | 78 +++ .../ReplayFixtureTraderModelService.java | 94 +++ .../model/RoutingTraderModelService.java | 49 ++ .../model/TraderOnnxInferenceClient.java | 10 + .../model/TraderOutputSchemaBounds.java | 84 +++ .../model/TraderProbabilityCalibrator.java | 129 ++++ .../outbox/JdbcTraderOutboxRepository.java | 28 + .../trader/outbox/P0OutboxDispatcher.java | 76 +++ .../trader/outbox/TraderOutboxDispatcher.java | 9 + .../trader/outbox/TraderOutboxRepository.java | 6 + .../JdbcTraderDecisionTraceWriter.java | 72 ++- .../TraderDecisionTraceWriter.java | 2 + .../position/TraderPositionManager.java | 71 ++- .../trader/replay/ReplayMarketEvent.java | 31 +- .../trader/replay/TraderP0CycleRunner.java | 107 +++- .../JdbcTraderPostActionStateRepository.java | 70 +++ .../replay/state/P0ReplayStateStore.java | 3 +- .../TraderPostActionStateRepository.java | 5 + .../com/quantai/trader/risk/RiskLimits.java | 3 +- .../quantai/trader/risk/TraderRiskGate.java | 5 +- .../RedisTraderRuntimeControlService.java | 93 +++ .../runtime/TraderRuntimeControlDecision.java | 14 + .../runtime/TraderRuntimeControlService.java | 13 + .../quantai/trader/util/TraderNumbers.java | 11 +- .../db/migration/V1__trader_v4_p0_schema.sql | 202 +++++- .../java/com/quantai/trader/TestFixtures.java | 496 +++++++++++++-- ...cTraderArtifactManifestRepositoryTest.java | 35 ++ .../artifact/TraderArtifactLoaderTest.java | 59 +- .../trader/config/JacksonConfigTest.java | 17 + .../controller/TraderControllerTest.java | 25 +- .../TraderReplayControllerTest.java | 45 ++ .../domain/ModelOutputContractTest.java | 3 +- .../JdbcTraderEvidenceRepositoryTest.java | 29 + .../TraderFeatureVectorBuilderTest.java | 68 ++ .../JdbcTraderFeedbackRepositoryTest.java | 40 ++ .../model/OnnxTraderModelServiceTest.java | 77 +++ .../OrtTraderOnnxInferenceClientTest.java | 32 + .../model/RoutingTraderModelServiceTest.java | 84 +++ .../model/TraderOutputSchemaBoundsTest.java | 55 ++ .../TraderProbabilityCalibratorTest.java | 58 ++ .../trader/outbox/P0OutboxDispatcherTest.java | 82 +++ .../JdbcTraderDecisionTraceWriterTest.java | 4 +- .../position/TraderPositionManagerTest.java | 28 +- .../replay/TraderP0CycleRunnerTest.java | 101 ++- ...bcTraderPostActionStateRepositoryTest.java | 37 ++ .../trader/risk/TraderRiskGateTest.java | 2 +- .../RedisTraderRuntimeControlServiceTest.java | 95 +++ .../runtime/StartupValidationRunnerTest.java | 18 + training/README.md | 33 + training/requirements.txt | 7 + training/scripts/01_audit_source_data.py | 23 + training/scripts/02_build_replay_1m.py | 26 + training/scripts/03_build_splits.py | 31 + training/scripts/04_build_feature_frame.py | 23 + .../scripts/05_build_price_plan_context.py | 23 + training/scripts/06_build_direction_labels.py | 23 + training/scripts/07_build_entry_labels.py | 25 + .../08_build_position_state_samples.py | 21 + .../09_build_continue_exit_risk_labels.py | 25 + training/scripts/10_build_train_datasets.py | 26 + training/scripts/11_train_small_models.py | 20 + training/scripts/12_calibrate_models.py | 19 + training/scripts/13_search_pm_thresholds.py | 19 + training/scripts/14_integrated_backtest.py | 19 + training/scripts/15_export_artifact_bundle.py | 21 + .../scripts/16_validate_artifact_bundle.py | 22 + .../scripts/17_promote_artifact_bundle.py | 21 + training/scripts/_bootstrap.py | 8 + training/tests/test_training_contract.py | 146 +++++ training/trader_training/__init__.py | 1 + training/trader_training/datasets.py | 112 ++++ training/trader_training/exporter.py | 244 ++++++++ training/trader_training/features.py | 342 +++++++++++ training/trader_training/io_utils.py | 162 +++++ training/trader_training/labels.py | 417 +++++++++++++ training/trader_training/onnx_export.py | 73 +++ training/trader_training/pm.py | 541 ++++++++++++++++ training/trader_training/promote.py | 129 ++++ training/trader_training/replay.py | 496 +++++++++++++++ training/trader_training/schemas.py | 206 +++++++ training/trader_training/training.py | 581 ++++++++++++++++++ training/trader_training/validator.py | 149 +++++ 113 files changed, 7959 insertions(+), 477 deletions(-) create mode 100644 src/main/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepository.java create mode 100644 src/main/java/com/quantai/trader/artifact/TraderArtifactManifestRepository.java delete mode 100644 src/main/java/com/quantai/trader/artifact/TraderArtifactModelPolicy.java create mode 100644 src/main/java/com/quantai/trader/artifact/TraderCalibrationManifest.java create mode 100644 src/main/java/com/quantai/trader/artifact/TraderModelManifest.java create mode 100644 src/main/java/com/quantai/trader/artifact/TraderReplayModelFixture.java create mode 100644 src/main/java/com/quantai/trader/domain/TraderModelOutputMetadata.java create mode 100644 src/main/java/com/quantai/trader/domain/TraderPricePlanContext.java create mode 100644 src/main/java/com/quantai/trader/feature/TraderFeatureVectorBuilder.java create mode 100644 src/main/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepository.java create mode 100644 src/main/java/com/quantai/trader/feedback/TraderFeedbackRepository.java delete mode 100644 src/main/java/com/quantai/trader/model/ArtifactTraderModelService.java create mode 100644 src/main/java/com/quantai/trader/model/OnnxTraderModelService.java create mode 100644 src/main/java/com/quantai/trader/model/OrtTraderOnnxInferenceClient.java create mode 100644 src/main/java/com/quantai/trader/model/ReplayFixtureTraderModelService.java create mode 100644 src/main/java/com/quantai/trader/model/RoutingTraderModelService.java create mode 100644 src/main/java/com/quantai/trader/model/TraderOnnxInferenceClient.java create mode 100644 src/main/java/com/quantai/trader/model/TraderOutputSchemaBounds.java create mode 100644 src/main/java/com/quantai/trader/model/TraderProbabilityCalibrator.java create mode 100644 src/main/java/com/quantai/trader/outbox/P0OutboxDispatcher.java create mode 100644 src/main/java/com/quantai/trader/outbox/TraderOutboxDispatcher.java create mode 100644 src/main/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepository.java create mode 100644 src/main/java/com/quantai/trader/replay/state/TraderPostActionStateRepository.java create mode 100644 src/main/java/com/quantai/trader/runtime/RedisTraderRuntimeControlService.java create mode 100644 src/main/java/com/quantai/trader/runtime/TraderRuntimeControlDecision.java create mode 100644 src/main/java/com/quantai/trader/runtime/TraderRuntimeControlService.java create mode 100644 src/test/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepositoryTest.java create mode 100644 src/test/java/com/quantai/trader/config/JacksonConfigTest.java create mode 100644 src/test/java/com/quantai/trader/controller/TraderReplayControllerTest.java create mode 100644 src/test/java/com/quantai/trader/evidence/JdbcTraderEvidenceRepositoryTest.java create mode 100644 src/test/java/com/quantai/trader/feature/TraderFeatureVectorBuilderTest.java create mode 100644 src/test/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepositoryTest.java create mode 100644 src/test/java/com/quantai/trader/model/OnnxTraderModelServiceTest.java create mode 100644 src/test/java/com/quantai/trader/model/OrtTraderOnnxInferenceClientTest.java create mode 100644 src/test/java/com/quantai/trader/model/RoutingTraderModelServiceTest.java create mode 100644 src/test/java/com/quantai/trader/model/TraderOutputSchemaBoundsTest.java create mode 100644 src/test/java/com/quantai/trader/model/TraderProbabilityCalibratorTest.java create mode 100644 src/test/java/com/quantai/trader/outbox/P0OutboxDispatcherTest.java create mode 100644 src/test/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepositoryTest.java create mode 100644 src/test/java/com/quantai/trader/runtime/RedisTraderRuntimeControlServiceTest.java create mode 100644 src/test/java/com/quantai/trader/runtime/StartupValidationRunnerTest.java create mode 100644 training/README.md create mode 100644 training/requirements.txt create mode 100644 training/scripts/01_audit_source_data.py create mode 100644 training/scripts/02_build_replay_1m.py create mode 100644 training/scripts/03_build_splits.py create mode 100644 training/scripts/04_build_feature_frame.py create mode 100644 training/scripts/05_build_price_plan_context.py create mode 100644 training/scripts/06_build_direction_labels.py create mode 100644 training/scripts/07_build_entry_labels.py create mode 100644 training/scripts/08_build_position_state_samples.py create mode 100644 training/scripts/09_build_continue_exit_risk_labels.py create mode 100644 training/scripts/10_build_train_datasets.py create mode 100644 training/scripts/11_train_small_models.py create mode 100644 training/scripts/12_calibrate_models.py create mode 100644 training/scripts/13_search_pm_thresholds.py create mode 100644 training/scripts/14_integrated_backtest.py create mode 100644 training/scripts/15_export_artifact_bundle.py create mode 100644 training/scripts/16_validate_artifact_bundle.py create mode 100644 training/scripts/17_promote_artifact_bundle.py create mode 100644 training/scripts/_bootstrap.py create mode 100644 training/tests/test_training_contract.py create mode 100644 training/trader_training/__init__.py create mode 100644 training/trader_training/datasets.py create mode 100644 training/trader_training/exporter.py create mode 100644 training/trader_training/features.py create mode 100644 training/trader_training/io_utils.py create mode 100644 training/trader_training/labels.py create mode 100644 training/trader_training/onnx_export.py create mode 100644 training/trader_training/pm.py create mode 100644 training/trader_training/promote.py create mode 100644 training/trader_training/replay.py create mode 100644 training/trader_training/schemas.py create mode 100644 training/trader_training/training.py create mode 100644 training/trader_training/validator.py diff --git a/.gitignore b/.gitignore index fd6e6c3..6dd7d02 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ target/ *.iml *.log .DS_Store +__pycache__/ +*.pyc # Runtime and local data stay outside source control. logs/ diff --git a/pom.xml b/pom.xml index d4e2f78..2012b3e 100644 --- a/pom.xml +++ b/pom.xml @@ -47,10 +47,19 @@ org.springframework.boot spring-boot-starter-jdbc + + org.springframework.boot + spring-boot-starter-data-redis + org.springframework.boot spring-boot-starter-flyway + + com.microsoft.onnxruntime + onnxruntime + 1.22.0 + org.flywaydb flyway-mysql diff --git a/src/main/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepository.java b/src/main/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepository.java new file mode 100644 index 0000000..e8b04ae --- /dev/null +++ b/src/main/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepository.java @@ -0,0 +1,177 @@ +package com.quantai.trader.artifact; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.enums.TraderRunMode; +import com.quantai.trader.persistence.TraderJsonCodec; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Repository; + +@Repository +public class JdbcTraderArtifactManifestRepository implements TraderArtifactManifestRepository { + private final JdbcTemplate jdbcTemplate; + private final TraderJsonCodec jsonCodec; + + public JdbcTraderArtifactManifestRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) { + this.jdbcTemplate = jdbcTemplate; + this.jsonCodec = new TraderJsonCodec(objectMapper); + } + + @Override + public void upsertActiveBundle(TraderArtifactBundle bundle) { + upsertModelBundle(bundle.modelBundleManifest()); + bundle.modelManifests().forEach(this::upsertModelManifest); + bundle.calibrationManifests().forEach(this::upsertCalibrationManifest); + upsertPmConfigManifest(bundle.pmConfigManifest()); + } + + private void upsertModelBundle(TraderModelBundleManifest manifest) { + jdbcTemplate.update(""" + insert into trader_model_bundle_manifest + (manifest_schema_version, model_bundle_version, calibration_bundle_version, + feature_version, label_version, split_version, training_run_id, training_export_id, + backtest_manifest_id, required_models_json, provided_models_json, missing_models_json, + allowed_run_modes_json, bundle_hash_sha256, complete, status) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + on duplicate key update + manifest_schema_version = values(manifest_schema_version), + feature_version = values(feature_version), + label_version = values(label_version), + split_version = values(split_version), + training_run_id = values(training_run_id), + training_export_id = values(training_export_id), + backtest_manifest_id = values(backtest_manifest_id), + required_models_json = values(required_models_json), + provided_models_json = values(provided_models_json), + missing_models_json = values(missing_models_json), + allowed_run_modes_json = values(allowed_run_modes_json), + bundle_hash_sha256 = values(bundle_hash_sha256), + complete = values(complete), + status = values(status) + """, + manifest.manifestSchemaVersion(), manifest.modelBundleVersion(), manifest.calibrationBundleVersion(), + manifest.featureVersion(), manifest.labelVersion(), manifest.splitVersion(), + manifest.trainingRunId(), manifest.trainingExportId(), manifest.backtestManifestId(), + jsonCodec.toJson(manifest.requiredModels()), + jsonCodec.toJson(manifest.providedModels()), jsonCodec.toJson(manifest.missingModels()), + jsonCodec.toJson(manifest.allowedRunModes().stream().map(TraderRunMode::name).toList()), + manifest.bundleHashSha256(), manifest.complete(), manifest.status()); + } + + private void upsertModelManifest(TraderModelManifest manifest) { + jdbcTemplate.update(""" + insert into trader_model_manifest + (model_bundle_version, calibration_bundle_version, model_name, model_type, side, + symbol_scope_json, bar_interval, horizon_minutes, model_format, model_runtime, + model_runtime_version, onnx_opset_version, producer_name, producer_version, + artifact_path, artifact_hash_sha256, source_hash, feature_version, feature_schema_path, + feature_schema_hash, feature_order_path, feature_order_hash, input_tensor_name, input_dtype, input_shape_json, + input_example_path, output_schema_path, output_schema_hash, output_tensor_names_json, + output_mapping_json, output_value_rules_json, label_version, split_version, training_fold, + train_start, train_end, validation_start, validation_end, test_start, test_end, + metrics_json, status) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + on duplicate key update + model_type = values(model_type), + symbol_scope_json = values(symbol_scope_json), + bar_interval = values(bar_interval), + model_format = values(model_format), + model_runtime = values(model_runtime), + model_runtime_version = values(model_runtime_version), + onnx_opset_version = values(onnx_opset_version), + producer_name = values(producer_name), + producer_version = values(producer_version), + artifact_path = values(artifact_path), + artifact_hash_sha256 = values(artifact_hash_sha256), + source_hash = values(source_hash), + feature_version = values(feature_version), + feature_schema_path = values(feature_schema_path), + feature_schema_hash = values(feature_schema_hash), + feature_order_path = values(feature_order_path), + feature_order_hash = values(feature_order_hash), + input_tensor_name = values(input_tensor_name), + input_dtype = values(input_dtype), + input_shape_json = values(input_shape_json), + input_example_path = values(input_example_path), + output_schema_path = values(output_schema_path), + output_schema_hash = values(output_schema_hash), + output_tensor_names_json = values(output_tensor_names_json), + output_mapping_json = values(output_mapping_json), + output_value_rules_json = values(output_value_rules_json), + label_version = values(label_version), + split_version = values(split_version), + training_fold = values(training_fold), + train_start = values(train_start), + train_end = values(train_end), + validation_start = values(validation_start), + validation_end = values(validation_end), + test_start = values(test_start), + test_end = values(test_end), + metrics_json = values(metrics_json), + status = values(status) + """, + manifest.modelBundleVersion(), manifest.calibrationBundleVersion(), manifest.modelName(), + manifest.modelType(), manifest.side(), jsonCodec.toJson(manifest.symbolScope()), + manifest.barInterval(), manifest.horizonMinutes(), manifest.modelFormat(), manifest.modelRuntime(), + manifest.modelRuntimeVersion(), manifest.onnxOpsetVersion(), manifest.producerName(), + manifest.producerVersion(), manifest.artifactPath(), manifest.artifactHashSha256(), + manifest.sourceHash(), manifest.featureVersion(), manifest.featureSchemaPath(), + manifest.featureSchemaHash(), manifest.featureOrderPath(), manifest.featureOrderHash(), manifest.inputTensorName(), + manifest.inputDtype(), jsonCodec.toJson(manifest.inputShapeJson()), manifest.inputExamplePath(), + manifest.outputSchemaPath(), manifest.outputSchemaHash(), jsonCodec.toJson(manifest.outputTensorNames()), + jsonCodec.toJson(manifest.outputMapping()), jsonCodec.toJson(manifest.outputValueRules()), + manifest.labelVersion(), manifest.splitVersion(), manifest.trainingFold(), + java.sql.Timestamp.from(manifest.trainStart()), java.sql.Timestamp.from(manifest.trainEnd()), + java.sql.Timestamp.from(manifest.validationStart()), java.sql.Timestamp.from(manifest.validationEnd()), + java.sql.Timestamp.from(manifest.testStart()), java.sql.Timestamp.from(manifest.testEnd()), + jsonCodec.toJson(manifest.metricsJson()), manifest.status()); + } + + private void upsertCalibrationManifest(TraderCalibrationManifest manifest) { + jdbcTemplate.update(""" + insert into trader_calibration_manifest + (calibration_bundle_version, model_bundle_version, model_name, calibrator_version, + calibration_method, calibrator_path, calibrator_hash_sha256, + calibration_window_from, calibration_window_to, calibration_metrics_json, + bucket_metrics_json, output_after_calibration_schema_hash, status) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + on duplicate key update + calibrator_version = values(calibrator_version), + calibration_method = values(calibration_method), + calibrator_path = values(calibrator_path), + calibrator_hash_sha256 = values(calibrator_hash_sha256), + calibration_window_from = values(calibration_window_from), + calibration_window_to = values(calibration_window_to), + calibration_metrics_json = values(calibration_metrics_json), + bucket_metrics_json = values(bucket_metrics_json), + output_after_calibration_schema_hash = values(output_after_calibration_schema_hash), + status = values(status) + """, + manifest.calibrationBundleVersion(), manifest.modelBundleVersion(), manifest.modelName(), + manifest.calibratorVersion(), manifest.calibrationMethod(), manifest.calibratorPath(), + manifest.calibratorHashSha256(), java.sql.Timestamp.from(manifest.calibrationWindowFrom()), + java.sql.Timestamp.from(manifest.calibrationWindowTo()), jsonCodec.toJson(manifest.calibrationMetrics()), + jsonCodec.toJson(manifest.bucketMetricsJson()), manifest.outputAfterCalibrationSchemaHash(), + manifest.status()); + } + + private void upsertPmConfigManifest(TraderPmConfigManifest manifest) { + jdbcTemplate.update(""" + insert into trader_pm_config_manifest + (pm_config_version, model_bundle_version, calibration_bundle_version, threshold_stability_json, + allowed_run_modes_json, config_json, config_hash_sha256, status) + values (?, ?, ?, ?, ?, ?, ?, ?) + on duplicate key update + model_bundle_version = values(model_bundle_version), + calibration_bundle_version = values(calibration_bundle_version), + threshold_stability_json = values(threshold_stability_json), + allowed_run_modes_json = values(allowed_run_modes_json), + config_json = values(config_json), + config_hash_sha256 = values(config_hash_sha256), + status = values(status) + """, + manifest.pmConfigVersion(), manifest.modelBundleVersion(), manifest.calibrationBundleVersion(), + jsonCodec.toJson(manifest.thresholdStabilityJson()), + jsonCodec.toJson(manifest.allowedRunModes().stream().map(TraderRunMode::name).toList()), + jsonCodec.toJson(manifest.config()), manifest.configHashSha256(), manifest.status()); + } +} diff --git a/src/main/java/com/quantai/trader/artifact/TraderArtifactBundle.java b/src/main/java/com/quantai/trader/artifact/TraderArtifactBundle.java index 190ee58..5f31004 100644 --- a/src/main/java/com/quantai/trader/artifact/TraderArtifactBundle.java +++ b/src/main/java/com/quantai/trader/artifact/TraderArtifactBundle.java @@ -1,7 +1,11 @@ package com.quantai.trader.artifact; +import com.quantai.trader.domain.TraderException; import com.quantai.trader.domain.TraderPmConfig; +import com.quantai.trader.domain.TraderPricePlanContext; +import com.quantai.trader.enums.TraderErrorCode; +import java.util.List; import java.util.Set; public record TraderArtifactBundle( @@ -10,14 +14,31 @@ public record TraderArtifactBundle( String pmConfigVersion, String bundleHashSha256, Set providedModels, + TraderModelBundleManifest modelBundleManifest, + List modelManifests, + List calibrationManifests, + TraderPmConfigManifest pmConfigManifest, TraderPmConfig pmConfig, - TraderArtifactModelPolicy modelPolicy + TraderPricePlanContext pricePlanContext, + TraderReplayModelFixture replayModelFixture ) { public TraderArtifactBundle { if (providedModels == null || !providedModels.containsAll(Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"))) { throw new IllegalArgumentException("artifact bundle must provide all five V4 models"); } + modelBundleManifest = java.util.Objects.requireNonNull(modelBundleManifest, "modelBundleManifest is required"); + modelManifests = List.copyOf(modelManifests); + calibrationManifests = List.copyOf(calibrationManifests); + pmConfigManifest = java.util.Objects.requireNonNull(pmConfigManifest, "pmConfigManifest is required"); pmConfig = java.util.Objects.requireNonNull(pmConfig, "pmConfig is required"); - modelPolicy = java.util.Objects.requireNonNull(modelPolicy, "modelPolicy is required"); + pricePlanContext = java.util.Objects.requireNonNull(pricePlanContext, "pricePlanContext is required"); + } + + public TraderReplayModelFixture requireReplayModelFixture() { + if (replayModelFixture == null) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "replay model fixture is required for replay fixture inference"); + } + return replayModelFixture; } } diff --git a/src/main/java/com/quantai/trader/artifact/TraderArtifactLoader.java b/src/main/java/com/quantai/trader/artifact/TraderArtifactLoader.java index c302d18..3162f57 100644 --- a/src/main/java/com/quantai/trader/artifact/TraderArtifactLoader.java +++ b/src/main/java/com/quantai/trader/artifact/TraderArtifactLoader.java @@ -2,9 +2,11 @@ package com.quantai.trader.artifact; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.core.type.TypeReference; import com.quantai.trader.config.TraderProperties; import com.quantai.trader.domain.TraderException; import com.quantai.trader.domain.TraderPmConfig; +import com.quantai.trader.domain.TraderPricePlanContext; import com.quantai.trader.enums.TraderRunMode; import com.quantai.trader.enums.TraderErrorCode; import org.slf4j.Logger; @@ -14,6 +16,11 @@ import org.springframework.stereotype.Component; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -22,6 +29,27 @@ import java.util.stream.StreamSupport; public class TraderArtifactLoader { private static final Logger log = LoggerFactory.getLogger(TraderArtifactLoader.class); private static final Set REQUIRED_MODELS = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"); + private static final int REQUIRED_FEATURE_COUNT = 39; + private static final int REQUIRED_ONNX_OPSET_VERSION = 17; + private static final Map> REQUIRED_OUTPUT_MAPPING_KEYS = Map.of( + "DIRECTION", Set.of("long_prob", "short_prob", "neutral_prob"), + "ENTRY", Set.of("long_entry_prob", "short_entry_prob", "long_expected_net_edge_bps", "short_expected_net_edge_bps"), + "CONTINUE", Set.of("long_continue_prob", "short_continue_prob", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"), + "EXIT", Set.of("long_exit_prob", "short_exit_prob", "long_adverse_move_bps", "short_adverse_move_bps", + "adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"), + "RISK", Set.of("market_risk_prob", "long_position_risk_prob", "short_position_risk_prob", + "market_path_risk_bps", "long_position_path_risk_bps", "short_position_path_risk_bps", + "market_drawdown_prob", "volatility_expansion_prob", "spike_prob", + "liquidity_deterioration_prob", "position_drawdown_prob") + ); + private static final Set REJECTED_OUTPUT_MAPPING_KEYS = Set.of( + "expected_net_edge_bps", + "expected_continue_edge_bps", + "expected_giveback_bps", + "market_expected_shortfall_bps", + "position_expected_shortfall_bps", + "position_risk_target" + ); private final TraderProperties properties; private final ObjectMapper objectMapper; @@ -35,10 +63,15 @@ public class TraderArtifactLoader { TraderProperties.Artifact artifact = properties.artifact(); Path root = Path.of(artifact.artifactRoot()); TraderModelBundleManifest modelManifest = readModelBundleManifest(root.resolve("manifests/model_bundle_manifest.json")); + List modelManifests = readModelManifests(root.resolve("manifests/model_manifest.json")); + List calibrationManifests = readCalibrationManifests(root.resolve("manifests/calibration_manifest.json")); TraderPmConfigManifest pmManifest = readPmConfigManifest(root.resolve("manifests/position_manager_manifest.json")); - TraderArtifactModelPolicy modelPolicy = readJson(root.resolve("model_output_policy.json"), TraderArtifactModelPolicy.class); + TraderPricePlanContext pricePlanContext = readJson(root.resolve("price_plan_context.json"), TraderPricePlanContext.class); + TraderReplayModelFixture replayModelFixture = readOptionalJson(root.resolve("replay_model_fixture.json"), TraderReplayModelFixture.class); validateVersions(artifact, modelManifest, pmManifest); validateModelManifest(modelManifest); + validateModelArtifacts(root, modelManifest, modelManifests); + validateCalibrationArtifacts(root, modelManifest, calibrationManifests); validatePmManifest(pmManifest, properties.runMode()); TraderArtifactBundle bundle = new TraderArtifactBundle( modelManifest.modelBundleVersion(), @@ -46,24 +79,111 @@ public class TraderArtifactLoader { pmManifest.pmConfigVersion(), modelManifest.bundleHashSha256(), modelManifest.providedModels(), + modelManifest, + modelManifests, + calibrationManifests, + pmManifest, pmManifest.config(), - modelPolicy); + pricePlanContext, + replayModelFixture); log.info("event=trader.artifact.loaded modelBundleVersion={} calibrationBundleVersion={} pmConfigVersion={} providedModels={}", bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion(), bundle.providedModels()); return bundle; } + private List readModelManifests(Path path) { + JsonNode root = readJsonNode(path); + if (!root.isArray()) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model manifest must be an array: " + path); + } + return StreamSupport.stream(root.spliterator(), false) + .map(node -> new TraderModelManifest( + requiredText(node, "model_bundle_version", path), + requiredText(node, "calibration_bundle_version", path), + requiredText(node, "model_name", path), + requiredText(node, "model_type", path), + requiredText(node, "side", path), + textSet(node, "symbol_scope_json", path), + requiredText(node, "bar_interval", path), + node.path("horizon_minutes").asInt(-1), + requiredText(node, "model_format", path), + requiredText(node, "model_runtime", path), + requiredText(node, "model_runtime_version", path), + node.path("onnx_opset_version").asInt(-1), + requiredText(node, "producer_name", path), + requiredText(node, "producer_version", path), + requiredText(node, "feature_version", path), + requiredText(node, "feature_schema_path", path), + requiredText(node, "feature_schema_hash", path), + requiredText(node, "feature_order_path", path), + requiredText(node, "feature_order_hash", path), + requiredText(node, "input_tensor_name", path), + requiredText(node, "input_dtype", path), + objectMap(node, "input_shape_json", path), + requiredText(node, "input_example_path", path), + requiredText(node, "output_schema_path", path), + requiredText(node, "output_schema_hash", path), + textSet(node, "output_tensor_names_json", path), + objectMap(node, "output_mapping_json", path), + objectMap(node, "output_value_rules_json", path), + requiredText(node, "label_version", path), + requiredText(node, "split_version", path), + requiredText(node, "training_fold", path), + instant(node, "train_start", path), + instant(node, "train_end", path), + instant(node, "validation_start", path), + instant(node, "validation_end", path), + instant(node, "test_start", path), + instant(node, "test_end", path), + objectMap(node, "metrics_json", path), + requiredText(node, "artifact_path", path), + requiredText(node, "artifact_hash_sha256", path), + requiredText(node, "source_hash", path), + requiredText(node, "status", path))) + .toList(); + } + + private List readCalibrationManifests(Path path) { + JsonNode root = readJsonNode(path); + if (!root.isArray()) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "calibration manifest must be an array: " + path); + } + return StreamSupport.stream(root.spliterator(), false) + .map(node -> new TraderCalibrationManifest( + requiredText(node, "calibration_bundle_version", path), + requiredText(node, "model_bundle_version", path), + requiredText(node, "model_name", path), + requiredText(node, "calibrator_version", path), + requiredText(node, "calibration_method", path), + requiredText(node, "calibrator_path", path), + requiredText(node, "calibrator_hash_sha256", path), + instant(node, "calibration_window_from", path), + instant(node, "calibration_window_to", path), + objectMap(node, "calibration_metrics_json", path), + objectMap(node, "bucket_metrics_json", path), + requiredText(node, "output_after_calibration_schema_hash", path), + requiredText(node, "status", path))) + .toList(); + } + private TraderModelBundleManifest readModelBundleManifest(Path path) { JsonNode root = readJsonNode(path); return new TraderModelBundleManifest( + requiredText(root, "manifest_schema_version", path), requiredText(root, "model_bundle_version", path), requiredText(root, "calibration_bundle_version", path), requiredText(root, "feature_version", path), requiredText(root, "label_version", path), requiredText(root, "split_version", path), + requiredText(root, "training_run_id", path), + requiredText(root, "training_export_id", path), + requiredText(root, "backtest_manifest_id", path), textSet(root, "required_models_json", path), textSet(root, "provided_models_json", path), textSet(root, "missing_models_json", path), + enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path), requiredText(root, "bundle_hash_sha256", path), root.path("complete").asBoolean(false), requiredText(root, "status", path)); @@ -75,6 +195,7 @@ public class TraderArtifactLoader { requiredText(root, "pm_config_version", path), requiredText(root, "model_bundle_version", path), requiredText(root, "calibration_bundle_version", path), + objectMap(root, "threshold_stability_json", path), enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path), convert(root.path("config_json"), TraderPmConfig.class, path), requiredText(root, "config_hash_sha256", path), @@ -106,6 +227,86 @@ public class TraderArtifactLoader { throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, "model bundle must provide all five V4 models with no missing model"); } + if (!manifest.allowedRunModes().contains(properties.runMode())) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model bundle does not allow current run mode"); + } + } + + private void validateModelArtifacts(Path root, TraderModelBundleManifest bundleManifest, + List manifests) { + Set modelNames = manifests.stream().map(TraderModelManifest::modelName).collect(Collectors.toUnmodifiableSet()); + if (!modelNames.containsAll(REQUIRED_MODELS)) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model_manifest.json must contain all five V4 model families"); + } + for (TraderModelManifest manifest : manifests) { + if (!bundleManifest.modelBundleVersion().equals(manifest.modelBundleVersion()) + || !bundleManifest.calibrationBundleVersion().equals(manifest.calibrationBundleVersion())) { + throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH, + "model manifest version does not match bundle manifest"); + } + if (!"ACTIVE".equals(manifest.status())) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model manifest must be ACTIVE: " + manifest.modelName()); + } + if (!REQUIRED_MODELS.contains(manifest.modelType()) + || !"ONNX".equals(manifest.modelFormat()) + || !"ONNX_RUNTIME_JAVA".equals(manifest.modelRuntime()) + || !"FLOAT32".equals(manifest.inputDtype()) + || manifest.horizonMinutes() <= 0 + || manifest.onnxOpsetVersion() != REQUIRED_ONNX_OPSET_VERSION + || inputFeatureCount(manifest) != REQUIRED_FEATURE_COUNT + || !"features".equals(manifest.inputTensorName())) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model manifest runtime contract is invalid: " + manifest.modelName()); + } + validateOutputMapping(manifest); + validateSha256(root.resolve(manifest.artifactPath()), manifest.artifactHashSha256()); + validateSha256(root.resolve(manifest.featureSchemaPath()), manifest.featureSchemaHash()); + validateSha256(root.resolve(manifest.featureOrderPath()), manifest.featureOrderHash()); + validateFeatureOrder(root.resolve(manifest.featureOrderPath())); + validateSha256(root.resolve(manifest.outputSchemaPath()), manifest.outputSchemaHash()); + if (!Files.isRegularFile(root.resolve(manifest.inputExamplePath()))) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model input example is missing: " + manifest.inputExamplePath()); + } + } + } + + private void validateOutputMapping(TraderModelManifest manifest) { + if (manifest.outputMapping().keySet().stream().anyMatch(REJECTED_OUTPUT_MAPPING_KEYS::contains)) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model output mapping contains rejected legacy key: " + manifest.modelName()); + } + Set requiredKeys = REQUIRED_OUTPUT_MAPPING_KEYS.get(manifest.modelType()); + if (requiredKeys == null || !manifest.outputMapping().keySet().containsAll(requiredKeys)) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "model output mapping does not match V4 contract: " + manifest.modelName()); + } + } + + private void validateCalibrationArtifacts(Path root, TraderModelBundleManifest modelManifest, + List calibrationManifests) { + Set calibratedModels = calibrationManifests.stream() + .map(TraderCalibrationManifest::modelName) + .collect(Collectors.toUnmodifiableSet()); + if (!calibratedModels.containsAll(REQUIRED_MODELS)) { + throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH, + "calibration_manifest.json must contain all five V4 model families"); + } + for (TraderCalibrationManifest calibrationManifest : calibrationManifests) { + if (!modelManifest.modelBundleVersion().equals(calibrationManifest.modelBundleVersion()) + || !modelManifest.calibrationBundleVersion().equals(calibrationManifest.calibrationBundleVersion())) { + throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH, + "calibration manifest version does not match bundle manifest"); + } + if (!"ACTIVE".equals(calibrationManifest.status())) { + throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH, + "calibration manifest must be ACTIVE"); + } + validateSha256(root.resolve(calibrationManifest.calibratorPath()), calibrationManifest.calibratorHashSha256()); + } } private void validatePmManifest(TraderPmConfigManifest manifest, TraderRunMode runMode) { @@ -119,6 +320,29 @@ public class TraderArtifactLoader { } } + private int inputFeatureCount(TraderModelManifest manifest) { + Object features = manifest.inputShapeJson().get("features"); + if (features instanceof Number number) { + return number.intValue(); + } + if (features instanceof String text) { + try { + return Integer.parseInt(text); + } catch (NumberFormatException ignored) { + return -1; + } + } + return -1; + } + + private void validateFeatureOrder(Path path) { + JsonNode featureOrder = readJsonNode(path); + if (!featureOrder.isArray() || featureOrder.size() != REQUIRED_FEATURE_COUNT) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "feature_order.json must contain exactly " + REQUIRED_FEATURE_COUNT + " fields: " + path); + } + } + private JsonNode readJsonNode(Path path) { if (!Files.isRegularFile(path)) { throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, @@ -132,6 +356,38 @@ public class TraderArtifactLoader { } } + private void validateSha256(Path path, String expectedHash) { + if (!Files.isRegularFile(path)) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "artifact referenced by manifest is missing: " + path); + } + String actualHash; + try { + actualHash = sha256(path); + } catch (IOException exception) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "artifact referenced by manifest cannot be read: " + path); + } + if (!expectedHash.equals(actualHash)) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "artifact sha256 mismatch: " + path); + } + } + + private String sha256(Path path) throws IOException { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(Files.readAllBytes(path)); + StringBuilder builder = new StringBuilder(hash.length * 2); + for (byte value : hash) { + builder.append(String.format("%02x", value & 0xff)); + } + return builder.toString(); + } catch (NoSuchAlgorithmException exception) { + throw new IllegalStateException("SHA-256 is not available", exception); + } + } + private T readJson(Path path, Class type) { if (!Files.isRegularFile(path)) { throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, @@ -145,6 +401,13 @@ public class TraderArtifactLoader { } } + private T readOptionalJson(Path path, Class type) { + if (!Files.isRegularFile(path)) { + return null; + } + return readJson(path, type); + } + private T convert(JsonNode node, Class type, Path path) { if (node == null || node.isMissingNode() || node.isNull()) { throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, @@ -167,6 +430,20 @@ public class TraderArtifactLoader { return value; } + private Instant instant(JsonNode node, String field, Path path) { + return Instant.parse(requiredText(node, field, path)); + } + + private Map objectMap(JsonNode node, String field, Path path) { + JsonNode value = node.path(field); + if (value.isMissingNode() || value.isNull() || !value.isObject()) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "artifact object field is required: " + path + "#" + field); + } + return objectMapper.convertValue(value, new TypeReference<>() { + }); + } + private Set textSet(JsonNode node, String field, Path path) { JsonNode array = node.path(field); if (!array.isArray()) { diff --git a/src/main/java/com/quantai/trader/artifact/TraderArtifactManifestRepository.java b/src/main/java/com/quantai/trader/artifact/TraderArtifactManifestRepository.java new file mode 100644 index 0000000..0ee4eda --- /dev/null +++ b/src/main/java/com/quantai/trader/artifact/TraderArtifactManifestRepository.java @@ -0,0 +1,5 @@ +package com.quantai.trader.artifact; + +public interface TraderArtifactManifestRepository { + void upsertActiveBundle(TraderArtifactBundle bundle); +} diff --git a/src/main/java/com/quantai/trader/artifact/TraderArtifactModelPolicy.java b/src/main/java/com/quantai/trader/artifact/TraderArtifactModelPolicy.java deleted file mode 100644 index 9735dc2..0000000 --- a/src/main/java/com/quantai/trader/artifact/TraderArtifactModelPolicy.java +++ /dev/null @@ -1,76 +0,0 @@ -package com.quantai.trader.artifact; - -import java.math.BigDecimal; - -public record TraderArtifactModelPolicy( - DirectionPolicy direction, - EntryPolicy entry, - ContinuePolicy continuation, - ExitPolicy exit, - RiskPolicy risk, - BigDecimal uncertainty, - BigDecimal oodScore -) { - public record DirectionPolicy( - BigDecimal longProbWhenMarkGteIndex, - BigDecimal longProbWhenMarkLtIndex, - BigDecimal neutralProb, - BigDecimal expectedReturnBps, - int horizonMinutes, - String modelVersion - ) { - } - - public record EntryPolicy( - BigDecimal longEntryProb, - BigDecimal shortEntryProb, - BigDecimal entryQualityScore, - BigDecimal expectedEdgeBps, - String pricePlanId, - String pricePlanConfigHash, - BigDecimal stopDistanceBps, - BigDecimal targetDistanceBps, - int maxHoldMinutes, - BigDecimal costBps, - String modelVersion - ) { - } - - public record ContinuePolicy( - BigDecimal longContinueProb, - BigDecimal shortContinueProb, - BigDecimal trendPersistenceProb, - BigDecimal holdEdgeBps, - BigDecimal continueVsExitEdgeBps, - String modelVersion - ) { - } - - public record ExitPolicy( - BigDecimal longExitProb, - BigDecimal shortExitProb, - BigDecimal profitGivebackProb, - BigDecimal reversalProb, - BigDecimal stopRiskProb, - BigDecimal stagnationProb, - BigDecimal expectedGivebackBps, - String modelVersion - ) { - } - - public record RiskPolicy( - BigDecimal marketRiskProb, - BigDecimal positionRiskProb, - BigDecimal marketRiskSeverityBps, - BigDecimal positionRiskSeverityBps, - BigDecimal drawdownProb, - BigDecimal expectedShortfallBps, - BigDecimal volatilityExpansionProb, - BigDecimal spikeProb, - BigDecimal liquidityRiskProb, - BigDecimal liquidityCapacityRatioWhenReady, - BigDecimal liquidityCapacityRatioWhenNotReady, - String modelVersion - ) { - } -} diff --git a/src/main/java/com/quantai/trader/artifact/TraderCalibrationManifest.java b/src/main/java/com/quantai/trader/artifact/TraderCalibrationManifest.java new file mode 100644 index 0000000..b260c34 --- /dev/null +++ b/src/main/java/com/quantai/trader/artifact/TraderCalibrationManifest.java @@ -0,0 +1,25 @@ +package com.quantai.trader.artifact; + +import java.time.Instant; +import java.util.Map; + +public record TraderCalibrationManifest( + String calibrationBundleVersion, + String modelBundleVersion, + String modelName, + String calibratorVersion, + String calibrationMethod, + String calibratorPath, + String calibratorHashSha256, + Instant calibrationWindowFrom, + Instant calibrationWindowTo, + Map calibrationMetrics, + Map bucketMetricsJson, + String outputAfterCalibrationSchemaHash, + String status +) { + public TraderCalibrationManifest { + calibrationMetrics = Map.copyOf(calibrationMetrics == null ? Map.of() : calibrationMetrics); + bucketMetricsJson = Map.copyOf(bucketMetricsJson == null ? Map.of() : bucketMetricsJson); + } +} diff --git a/src/main/java/com/quantai/trader/artifact/TraderModelBundleManifest.java b/src/main/java/com/quantai/trader/artifact/TraderModelBundleManifest.java index 20645d2..e3d4126 100644 --- a/src/main/java/com/quantai/trader/artifact/TraderModelBundleManifest.java +++ b/src/main/java/com/quantai/trader/artifact/TraderModelBundleManifest.java @@ -1,16 +1,23 @@ package com.quantai.trader.artifact; +import com.quantai.trader.enums.TraderRunMode; + import java.util.Set; public record TraderModelBundleManifest( + String manifestSchemaVersion, String modelBundleVersion, String calibrationBundleVersion, String featureVersion, String labelVersion, String splitVersion, + String trainingRunId, + String trainingExportId, + String backtestManifestId, Set requiredModels, Set providedModels, Set missingModels, + Set allowedRunModes, String bundleHashSha256, boolean complete, String status diff --git a/src/main/java/com/quantai/trader/artifact/TraderModelManifest.java b/src/main/java/com/quantai/trader/artifact/TraderModelManifest.java new file mode 100644 index 0000000..027efd5 --- /dev/null +++ b/src/main/java/com/quantai/trader/artifact/TraderModelManifest.java @@ -0,0 +1,59 @@ +package com.quantai.trader.artifact; + +import java.time.Instant; +import java.util.Map; +import java.util.Set; + +public record TraderModelManifest( + String modelBundleVersion, + String calibrationBundleVersion, + String modelName, + String modelType, + String side, + Set symbolScope, + String barInterval, + int horizonMinutes, + String modelFormat, + String modelRuntime, + String modelRuntimeVersion, + int onnxOpsetVersion, + String producerName, + String producerVersion, + String featureVersion, + String featureSchemaPath, + String featureSchemaHash, + String featureOrderPath, + String featureOrderHash, + String inputTensorName, + String inputDtype, + Map inputShapeJson, + String inputExamplePath, + String outputSchemaPath, + String outputSchemaHash, + Set outputTensorNames, + Map outputMapping, + Map outputValueRules, + String labelVersion, + String splitVersion, + String trainingFold, + Instant trainStart, + Instant trainEnd, + Instant validationStart, + Instant validationEnd, + Instant testStart, + Instant testEnd, + Map metricsJson, + String artifactPath, + String artifactHashSha256, + String sourceHash, + String status +) { + public TraderModelManifest { + symbolScope = Set.copyOf(symbolScope == null ? Set.of() : symbolScope); + inputShapeJson = Map.copyOf(inputShapeJson == null ? Map.of() : inputShapeJson); + outputTensorNames = Set.copyOf(outputTensorNames == null ? Set.of() : outputTensorNames); + outputMapping = Map.copyOf(outputMapping == null ? Map.of() : outputMapping); + outputValueRules = Map.copyOf(outputValueRules == null ? Map.of() : outputValueRules); + metricsJson = Map.copyOf(metricsJson == null ? Map.of() : metricsJson); + } +} diff --git a/src/main/java/com/quantai/trader/artifact/TraderPmConfigManifest.java b/src/main/java/com/quantai/trader/artifact/TraderPmConfigManifest.java index 7d29c68..297fb7e 100644 --- a/src/main/java/com/quantai/trader/artifact/TraderPmConfigManifest.java +++ b/src/main/java/com/quantai/trader/artifact/TraderPmConfigManifest.java @@ -3,15 +3,20 @@ package com.quantai.trader.artifact; import com.quantai.trader.domain.TraderPmConfig; import com.quantai.trader.enums.TraderRunMode; +import java.util.Map; import java.util.Set; public record TraderPmConfigManifest( String pmConfigVersion, String modelBundleVersion, String calibrationBundleVersion, + Map thresholdStabilityJson, Set allowedRunModes, TraderPmConfig config, String configHashSha256, String status ) { + public TraderPmConfigManifest { + thresholdStabilityJson = Map.copyOf(thresholdStabilityJson == null ? Map.of() : thresholdStabilityJson); + } } diff --git a/src/main/java/com/quantai/trader/artifact/TraderReplayModelFixture.java b/src/main/java/com/quantai/trader/artifact/TraderReplayModelFixture.java new file mode 100644 index 0000000..f6e9423 --- /dev/null +++ b/src/main/java/com/quantai/trader/artifact/TraderReplayModelFixture.java @@ -0,0 +1,66 @@ +package com.quantai.trader.artifact; + +import java.math.BigDecimal; +import java.util.Map; + +public record TraderReplayModelFixture( + DirectionFixture direction, + EntryFixture entry, + ContinueFixture continuation, + ExitFixture exit, + RiskFixture risk, + BigDecimal uncertainty, + BigDecimal oodScore, + String featureSchemaHash, + String featureOrderHash, + String outputSchemaHash +) { + public record DirectionFixture( + BigDecimal longProbWhenMarkGteIndex, + BigDecimal longProbWhenMarkLtIndex, + BigDecimal neutralProb + ) { + } + + public record EntryFixture( + BigDecimal longEntryProb, + BigDecimal shortEntryProb, + BigDecimal longExpectedNetEdgeBps, + BigDecimal shortExpectedNetEdgeBps + ) { + } + + public record ContinueFixture( + BigDecimal longContinueProb, + BigDecimal shortContinueProb, + BigDecimal longExpectedContinueEdgeBps, + BigDecimal shortExpectedContinueEdgeBps + ) { + } + + public record ExitFixture( + BigDecimal longExitProb, + BigDecimal shortExitProb, + BigDecimal longAdverseMoveBps, + BigDecimal shortAdverseMoveBps, + Map exitReasonScores + ) { + public ExitFixture { + exitReasonScores = Map.copyOf(exitReasonScores == null ? Map.of() : exitReasonScores); + } + } + + public record RiskFixture( + BigDecimal marketRiskProb, + BigDecimal longPositionRiskProb, + BigDecimal shortPositionRiskProb, + BigDecimal marketPathRiskBps, + BigDecimal longPositionPathRiskBps, + BigDecimal shortPositionPathRiskBps, + Map riskReasonScores + ) { + public RiskFixture { + riskReasonScores = Map.copyOf(riskReasonScores == null ? Map.of() : riskReasonScores); + } + } +} diff --git a/src/main/java/com/quantai/trader/config/TraderProperties.java b/src/main/java/com/quantai/trader/config/TraderProperties.java index 11c0462..ec9fcd2 100644 --- a/src/main/java/com/quantai/trader/config/TraderProperties.java +++ b/src/main/java/com/quantai/trader/config/TraderProperties.java @@ -6,6 +6,8 @@ import org.springframework.boot.context.properties.ConfigurationProperties; import java.math.BigDecimal; +import static com.quantai.trader.util.TraderNumbers.nonNegative; +import static com.quantai.trader.util.TraderNumbers.positive; import static com.quantai.trader.util.TraderNumbers.requiredText; @ConfigurationProperties(prefix = "trader") @@ -23,21 +25,24 @@ public record TraderProperties( PositionManager positionManager ) { public TraderProperties { - serviceName = defaultText(serviceName, "quant-trader-service"); - runMode = runMode == null ? TraderRunMode.SHADOW : runMode; - symbol = defaultText(symbol, "BTC-USDT-PERP"); - artifact = artifact == null ? new Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", ".") : artifact; - feedback = feedback == null ? new Feedback(false) : feedback; - execution = execution == null ? new Execution(TraderExecutionMode.SHADOW, 3, 1500) : execution; - runtime = runtime == null ? new Runtime("trader:v4", true, false) : runtime; - outbox = outbox == null ? new Outbox(true, 5) : outbox; - release = release == null ? new Release(true, true, true) : release; - risk = risk == null ? new Risk(new BigDecimal("200"), BigDecimal.ONE, new BigDecimal("500")) : risk; - positionManager = positionManager == null ? new PositionManager(BigDecimal.ONE, BigDecimal.ONE) : positionManager; + serviceName = requiredText(serviceName, "serviceName"); + runMode = require(runMode, "runMode"); + symbol = requiredText(symbol, "symbol"); + artifact = require(artifact, "artifact"); + feedback = require(feedback, "feedback"); + execution = require(execution, "execution"); + runtime = require(runtime, "runtime"); + outbox = require(outbox, "outbox"); + release = require(release, "release"); + risk = require(risk, "risk"); + positionManager = require(positionManager, "positionManager"); } - private static String defaultText(String value, String defaultValue) { - return value == null || value.isBlank() ? defaultValue : value; + private static T require(T value, String field) { + if (value == null) { + throw new IllegalArgumentException(field + " is required"); + } + return value; } public record Artifact( @@ -57,19 +62,30 @@ public record TraderProperties( public record Feedback(boolean httpEnabled) { } - public record Execution(TraderExecutionMode mode, int maxApiErrorCount, long maxExchangeLatencyMs) { + public record Execution(TraderExecutionMode mode, Integer maxApiErrorCount, Long maxExchangeLatencyMs) { public Execution { - mode = mode == null ? TraderExecutionMode.SHADOW : mode; + mode = require(mode, "execution.mode"); + if (require(maxApiErrorCount, "execution.maxApiErrorCount") < 0) { + throw new IllegalArgumentException("execution.maxApiErrorCount must be >= 0"); + } + if (require(maxExchangeLatencyMs, "execution.maxExchangeLatencyMs") <= 0) { + throw new IllegalArgumentException("execution.maxExchangeLatencyMs must be > 0"); + } } } public record Runtime(String redisKeyPrefix, boolean requireRedisForOpenAdd, boolean tradingEnabled) { public Runtime { - redisKeyPrefix = defaultText(redisKeyPrefix, "trader:v4"); + redisKeyPrefix = requiredText(redisKeyPrefix, "runtime.redisKeyPrefix"); } } - public record Outbox(boolean enabled, int maxRetryCount) { + public record Outbox(boolean enabled, Integer maxRetryCount) { + public Outbox { + if (require(maxRetryCount, "outbox.maxRetryCount") < 0) { + throw new IllegalArgumentException("outbox.maxRetryCount must be >= 0"); + } + } } public record Release(boolean requireReviewForPaper, boolean requireReviewForLiveProbe, boolean activePointerCheckEnabled) { @@ -77,16 +93,16 @@ public record TraderProperties( public record Risk(BigDecimal maxDailyLossBps, BigDecimal maxTotalExposureRatio, BigDecimal minLiquidationBufferBps) { public Risk { - maxDailyLossBps = maxDailyLossBps == null ? new BigDecimal("200") : maxDailyLossBps; - maxTotalExposureRatio = maxTotalExposureRatio == null ? BigDecimal.ONE : maxTotalExposureRatio; - minLiquidationBufferBps = minLiquidationBufferBps == null ? new BigDecimal("500") : minLiquidationBufferBps; + maxDailyLossBps = nonNegative(maxDailyLossBps, "risk.maxDailyLossBps"); + maxTotalExposureRatio = positive(maxTotalExposureRatio, "risk.maxTotalExposureRatio"); + minLiquidationBufferBps = nonNegative(minLiquidationBufferBps, "risk.minLiquidationBufferBps"); } } public record PositionManager(BigDecimal maxSingleLegRatio, BigDecimal maxTotalPositionRatio) { public PositionManager { - maxSingleLegRatio = maxSingleLegRatio == null ? BigDecimal.ONE : maxSingleLegRatio; - maxTotalPositionRatio = maxTotalPositionRatio == null ? BigDecimal.ONE : maxTotalPositionRatio; + maxSingleLegRatio = positive(maxSingleLegRatio, "positionManager.maxSingleLegRatio"); + maxTotalPositionRatio = positive(maxTotalPositionRatio, "positionManager.maxTotalPositionRatio"); } } } diff --git a/src/main/java/com/quantai/trader/controller/TraderApiExceptionHandler.java b/src/main/java/com/quantai/trader/controller/TraderApiExceptionHandler.java index 89d3fe6..ba047b5 100644 --- a/src/main/java/com/quantai/trader/controller/TraderApiExceptionHandler.java +++ b/src/main/java/com/quantai/trader/controller/TraderApiExceptionHandler.java @@ -14,6 +14,6 @@ public class TraderApiExceptionHandler { @ExceptionHandler(IllegalArgumentException.class) ResponseEntity illegalArgument(IllegalArgumentException exception) { - return ResponseEntity.badRequest().body(new TraderApiError(com.quantai.trader.enums.TraderErrorCode.TRADER_MODEL_OUTPUT_INVALID, exception.getMessage())); + return ResponseEntity.badRequest().body(new TraderApiError(com.quantai.trader.enums.TraderErrorCode.TRADER_REQUEST_INVALID, exception.getMessage())); } } diff --git a/src/main/java/com/quantai/trader/controller/TraderFeedbackController.java b/src/main/java/com/quantai/trader/controller/TraderFeedbackController.java index 61f3adb..cd4a134 100644 --- a/src/main/java/com/quantai/trader/controller/TraderFeedbackController.java +++ b/src/main/java/com/quantai/trader/controller/TraderFeedbackController.java @@ -5,6 +5,7 @@ import com.quantai.trader.domain.FeedbackValidator; import com.quantai.trader.domain.TraderAppFeedback; import com.quantai.trader.domain.TraderException; import com.quantai.trader.enums.TraderErrorCode; +import com.quantai.trader.feedback.TraderFeedbackRepository; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.web.bind.annotation.PostMapping; @@ -18,10 +19,13 @@ public class TraderFeedbackController { private static final Logger log = LoggerFactory.getLogger(TraderFeedbackController.class); private final TraderProperties properties; private final FeedbackValidator feedbackValidator; + private final TraderFeedbackRepository feedbackRepository; - public TraderFeedbackController(TraderProperties properties, FeedbackValidator feedbackValidator) { + public TraderFeedbackController(TraderProperties properties, FeedbackValidator feedbackValidator, + TraderFeedbackRepository feedbackRepository) { this.properties = properties; this.feedbackValidator = feedbackValidator; + this.feedbackRepository = feedbackRepository; } @PostMapping("/api/trader/feedback") @@ -30,6 +34,7 @@ public class TraderFeedbackController { throw new TraderException(TraderErrorCode.TRADER_FEEDBACK_INVALID, "HTTP feedback is disabled in P0"); } feedbackValidator.validateP0(feedback); + feedbackRepository.insert(feedback); log.info("event=trader.feedback.accepted runId={} cycleId={} actionId={} source={}", feedback.runId(), feedback.cycleId(), feedback.actionId(), feedback.feedbackSource()); return Map.of("accepted", true, "feedbackId", feedback.feedbackId()); diff --git a/src/main/java/com/quantai/trader/domain/ContinueOutput.java b/src/main/java/com/quantai/trader/domain/ContinueOutput.java index 9dbf164..7d568f6 100644 --- a/src/main/java/com/quantai/trader/domain/ContinueOutput.java +++ b/src/main/java/com/quantai/trader/domain/ContinueOutput.java @@ -2,27 +2,29 @@ package com.quantai.trader.domain; import static com.quantai.trader.util.TraderNumbers.*; -import java.math.BigDecimal; -import java.util.Map; +import com.quantai.trader.enums.PositionSide; +import java.math.BigDecimal; public record ContinueOutput( BigDecimal longContinueProb, BigDecimal shortContinueProb, - BigDecimal trendPersistenceProb, - BigDecimal holdEdgeBps, - BigDecimal continueVsExitEdgeBps, - String modelVersion, - String calibrationVersion, - Map explanation + BigDecimal longExpectedContinueEdgeBps, + BigDecimal shortExpectedContinueEdgeBps ) { public ContinueOutput { longContinueProb = probability(longContinueProb, "continue.longContinueProb"); shortContinueProb = probability(shortContinueProb, "continue.shortContinueProb"); - trendPersistenceProb = probability(trendPersistenceProb, "continue.trendPersistenceProb"); - holdEdgeBps = required(holdEdgeBps, "continue.holdEdgeBps"); - continueVsExitEdgeBps = required(continueVsExitEdgeBps, "continue.continueVsExitEdgeBps"); - modelVersion = requiredText(modelVersion, "continue.modelVersion"); - calibrationVersion = requiredText(calibrationVersion, "continue.calibrationVersion"); - explanation = Map.copyOf(explanation == null ? Map.of() : explanation); + longExpectedContinueEdgeBps = required(longExpectedContinueEdgeBps, "continue.longExpectedContinueEdgeBps"); + shortExpectedContinueEdgeBps = required(shortExpectedContinueEdgeBps, "continue.shortExpectedContinueEdgeBps"); + } + + public BigDecimal continueEdgeBpsFor(PositionSide side) { + if (side == PositionSide.LONG) { + return longExpectedContinueEdgeBps; + } + if (side == PositionSide.SHORT) { + return shortExpectedContinueEdgeBps; + } + throw new IllegalArgumentException("continue edge requires LONG or SHORT side"); } } diff --git a/src/main/java/com/quantai/trader/domain/DirectionOutput.java b/src/main/java/com/quantai/trader/domain/DirectionOutput.java index f98ad9c..9d40299 100644 --- a/src/main/java/com/quantai/trader/domain/DirectionOutput.java +++ b/src/main/java/com/quantai/trader/domain/DirectionOutput.java @@ -3,32 +3,27 @@ package com.quantai.trader.domain; import static com.quantai.trader.util.TraderNumbers.*; import java.math.BigDecimal; -import java.util.Map; public record DirectionOutput( BigDecimal longProb, BigDecimal shortProb, - BigDecimal neutralProb, - BigDecimal directionConfidence, - BigDecimal directionMargin, - BigDecimal expectedReturnBps, - Integer horizonMinutes, - String modelVersion, - String calibrationVersion, - Map explanation + BigDecimal neutralProb ) { public DirectionOutput { longProb = probability(longProb, "direction.longProb"); shortProb = probability(shortProb, "direction.shortProb"); neutralProb = probability(neutralProb, "direction.neutralProb"); - directionConfidence = probability(directionConfidence, "direction.directionConfidence"); - directionMargin = nonNegative(directionMargin, "direction.directionMargin"); - expectedReturnBps = required(expectedReturnBps, "direction.expectedReturnBps"); - if (horizonMinutes == null || horizonMinutes <= 0) { - throw new IllegalArgumentException("direction.horizonMinutes must be > 0"); + BigDecimal sum = longProb.add(shortProb).add(neutralProb); + if (sum.subtract(BigDecimal.ONE).abs().compareTo(new BigDecimal("0.000001")) > 0) { + throw new IllegalArgumentException("direction probabilities must sum to 1"); } - modelVersion = requiredText(modelVersion, "direction.modelVersion"); - calibrationVersion = requiredText(calibrationVersion, "direction.calibrationVersion"); - explanation = Map.copyOf(explanation == null ? Map.of() : explanation); + } + + public BigDecimal margin() { + return longProb.subtract(shortProb).abs(); + } + + public BigDecimal confidence() { + return longProb.max(shortProb).max(neutralProb); } } diff --git a/src/main/java/com/quantai/trader/domain/EntryOutput.java b/src/main/java/com/quantai/trader/domain/EntryOutput.java index 752933d..9566831 100644 --- a/src/main/java/com/quantai/trader/domain/EntryOutput.java +++ b/src/main/java/com/quantai/trader/domain/EntryOutput.java @@ -2,39 +2,29 @@ package com.quantai.trader.domain; import static com.quantai.trader.util.TraderNumbers.*; -import java.math.BigDecimal; -import java.util.Map; +import com.quantai.trader.enums.PositionSide; +import java.math.BigDecimal; public record EntryOutput( BigDecimal longEntryProb, BigDecimal shortEntryProb, - BigDecimal entryQualityScore, - BigDecimal expectedEdgeBps, - String pricePlanId, - String pricePlanConfigHash, - BigDecimal stopDistanceBps, - BigDecimal targetDistanceBps, - Integer maxHoldMinutes, - BigDecimal costBps, - String modelVersion, - String calibrationVersion, - Map explanation + BigDecimal longExpectedNetEdgeBps, + BigDecimal shortExpectedNetEdgeBps ) { public EntryOutput { longEntryProb = probability(longEntryProb, "entry.longEntryProb"); shortEntryProb = probability(shortEntryProb, "entry.shortEntryProb"); - entryQualityScore = probability(entryQualityScore, "entry.entryQualityScore"); - expectedEdgeBps = required(expectedEdgeBps, "entry.expectedEdgeBps"); - pricePlanId = requiredText(pricePlanId, "entry.pricePlanId"); - pricePlanConfigHash = requiredText(pricePlanConfigHash, "entry.pricePlanConfigHash"); - stopDistanceBps = positive(stopDistanceBps, "entry.stopDistanceBps"); - targetDistanceBps = positive(targetDistanceBps, "entry.targetDistanceBps"); - if (maxHoldMinutes == null || maxHoldMinutes <= 0) { - throw new IllegalArgumentException("entry.maxHoldMinutes must be > 0"); + longExpectedNetEdgeBps = required(longExpectedNetEdgeBps, "entry.longExpectedNetEdgeBps"); + shortExpectedNetEdgeBps = required(shortExpectedNetEdgeBps, "entry.shortExpectedNetEdgeBps"); + } + + public BigDecimal netEdgeBpsFor(PositionSide side) { + if (side == PositionSide.LONG) { + return longExpectedNetEdgeBps; } - costBps = nonNegative(costBps, "entry.costBps"); - modelVersion = requiredText(modelVersion, "entry.modelVersion"); - calibrationVersion = requiredText(calibrationVersion, "entry.calibrationVersion"); - explanation = Map.copyOf(explanation == null ? Map.of() : explanation); + if (side == PositionSide.SHORT) { + return shortExpectedNetEdgeBps; + } + throw new IllegalArgumentException("entry edge requires LONG or SHORT side"); } } diff --git a/src/main/java/com/quantai/trader/domain/ExitOutput.java b/src/main/java/com/quantai/trader/domain/ExitOutput.java index 0828b10..0db07ea 100644 --- a/src/main/java/com/quantai/trader/domain/ExitOutput.java +++ b/src/main/java/com/quantai/trader/domain/ExitOutput.java @@ -4,29 +4,39 @@ import static com.quantai.trader.util.TraderNumbers.*; import java.math.BigDecimal; import java.util.Map; +import java.util.Set; public record ExitOutput( BigDecimal longExitProb, BigDecimal shortExitProb, - BigDecimal profitGivebackProb, - BigDecimal reversalProb, - BigDecimal stopRiskProb, - BigDecimal stagnationProb, - BigDecimal expectedGivebackBps, - String modelVersion, - String calibrationVersion, - Map explanation + BigDecimal longAdverseMoveBps, + BigDecimal shortAdverseMoveBps, + Map exitReasonScores ) { + private static final Set REQUIRED_REASON_KEYS = Set.of( + "adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"); + public ExitOutput { longExitProb = probability(longExitProb, "exit.longExitProb"); shortExitProb = probability(shortExitProb, "exit.shortExitProb"); - profitGivebackProb = probability(profitGivebackProb, "exit.profitGivebackProb"); - reversalProb = probability(reversalProb, "exit.reversalProb"); - stopRiskProb = probability(stopRiskProb, "exit.stopRiskProb"); - stagnationProb = probability(stagnationProb, "exit.stagnationProb"); - expectedGivebackBps = nonNegative(expectedGivebackBps, "exit.expectedGivebackBps"); - modelVersion = requiredText(modelVersion, "exit.modelVersion"); - calibrationVersion = requiredText(calibrationVersion, "exit.calibrationVersion"); - explanation = Map.copyOf(explanation == null ? Map.of() : explanation); + longAdverseMoveBps = nonNegative(longAdverseMoveBps, "exit.longAdverseMoveBps"); + shortAdverseMoveBps = nonNegative(shortAdverseMoveBps, "exit.shortAdverseMoveBps"); + exitReasonScores = checkedProbabilities(exitReasonScores, "exit.exitReasonScores"); + } + + public BigDecimal reasonScore(String key) { + return exitReasonScores.get(requiredText(key, "exit reason key")); + } + + private static Map checkedProbabilities(Map scores, String field) { + Map source = scores == null ? Map.of() : scores; + if (!source.keySet().containsAll(REQUIRED_REASON_KEYS)) { + throw new IllegalArgumentException(field + " must contain " + REQUIRED_REASON_KEYS); + } + source.forEach((key, value) -> { + requiredText(key, field + ".key"); + probability(value, field + "." + key); + }); + return Map.copyOf(source); } } diff --git a/src/main/java/com/quantai/trader/domain/PositionManagerInput.java b/src/main/java/com/quantai/trader/domain/PositionManagerInput.java index 29a1060..21c4ec2 100644 --- a/src/main/java/com/quantai/trader/domain/PositionManagerInput.java +++ b/src/main/java/com/quantai/trader/domain/PositionManagerInput.java @@ -6,6 +6,7 @@ public record PositionManagerInput( TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderModelOutput modelOutput, + TraderPricePlanContext pricePlanContext, TraderPositionState positionState, TraderAccountState accountState, TraderExecutionState executionState, @@ -15,6 +16,7 @@ public record PositionManagerInput( cycle = Objects.requireNonNull(cycle, "cycle is required"); snapshot = Objects.requireNonNull(snapshot, "snapshot is required"); modelOutput = Objects.requireNonNull(modelOutput, "modelOutput is required"); + pricePlanContext = Objects.requireNonNull(pricePlanContext, "pricePlanContext is required"); positionState = Objects.requireNonNull(positionState, "positionState is required"); accountState = Objects.requireNonNull(accountState, "accountState is required"); executionState = Objects.requireNonNull(executionState, "executionState is required"); diff --git a/src/main/java/com/quantai/trader/domain/RiskOutput.java b/src/main/java/com/quantai/trader/domain/RiskOutput.java index 99b8828..0563bae 100644 --- a/src/main/java/com/quantai/trader/domain/RiskOutput.java +++ b/src/main/java/com/quantai/trader/domain/RiskOutput.java @@ -2,37 +2,68 @@ package com.quantai.trader.domain; import static com.quantai.trader.util.TraderNumbers.*; +import com.quantai.trader.enums.PositionSide; + import java.math.BigDecimal; import java.util.Map; +import java.util.Set; public record RiskOutput( BigDecimal marketRiskProb, - BigDecimal positionRiskProb, - BigDecimal marketRiskSeverityBps, - BigDecimal positionRiskSeverityBps, - BigDecimal drawdownProb, - BigDecimal expectedShortfallBps, - BigDecimal volatilityExpansionProb, - BigDecimal spikeProb, - BigDecimal liquidityRiskProb, - BigDecimal liquidityCapacityRatio, - String modelVersion, - String calibrationVersion, - Map explanation + BigDecimal longPositionRiskProb, + BigDecimal shortPositionRiskProb, + BigDecimal marketPathRiskBps, + BigDecimal longPositionPathRiskBps, + BigDecimal shortPositionPathRiskBps, + Map riskReasonScores ) { + private static final Set REQUIRED_REASON_KEYS = Set.of( + "market_drawdown_prob", "volatility_expansion_prob", "spike_prob", + "liquidity_deterioration_prob", "position_drawdown_prob"); + public RiskOutput { marketRiskProb = probability(marketRiskProb, "risk.marketRiskProb"); - positionRiskProb = probability(positionRiskProb, "risk.positionRiskProb"); - marketRiskSeverityBps = nonNegative(marketRiskSeverityBps, "risk.marketRiskSeverityBps"); - positionRiskSeverityBps = nonNegative(positionRiskSeverityBps, "risk.positionRiskSeverityBps"); - drawdownProb = probability(drawdownProb, "risk.drawdownProb"); - expectedShortfallBps = nonNegative(expectedShortfallBps, "risk.expectedShortfallBps"); - volatilityExpansionProb = probability(volatilityExpansionProb, "risk.volatilityExpansionProb"); - spikeProb = probability(spikeProb, "risk.spikeProb"); - liquidityRiskProb = probability(liquidityRiskProb, "risk.liquidityRiskProb"); - liquidityCapacityRatio = nonNegative(liquidityCapacityRatio, "risk.liquidityCapacityRatio"); - modelVersion = requiredText(modelVersion, "risk.modelVersion"); - calibrationVersion = requiredText(calibrationVersion, "risk.calibrationVersion"); - explanation = Map.copyOf(explanation == null ? Map.of() : explanation); + longPositionRiskProb = probability(longPositionRiskProb, "risk.longPositionRiskProb"); + shortPositionRiskProb = probability(shortPositionRiskProb, "risk.shortPositionRiskProb"); + marketPathRiskBps = nonNegative(marketPathRiskBps, "risk.marketPathRiskBps"); + longPositionPathRiskBps = nonNegative(longPositionPathRiskBps, "risk.longPositionPathRiskBps"); + shortPositionPathRiskBps = nonNegative(shortPositionPathRiskBps, "risk.shortPositionPathRiskBps"); + riskReasonScores = checkedProbabilities(riskReasonScores, "risk.riskReasonScores"); + } + + public BigDecimal reasonScore(String key) { + return riskReasonScores.get(requiredText(key, "risk reason key")); + } + + public BigDecimal sideRiskProbFor(PositionSide side) { + if (side == PositionSide.LONG) { + return longPositionRiskProb; + } + if (side == PositionSide.SHORT) { + return shortPositionRiskProb; + } + throw new IllegalArgumentException("position risk requires LONG or SHORT side"); + } + + public BigDecimal positionPathRiskBpsFor(PositionSide side) { + if (side == PositionSide.LONG) { + return longPositionPathRiskBps; + } + if (side == PositionSide.SHORT) { + return shortPositionPathRiskBps; + } + throw new IllegalArgumentException("position path risk requires LONG or SHORT side"); + } + + private static Map checkedProbabilities(Map scores, String field) { + Map source = scores == null ? Map.of() : scores; + if (!source.keySet().containsAll(REQUIRED_REASON_KEYS)) { + throw new IllegalArgumentException(field + " must contain " + REQUIRED_REASON_KEYS); + } + source.forEach((key, value) -> { + requiredText(key, field + ".key"); + probability(value, field + "." + key); + }); + return Map.copyOf(source); } } diff --git a/src/main/java/com/quantai/trader/domain/TraderModelOutput.java b/src/main/java/com/quantai/trader/domain/TraderModelOutput.java index 61800e3..30d6fae 100644 --- a/src/main/java/com/quantai/trader/domain/TraderModelOutput.java +++ b/src/main/java/com/quantai/trader/domain/TraderModelOutput.java @@ -2,38 +2,28 @@ package com.quantai.trader.domain; import static com.quantai.trader.util.TraderNumbers.*; -import java.math.BigDecimal; -import java.util.Map; import java.util.Objects; public record TraderModelOutput( String modelOutputId, String runId, String cycleId, - String modelBundleVersion, - String calibrationBundleVersion, + TraderModelOutputMetadata metadata, DirectionOutput direction, EntryOutput entry, ContinueOutput continuation, ExitOutput exit, - RiskOutput risk, - BigDecimal uncertainty, - BigDecimal oodScore, - Map explanation + RiskOutput risk ) { public TraderModelOutput { modelOutputId = requiredText(modelOutputId, "modelOutputId"); runId = requiredText(runId, "runId"); cycleId = requiredText(cycleId, "cycleId"); - modelBundleVersion = requiredText(modelBundleVersion, "modelBundleVersion"); - calibrationBundleVersion = requiredText(calibrationBundleVersion, "calibrationBundleVersion"); + metadata = Objects.requireNonNull(metadata, "metadata is required"); direction = Objects.requireNonNull(direction, "direction is required"); entry = Objects.requireNonNull(entry, "entry is required"); continuation = Objects.requireNonNull(continuation, "continuation is required"); exit = Objects.requireNonNull(exit, "exit is required"); risk = Objects.requireNonNull(risk, "risk is required"); - uncertainty = probability(uncertainty, "uncertainty"); - oodScore = probability(oodScore, "oodScore"); - explanation = Map.copyOf(explanation == null ? Map.of() : explanation); } } diff --git a/src/main/java/com/quantai/trader/domain/TraderModelOutputMetadata.java b/src/main/java/com/quantai/trader/domain/TraderModelOutputMetadata.java new file mode 100644 index 0000000..1eea066 --- /dev/null +++ b/src/main/java/com/quantai/trader/domain/TraderModelOutputMetadata.java @@ -0,0 +1,39 @@ +package com.quantai.trader.domain; + +import static com.quantai.trader.util.TraderNumbers.*; + +import java.math.BigDecimal; +import java.util.Map; + +public record TraderModelOutputMetadata( + String modelBundleVersion, + String calibrationBundleVersion, + Map modelVersions, + Map calibrationVersions, + String featureSchemaHash, + String featureOrderHash, + String outputSchemaHash, + BigDecimal uncertainty, + BigDecimal oodScore +) { + public TraderModelOutputMetadata { + modelBundleVersion = requiredText(modelBundleVersion, "metadata.modelBundleVersion"); + calibrationBundleVersion = requiredText(calibrationBundleVersion, "metadata.calibrationBundleVersion"); + modelVersions = checkedTextMap(modelVersions, "metadata.modelVersions"); + calibrationVersions = checkedTextMap(calibrationVersions, "metadata.calibrationVersions"); + featureSchemaHash = requiredText(featureSchemaHash, "metadata.featureSchemaHash"); + featureOrderHash = requiredText(featureOrderHash, "metadata.featureOrderHash"); + outputSchemaHash = requiredText(outputSchemaHash, "metadata.outputSchemaHash"); + uncertainty = probability(uncertainty, "metadata.uncertainty"); + oodScore = probability(oodScore, "metadata.oodScore"); + } + + private static Map checkedTextMap(Map values, String field) { + Map source = values == null ? Map.of() : values; + source.forEach((key, value) -> { + requiredText(key, field + ".key"); + requiredText(value, field + "." + key); + }); + return Map.copyOf(source); + } +} diff --git a/src/main/java/com/quantai/trader/domain/TraderPmConfig.java b/src/main/java/com/quantai/trader/domain/TraderPmConfig.java index 6a71b87..90eac5e 100644 --- a/src/main/java/com/quantai/trader/domain/TraderPmConfig.java +++ b/src/main/java/com/quantai/trader/domain/TraderPmConfig.java @@ -81,22 +81,22 @@ public record TraderPmConfig( BigDecimal closePositionRiskProb, BigDecimal closeMarketRiskProb, BigDecimal closeContinueMax, - BigDecimal reduceGivebackProb, + BigDecimal reduceAdverseMoveProb, BigDecimal reduceContinueMin, BigDecimal reduceContinueMax, BigDecimal minProfitForReduceBps, - BigDecimal maxExpectedShortfallBps + BigDecimal maxPositionPathRiskBps ) { public ExitRuleConfig { closeExitProb = probability(closeExitProb, "exit.closeExitProb"); closePositionRiskProb = probability(closePositionRiskProb, "exit.closePositionRiskProb"); closeMarketRiskProb = probability(closeMarketRiskProb, "exit.closeMarketRiskProb"); closeContinueMax = probability(closeContinueMax, "exit.closeContinueMax"); - reduceGivebackProb = probability(reduceGivebackProb, "exit.reduceGivebackProb"); + reduceAdverseMoveProb = probability(reduceAdverseMoveProb, "exit.reduceAdverseMoveProb"); reduceContinueMin = probability(reduceContinueMin, "exit.reduceContinueMin"); reduceContinueMax = probability(reduceContinueMax, "exit.reduceContinueMax"); minProfitForReduceBps = nonNegative(minProfitForReduceBps, "exit.minProfitForReduceBps"); - maxExpectedShortfallBps = nonNegative(maxExpectedShortfallBps, "exit.maxExpectedShortfallBps"); + maxPositionPathRiskBps = nonNegative(maxPositionPathRiskBps, "exit.maxPositionPathRiskBps"); } } diff --git a/src/main/java/com/quantai/trader/domain/TraderPricePlanContext.java b/src/main/java/com/quantai/trader/domain/TraderPricePlanContext.java new file mode 100644 index 0000000..4902edc --- /dev/null +++ b/src/main/java/com/quantai/trader/domain/TraderPricePlanContext.java @@ -0,0 +1,25 @@ +package com.quantai.trader.domain; + +import static com.quantai.trader.util.TraderNumbers.*; + +import java.math.BigDecimal; + +public record TraderPricePlanContext( + String pricePlanId, + String pricePlanConfigHash, + BigDecimal stopDistanceBps, + BigDecimal targetDistanceBps, + int maxHoldMinutes, + BigDecimal costBps +) { + public TraderPricePlanContext { + pricePlanId = requiredText(pricePlanId, "pricePlan.pricePlanId"); + pricePlanConfigHash = requiredText(pricePlanConfigHash, "pricePlan.pricePlanConfigHash"); + stopDistanceBps = positive(stopDistanceBps, "pricePlan.stopDistanceBps"); + targetDistanceBps = positive(targetDistanceBps, "pricePlan.targetDistanceBps"); + if (maxHoldMinutes <= 0) { + throw new IllegalArgumentException("pricePlan.maxHoldMinutes must be > 0"); + } + costBps = nonNegative(costBps, "pricePlan.costBps"); + } +} diff --git a/src/main/java/com/quantai/trader/enums/TraderErrorCode.java b/src/main/java/com/quantai/trader/enums/TraderErrorCode.java index c02cf46..c7544e9 100644 --- a/src/main/java/com/quantai/trader/enums/TraderErrorCode.java +++ b/src/main/java/com/quantai/trader/enums/TraderErrorCode.java @@ -9,8 +9,11 @@ public enum TraderErrorCode { TRADER_RISK_BLOCKED, TRADER_EXECUTION_BLOCKED, TRADER_FEEDBACK_INVALID, + TRADER_REQUEST_INVALID, TRADER_P0_MODE_BLOCKED, TRADER_KILL_SWITCH_ACTIVE, + TRADER_RUNTIME_CONTROL_BLOCKED, + TRADER_OUTBOX_BLOCKED, TRADER_ACTIVE_POINTER_MISMATCH, TRADER_PERSISTENCE_FAILED } diff --git a/src/main/java/com/quantai/trader/feature/TraderFeatureVectorBuilder.java b/src/main/java/com/quantai/trader/feature/TraderFeatureVectorBuilder.java new file mode 100644 index 0000000..11843ff --- /dev/null +++ b/src/main/java/com/quantai/trader/feature/TraderFeatureVectorBuilder.java @@ -0,0 +1,144 @@ +package com.quantai.trader.feature; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.artifact.TraderArtifactBundle; +import com.quantai.trader.artifact.TraderModelManifest; +import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.domain.TraderMarketSnapshot; +import com.quantai.trader.enums.TraderErrorCode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.StreamSupport; + +@Component +public class TraderFeatureVectorBuilder { + private static final Logger log = LoggerFactory.getLogger(TraderFeatureVectorBuilder.class); + private static final int REQUIRED_FEATURE_COUNT = 39; + + private final TraderProperties properties; + private final ObjectMapper objectMapper; + private final ConcurrentMap> featureOrderCache = new ConcurrentHashMap<>(); + + public TraderFeatureVectorBuilder(TraderProperties properties, ObjectMapper objectMapper) { + this.properties = properties; + this.objectMapper = objectMapper; + } + + public float[] build(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) { + TraderModelManifest referenceManifest = referenceManifest(bundle); + if (!referenceManifest.featureVersion().equals(snapshot.featureVersion())) { + log.warn("event=trader.features.version_mismatch runId={} cycleId={} snapshotFeatureVersion={} modelFeatureVersion={}", + snapshot.runId(), snapshot.cycleId(), snapshot.featureVersion(), referenceManifest.featureVersion()); + throw modelInputException("snapshot feature version does not match model feature version: " + + snapshot.featureVersion() + " != " + referenceManifest.featureVersion()); + } + List featureOrder = featureOrder(referenceManifest); + rejectUnexpectedFeatures(snapshot, featureOrder); + float[] values = new float[featureOrder.size()]; + for (int index = 0; index < featureOrder.size(); index++) { + String featureName = featureOrder.get(index); + Object rawValue = snapshot.featureJson().get(featureName); + values[index] = toFloat(snapshot, rawValue, featureName, index); + } + log.debug("event=trader.features.vector_built runId={} cycleId={} featureVersion={} featureCount={} featureOrderHash={}", + snapshot.runId(), snapshot.cycleId(), snapshot.featureVersion(), values.length, referenceManifest.featureOrderHash()); + return values; + } + + public List featureOrder(TraderArtifactBundle bundle) { + return featureOrder(referenceManifest(bundle)); + } + + private TraderModelManifest referenceManifest(TraderArtifactBundle bundle) { + return bundle.modelManifests().stream() + .filter(manifest -> "DIRECTION".equals(manifest.modelType())) + .findFirst() + .orElseThrow(() -> modelInputException("DIRECTION model manifest is required for feature order")); + } + + private List featureOrder(TraderModelManifest manifest) { + String cacheKey = manifest.featureOrderHash() + "|" + manifest.featureOrderPath(); + // 特征顺序是模型包契约的一部分,按 hash 缓存,避免每轮重复读文件。 + return featureOrderCache.computeIfAbsent(cacheKey, ignored -> readFeatureOrder(manifest)); + } + + private List readFeatureOrder(TraderModelManifest manifest) { + Path path = Path.of(properties.artifact().artifactRoot()).resolve(manifest.featureOrderPath()); + if (!Files.isRegularFile(path)) { + throw modelInputException("feature_order.json is missing: " + path); + } + try { + JsonNode root = objectMapper.readTree(path.toFile()); + if (!root.isArray() || root.size() != REQUIRED_FEATURE_COUNT) { + throw modelInputException("feature_order.json must contain exactly " + REQUIRED_FEATURE_COUNT + " fields: " + path); + } + List order = StreamSupport.stream(root.spliterator(), false) + .map(JsonNode::asText) + .toList(); + Set unique = new LinkedHashSet<>(order); + if (unique.size() != order.size() || order.stream().anyMatch(String::isBlank)) { + throw modelInputException("feature_order.json contains duplicate or blank feature names: " + path); + } + log.info("event=trader.features.order_loaded featureOrderPath={} featureOrderHash={} featureCount={}", + manifest.featureOrderPath(), manifest.featureOrderHash(), order.size()); + return order; + } catch (IOException exception) { + throw modelInputException("feature_order.json cannot be read: " + path); + } + } + + private void rejectUnexpectedFeatures(TraderMarketSnapshot snapshot, List featureOrder) { + Set allowed = Set.copyOf(featureOrder); + List unexpected = snapshot.featureJson().keySet().stream() + .filter(key -> !allowed.contains(key)) + .sorted() + .toList(); + if (!unexpected.isEmpty()) { + log.warn("event=trader.features.unexpected_fields runId={} cycleId={} unexpectedFields={}", + snapshot.runId(), snapshot.cycleId(), unexpected); + throw modelInputException("snapshot featureJson contains fields outside feature_order.json: " + unexpected); + } + } + + private float toFloat(TraderMarketSnapshot snapshot, Object rawValue, String featureName, int index) { + if (rawValue == null) { + log.warn("event=trader.features.missing runId={} cycleId={} featureIndex={} featureName={}", + snapshot.runId(), snapshot.cycleId(), index + 1, featureName); + throw modelInputException("snapshot feature is missing: " + featureName); + } + double value; + if (rawValue instanceof BigDecimal decimal) { + value = decimal.doubleValue(); + } else if (rawValue instanceof Number number) { + value = number.doubleValue(); + } else { + log.warn("event=trader.features.non_numeric runId={} cycleId={} featureIndex={} featureName={} valueType={}", + snapshot.runId(), snapshot.cycleId(), index + 1, featureName, rawValue.getClass().getName()); + throw modelInputException("snapshot feature must be numeric: " + featureName); + } + if (!Double.isFinite(value)) { + log.warn("event=trader.features.non_finite runId={} cycleId={} featureIndex={} featureName={} value={}", + snapshot.runId(), snapshot.cycleId(), index + 1, featureName, value); + throw modelInputException("snapshot feature must be finite: " + featureName); + } + return (float) value; + } + + private static TraderException modelInputException(String message) { + return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message); + } +} diff --git a/src/main/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepository.java b/src/main/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepository.java new file mode 100644 index 0000000..18f469f --- /dev/null +++ b/src/main/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepository.java @@ -0,0 +1,38 @@ +package com.quantai.trader.feedback; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.domain.TraderAppFeedback; +import com.quantai.trader.persistence.TraderJsonCodec; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Repository; + +import java.sql.Timestamp; + +@Repository +public class JdbcTraderFeedbackRepository implements TraderFeedbackRepository { + private final JdbcTemplate jdbcTemplate; + private final TraderJsonCodec jsonCodec; + + public JdbcTraderFeedbackRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) { + this.jdbcTemplate = jdbcTemplate; + this.jsonCodec = new TraderJsonCodec(objectMapper); + } + + @Override + public void insert(TraderAppFeedback feedback) { + jdbcTemplate.update(""" + insert into trader_app_feedback + (run_id, cycle_id, feedback_id, action_id, feedback_source, is_real_fill, + order_id, order_status, app_received_time, exchange_ack_time, filled_time, + filled_price, filled_quantity, fee, slippage_bps, reject_reason, raw_feedback_json) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + feedback.runId(), feedback.cycleId(), feedback.feedbackId(), feedback.actionId(), + feedback.feedbackSource().name(), feedback.realFill(), feedback.orderId(), feedback.orderStatus(), + feedback.appReceivedTime() == null ? null : Timestamp.from(feedback.appReceivedTime()), + feedback.exchangeAckTime() == null ? null : Timestamp.from(feedback.exchangeAckTime()), + feedback.filledTime() == null ? null : Timestamp.from(feedback.filledTime()), + feedback.filledPrice(), feedback.filledQuantity(), feedback.fee(), feedback.slippageBps(), + feedback.rejectReason(), jsonCodec.toJson(feedback.rawFeedbackJson())); + } +} diff --git a/src/main/java/com/quantai/trader/feedback/TraderFeedbackRepository.java b/src/main/java/com/quantai/trader/feedback/TraderFeedbackRepository.java new file mode 100644 index 0000000..6f0cf9f --- /dev/null +++ b/src/main/java/com/quantai/trader/feedback/TraderFeedbackRepository.java @@ -0,0 +1,7 @@ +package com.quantai.trader.feedback; + +import com.quantai.trader.domain.TraderAppFeedback; + +public interface TraderFeedbackRepository { + void insert(TraderAppFeedback feedback); +} diff --git a/src/main/java/com/quantai/trader/model/ArtifactTraderModelService.java b/src/main/java/com/quantai/trader/model/ArtifactTraderModelService.java deleted file mode 100644 index 8f6bf7d..0000000 --- a/src/main/java/com/quantai/trader/model/ArtifactTraderModelService.java +++ /dev/null @@ -1,72 +0,0 @@ -package com.quantai.trader.model; - -import com.quantai.trader.artifact.TraderArtifactBundle; -import com.quantai.trader.artifact.TraderArtifactModelPolicy; -import com.quantai.trader.domain.*; -import org.springframework.stereotype.Service; - -import java.math.BigDecimal; -import java.util.Map; - -@Service -public class ArtifactTraderModelService implements TraderModelService { - @Override - public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) { - TraderArtifactModelPolicy policy = bundle.modelPolicy(); - DirectionOutput direction = direction(snapshot, bundle, policy.direction()); - return new TraderModelOutput( - "model_output_" + snapshot.cycleId(), - snapshot.runId(), - snapshot.cycleId(), - bundle.modelBundleVersion(), - bundle.calibrationBundleVersion(), - direction, - entry(bundle, policy.entry()), - continuation(bundle, policy.continuation()), - exit(bundle, policy.exit()), - risk(snapshot, bundle, policy.risk()), - policy.uncertainty(), - policy.oodScore(), - Map.of("artifactPolicy", bundle.bundleHashSha256())); - } - - private DirectionOutput direction(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle, TraderArtifactModelPolicy.DirectionPolicy policy) { - BigDecimal longProb = snapshot.markPrice().compareTo(snapshot.indexPrice()) >= 0 - ? policy.longProbWhenMarkGteIndex() - : policy.longProbWhenMarkLtIndex(); - BigDecimal shortProb = BigDecimal.ONE.subtract(longProb).subtract(policy.neutralProb()).max(BigDecimal.ZERO); - BigDecimal neutralProb = BigDecimal.ONE.subtract(longProb).subtract(shortProb); - return new DirectionOutput(longProb, shortProb, neutralProb, longProb.max(shortProb), - longProb.subtract(shortProb).abs(), policy.expectedReturnBps(), policy.horizonMinutes(), - policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy")); - } - - private EntryOutput entry(TraderArtifactBundle bundle, TraderArtifactModelPolicy.EntryPolicy policy) { - return new EntryOutput(policy.longEntryProb(), policy.shortEntryProb(), policy.entryQualityScore(), - policy.expectedEdgeBps(), policy.pricePlanId(), policy.pricePlanConfigHash(), - policy.stopDistanceBps(), policy.targetDistanceBps(), policy.maxHoldMinutes(), policy.costBps(), - policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy")); - } - - private ContinueOutput continuation(TraderArtifactBundle bundle, TraderArtifactModelPolicy.ContinuePolicy policy) { - return new ContinueOutput(policy.longContinueProb(), policy.shortContinueProb(), policy.trendPersistenceProb(), - policy.holdEdgeBps(), policy.continueVsExitEdgeBps(), - policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy")); - } - - private ExitOutput exit(TraderArtifactBundle bundle, TraderArtifactModelPolicy.ExitPolicy policy) { - return new ExitOutput(policy.longExitProb(), policy.shortExitProb(), policy.profitGivebackProb(), - policy.reversalProb(), policy.stopRiskProb(), policy.stagnationProb(), policy.expectedGivebackBps(), - policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy")); - } - - private RiskOutput risk(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle, TraderArtifactModelPolicy.RiskPolicy policy) { - BigDecimal liquidityCapacity = snapshot.dataReady() - ? policy.liquidityCapacityRatioWhenReady() - : policy.liquidityCapacityRatioWhenNotReady(); - return new RiskOutput(policy.marketRiskProb(), policy.positionRiskProb(), policy.marketRiskSeverityBps(), - policy.positionRiskSeverityBps(), policy.drawdownProb(), policy.expectedShortfallBps(), - policy.volatilityExpansionProb(), policy.spikeProb(), policy.liquidityRiskProb(), liquidityCapacity, - policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy")); - } -} diff --git a/src/main/java/com/quantai/trader/model/OnnxTraderModelService.java b/src/main/java/com/quantai/trader/model/OnnxTraderModelService.java new file mode 100644 index 0000000..9d3cdd9 --- /dev/null +++ b/src/main/java/com/quantai/trader/model/OnnxTraderModelService.java @@ -0,0 +1,267 @@ +package com.quantai.trader.model; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.artifact.TraderArtifactBundle; +import com.quantai.trader.artifact.TraderCalibrationManifest; +import com.quantai.trader.artifact.TraderModelManifest; +import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.domain.*; +import com.quantai.trader.enums.TraderErrorCode; +import com.quantai.trader.feature.TraderFeatureVectorBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.math.BigDecimal; +import java.nio.file.Path; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +@Service +public class OnnxTraderModelService implements TraderModelService { + private static final Logger log = LoggerFactory.getLogger(OnnxTraderModelService.class); + private static final Pattern OUTPUT_REFERENCE = Pattern.compile("^([A-Za-z0-9_]+)\\[(\\d+)]$"); + private static final Set REQUIRED_TYPES = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"); + + private final TraderProperties properties; + private final ObjectMapper objectMapper; + private final TraderFeatureVectorBuilder featureVectorBuilder; + private final TraderOnnxInferenceClient inferenceClient; + + public OnnxTraderModelService(TraderProperties properties, ObjectMapper objectMapper, + TraderFeatureVectorBuilder featureVectorBuilder, + TraderOnnxInferenceClient inferenceClient) { + this.properties = properties; + this.objectMapper = objectMapper; + this.featureVectorBuilder = featureVectorBuilder; + this.inferenceClient = inferenceClient; + } + + @Override + public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) { + float[] features = featureVectorBuilder.build(snapshot, bundle); + Map manifests = manifestsByType(bundle); + Map> rawOutputs = new LinkedHashMap<>(); + Path artifactRoot = Path.of(properties.artifact().artifactRoot()); + log.info("event=trader.model.onnx_started runId={} cycleId={} modelBundleVersion={} calibrationBundleVersion={} featureCount={}", + snapshot.runId(), snapshot.cycleId(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), features.length); + for (TraderModelManifest manifest : manifests.values()) { + Path modelPath = artifactRoot.resolve(manifest.artifactPath()); + log.debug("event=trader.model.onnx_model_started runId={} cycleId={} modelName={} modelType={} modelPath={}", + snapshot.runId(), snapshot.cycleId(), manifest.modelName(), manifest.modelType(), modelPath); + Map tensors = inferenceClient.infer(manifest, modelPath, features); + Map mapped = mappedOutputs(manifest, tensors); + rawOutputs.put(manifest.modelType(), mapped); + log.debug("event=trader.model.onnx_model_mapped runId={} cycleId={} modelName={} outputFieldCount={}", + snapshot.runId(), snapshot.cycleId(), manifest.modelName(), mapped.size()); + } + + DirectionOutput direction = direction(rawOutputs.get("DIRECTION"), calibrator(bundle, manifests.get("DIRECTION"))); + EntryOutput entry = entry(rawOutputs.get("ENTRY"), calibrator(bundle, manifests.get("ENTRY")), + bounds(manifests.get("ENTRY"))); + ContinueOutput continuation = continuation(rawOutputs.get("CONTINUE"), calibrator(bundle, manifests.get("CONTINUE")), + bounds(manifests.get("CONTINUE"))); + ExitOutput exit = exit(rawOutputs.get("EXIT"), calibrator(bundle, manifests.get("EXIT")), + bounds(manifests.get("EXIT"))); + RiskOutput risk = risk(rawOutputs.get("RISK"), calibrator(bundle, manifests.get("RISK")), + bounds(manifests.get("RISK"))); + + TraderModelOutput output = new TraderModelOutput( + "model_output_" + snapshot.cycleId(), + snapshot.runId(), + snapshot.cycleId(), + metadata(bundle, manifests, direction, snapshot), + direction, + entry, + continuation, + exit, + risk); + log.info("event=trader.model.onnx_evaluated runId={} cycleId={} modelBundleVersion={} featureCount={} directionLong={} directionShort={} marketRisk={}", + snapshot.runId(), snapshot.cycleId(), bundle.modelBundleVersion(), features.length, + direction.longProb(), direction.shortProb(), risk.marketRiskProb()); + return output; + } + + private Map manifestsByType(TraderArtifactBundle bundle) { + Map manifests = new LinkedHashMap<>(); + for (TraderModelManifest manifest : bundle.modelManifests()) { + TraderModelManifest previous = manifests.putIfAbsent(manifest.modelType(), manifest); + if (previous != null) { + throw modelException("model_manifest.json must contain exactly one manifest per V4 model type: " + + manifest.modelType()); + } + } + if (!manifests.keySet().containsAll(REQUIRED_TYPES)) { + throw modelException("model_manifest.json does not contain all five V4 model types"); + } + return Map.copyOf(manifests); + } + + private Map mappedOutputs(TraderModelManifest manifest, Map tensors) { + Map mapped = new LinkedHashMap<>(); + for (Map.Entry entry : manifest.outputMapping().entrySet()) { + // manifest 只允许 tensor[index] 显式映射,避免 Java 根据位置临时猜字段。 + if (!(entry.getValue() instanceof String referenceText)) { + throw modelException("output_mapping_json value must be tensor[index]: " + manifest.modelName()); + } + Matcher matcher = OUTPUT_REFERENCE.matcher(referenceText); + if (!matcher.matches()) { + throw modelException("output_mapping_json value must be tensor[index]: " + referenceText); + } + String tensorName = matcher.group(1); + int index = Integer.parseInt(matcher.group(2)); + float[] values = tensors.get(tensorName); + if (values == null) { + throw modelException("ONNX output tensor is missing: model=" + manifest.modelName() + + ", tensor=" + tensorName); + } + if (index < 0 || index >= values.length) { + throw modelException("ONNX output tensor index is out of range: model=" + manifest.modelName() + + ", mapping=" + referenceText); + } + mapped.put(entry.getKey(), BigDecimal.valueOf(values[index])); + } + return Map.copyOf(mapped); + } + + private DirectionOutput direction(Map raw, TraderProbabilityCalibrator calibrator) { + return new DirectionOutput( + probability(raw, calibrator, "long_prob", "longProb"), + probability(raw, calibrator, "short_prob", "shortProb"), + probability(raw, calibrator, "neutral_prob", "neutralProb")); + } + + private EntryOutput entry(Map raw, TraderProbabilityCalibrator calibrator, + TraderOutputSchemaBounds bounds) { + return new EntryOutput( + probability(raw, calibrator, "long_entry_prob", "longEntryProb"), + probability(raw, calibrator, "short_entry_prob", "shortEntryProb"), + bps(raw, bounds, "long_expected_net_edge_bps", "longExpectedNetEdgeBps"), + bps(raw, bounds, "short_expected_net_edge_bps", "shortExpectedNetEdgeBps")); + } + + private ContinueOutput continuation(Map raw, TraderProbabilityCalibrator calibrator, + TraderOutputSchemaBounds bounds) { + return new ContinueOutput( + probability(raw, calibrator, "long_continue_prob", "longContinueProb"), + probability(raw, calibrator, "short_continue_prob", "shortContinueProb"), + bps(raw, bounds, "long_expected_continue_edge_bps", "longExpectedContinueEdgeBps"), + bps(raw, bounds, "short_expected_continue_edge_bps", "shortExpectedContinueEdgeBps")); + } + + private ExitOutput exit(Map raw, TraderProbabilityCalibrator calibrator, + TraderOutputSchemaBounds bounds) { + return new ExitOutput( + probability(raw, calibrator, "long_exit_prob", "longExitProb"), + probability(raw, calibrator, "short_exit_prob", "shortExitProb"), + bps(raw, bounds, "long_adverse_move_bps", "longAdverseMoveBps"), + bps(raw, bounds, "short_adverse_move_bps", "shortAdverseMoveBps"), + Map.of( + "adverse_move_prob", probability(raw, calibrator, "adverse_move_prob", "adverse_move_prob"), + "reversal_prob", probability(raw, calibrator, "reversal_prob", "reversal_prob"), + "stop_hit_prob", probability(raw, calibrator, "stop_hit_prob", "stop_hit_prob"), + "stagnation_prob", probability(raw, calibrator, "stagnation_prob", "stagnation_prob"))); + } + + private RiskOutput risk(Map raw, TraderProbabilityCalibrator calibrator, + TraderOutputSchemaBounds bounds) { + return new RiskOutput( + probability(raw, calibrator, "market_risk_prob", "marketRiskProb"), + probability(raw, calibrator, "long_position_risk_prob", "longPositionRiskProb"), + probability(raw, calibrator, "short_position_risk_prob", "shortPositionRiskProb"), + bps(raw, bounds, "market_path_risk_bps", "marketPathRiskBps"), + bps(raw, bounds, "long_position_path_risk_bps", "longPositionPathRiskBps"), + bps(raw, bounds, "short_position_path_risk_bps", "shortPositionPathRiskBps"), + Map.of( + "market_drawdown_prob", probability(raw, calibrator, "market_drawdown_prob", "market_drawdown_prob"), + "volatility_expansion_prob", probability(raw, calibrator, "volatility_expansion_prob", "volatility_expansion_prob"), + "spike_prob", probability(raw, calibrator, "spike_prob", "spike_prob"), + "liquidity_deterioration_prob", probability(raw, calibrator, "liquidity_deterioration_prob", "liquidity_deterioration_prob"), + "position_drawdown_prob", probability(raw, calibrator, "position_drawdown_prob", "position_drawdown_prob"))); + } + + private TraderModelOutputMetadata metadata(TraderArtifactBundle bundle, Map manifests, + DirectionOutput direction, TraderMarketSnapshot snapshot) { + Map modelVersions = manifests.values().stream() + .collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, TraderModelManifest::sourceHash)); + Map calibrationVersions = bundle.calibrationManifests().stream() + .collect(Collectors.toUnmodifiableMap(TraderCalibrationManifest::modelName, TraderCalibrationManifest::calibratorVersion)); + TraderModelManifest reference = manifests.get("DIRECTION"); + // 第一版不再单独输出 uncertainty 子模型,用方向最大概率的反面表示本轮不确定性。 + BigDecimal uncertainty = BigDecimal.ONE.subtract(direction.confidence()); + BigDecimal oodScore = dataQualityProbability(snapshot, "ood_score"); + return new TraderModelOutputMetadata( + bundle.modelBundleVersion(), + bundle.calibrationBundleVersion(), + modelVersions, + calibrationVersions, + reference.featureSchemaHash(), + reference.featureOrderHash(), + reference.outputSchemaHash(), + uncertainty, + oodScore); + } + + private TraderProbabilityCalibrator calibrator(TraderArtifactBundle bundle, TraderModelManifest manifest) { + TraderCalibrationManifest calibrationManifest = bundle.calibrationManifests().stream() + .filter(item -> item.modelName().equals(manifest.modelName())) + .findFirst() + .orElseThrow(() -> modelException("calibration manifest is missing for model: " + manifest.modelName())); + Path path = Path.of(properties.artifact().artifactRoot()).resolve(calibrationManifest.calibratorPath()); + log.debug("event=trader.model.calibrator_loaded modelName={} calibratorVersion={} calibratorPath={}", + manifest.modelName(), calibrationManifest.calibratorVersion(), path); + return TraderProbabilityCalibrator.read(objectMapper, path, manifest.modelName()); + } + + private TraderOutputSchemaBounds bounds(TraderModelManifest manifest) { + return TraderOutputSchemaBounds.read(objectMapper, + Path.of(properties.artifact().artifactRoot()).resolve(manifest.outputSchemaPath())); + } + + private BigDecimal probability(Map raw, TraderProbabilityCalibrator calibrator, + String outputKey, String targetName) { + BigDecimal value = requiredRaw(raw, outputKey); + return calibrator.calibrate(targetName, value); + } + + private BigDecimal bps(Map raw, TraderOutputSchemaBounds bounds, + String outputKey, String fieldName) { + return bounds.clip(fieldName, requiredRaw(raw, outputKey)); + } + + private BigDecimal requiredRaw(Map raw, String outputKey) { + BigDecimal value = raw.get(outputKey); + if (value == null) { + throw modelException("ONNX mapped output is missing: " + outputKey); + } + return value; + } + + private BigDecimal dataQualityProbability(TraderMarketSnapshot snapshot, String field) { + Object raw = snapshot.dataQualityJson().get(field); + BigDecimal value; + if (raw instanceof BigDecimal decimal) { + value = decimal; + } else if (raw instanceof Number number) { + value = BigDecimal.valueOf(number.doubleValue()); + } else { + log.warn("event=trader.model.ood_missing runId={} cycleId={} field={}", + snapshot.runId(), snapshot.cycleId(), field); + throw modelException("snapshot dataQualityJson must contain numeric field: " + field); + } + if (value.compareTo(BigDecimal.ZERO) < 0 || value.compareTo(BigDecimal.ONE) > 0) { + log.warn("event=trader.model.ood_out_of_range runId={} cycleId={} field={} value={}", + snapshot.runId(), snapshot.cycleId(), field, value); + throw modelException("snapshot dataQualityJson probability is outside [0,1]: " + field); + } + return value; + } + + private static TraderException modelException(String message) { + return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message); + } +} diff --git a/src/main/java/com/quantai/trader/model/OrtTraderOnnxInferenceClient.java b/src/main/java/com/quantai/trader/model/OrtTraderOnnxInferenceClient.java new file mode 100644 index 0000000..5f56602 --- /dev/null +++ b/src/main/java/com/quantai/trader/model/OrtTraderOnnxInferenceClient.java @@ -0,0 +1,78 @@ +package com.quantai.trader.model; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import com.quantai.trader.artifact.TraderModelManifest; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.enums.TraderErrorCode; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.stereotype.Component; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +@Component +public class OrtTraderOnnxInferenceClient implements TraderOnnxInferenceClient, DisposableBean { + private final OrtEnvironment environment = OrtEnvironment.getEnvironment(); + private final ConcurrentMap sessions = new ConcurrentHashMap<>(); + + @Override + public Map infer(TraderModelManifest manifest, Path modelPath, float[] features) { + try { + OrtSession session = sessions.computeIfAbsent(modelPath.toAbsolutePath().normalize(), this::createSession); + try (OnnxTensor input = OnnxTensor.createTensor(environment, new float[][]{features}); + OrtSession.Result result = session.run(Map.of(manifest.inputTensorName(), input))) { + Map outputs = new LinkedHashMap<>(); + for (Map.Entry entry : result) { + outputs.put(entry.getKey(), flatten(entry.getValue())); + } + return Map.copyOf(outputs); + } + } catch (OrtException exception) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "ONNX inference failed for model: " + manifest.modelName()); + } + } + + private OrtSession createSession(Path path) { + try { + // ONNX session 建立成本高,外层按模型文件路径缓存,进程退出时统一关闭。 + return environment.createSession(path.toString(), new OrtSession.SessionOptions()); + } catch (OrtException exception) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "ONNX model cannot be loaded: " + path); + } + } + + private float[] flatten(OnnxValue value) { + try { + Object raw = value.getValue(); + if (raw instanceof float[] values) { + return Arrays.copyOf(values, values.length); + } + if (raw instanceof float[][] matrix && matrix.length == 1) { + return Arrays.copyOf(matrix[0], matrix[0].length); + } + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "ONNX output tensor must be FLOAT32 with shape [n] or [1,n]"); + } catch (OrtException exception) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "ONNX output tensor cannot be read"); + } + } + + @Override + public void destroy() throws Exception { + for (OrtSession session : sessions.values()) { + session.close(); + } + sessions.clear(); + } +} diff --git a/src/main/java/com/quantai/trader/model/ReplayFixtureTraderModelService.java b/src/main/java/com/quantai/trader/model/ReplayFixtureTraderModelService.java new file mode 100644 index 0000000..27d461d --- /dev/null +++ b/src/main/java/com/quantai/trader/model/ReplayFixtureTraderModelService.java @@ -0,0 +1,94 @@ +package com.quantai.trader.model; + +import com.quantai.trader.artifact.TraderArtifactBundle; +import com.quantai.trader.artifact.TraderModelManifest; +import com.quantai.trader.artifact.TraderReplayModelFixture; +import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.domain.*; +import com.quantai.trader.enums.TraderErrorCode; +import com.quantai.trader.enums.TraderRunMode; +import org.springframework.stereotype.Service; + +import java.math.BigDecimal; +import java.util.Map; +import java.util.stream.Collectors; + +@Service +public class ReplayFixtureTraderModelService implements TraderModelService { + private final TraderRunMode runMode; + + public ReplayFixtureTraderModelService(TraderProperties properties) { + this.runMode = properties.runMode(); + } + + public ReplayFixtureTraderModelService() { + this.runMode = TraderRunMode.REPLAY_SIM; + } + + @Override + public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) { + if (runMode != TraderRunMode.REPLAY_SIM) { + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "replay fixture model service only allows REPLAY_SIM"); + } + TraderReplayModelFixture fixture = bundle.requireReplayModelFixture(); + DirectionOutput direction = direction(snapshot, fixture.direction()); + return new TraderModelOutput( + "model_output_" + snapshot.cycleId(), + snapshot.runId(), + snapshot.cycleId(), + metadata(bundle, fixture), + direction, + entry(fixture.entry()), + continuation(fixture.continuation()), + exit(fixture.exit()), + risk(fixture.risk())); + } + + private TraderModelOutputMetadata metadata(TraderArtifactBundle bundle, TraderReplayModelFixture fixture) { + Map modelVersions = bundle.modelManifests().stream() + .collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, ignored -> bundle.modelBundleVersion())); + Map calibrationVersions = bundle.modelManifests().stream() + .collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, ignored -> bundle.calibrationBundleVersion())); + return new TraderModelOutputMetadata( + bundle.modelBundleVersion(), + bundle.calibrationBundleVersion(), + modelVersions, + calibrationVersions, + fixture.featureSchemaHash(), + fixture.featureOrderHash(), + fixture.outputSchemaHash(), + fixture.uncertainty(), + fixture.oodScore()); + } + + private DirectionOutput direction(TraderMarketSnapshot snapshot, TraderReplayModelFixture.DirectionFixture fixture) { + BigDecimal longProb = snapshot.markPrice().compareTo(snapshot.indexPrice()) >= 0 + ? fixture.longProbWhenMarkGteIndex() + : fixture.longProbWhenMarkLtIndex(); + BigDecimal shortProb = BigDecimal.ONE.subtract(longProb).subtract(fixture.neutralProb()).max(BigDecimal.ZERO); + BigDecimal neutralProb = BigDecimal.ONE.subtract(longProb).subtract(shortProb); + return new DirectionOutput(longProb, shortProb, neutralProb); + } + + private EntryOutput entry(TraderReplayModelFixture.EntryFixture fixture) { + return new EntryOutput(fixture.longEntryProb(), fixture.shortEntryProb(), + fixture.longExpectedNetEdgeBps(), fixture.shortExpectedNetEdgeBps()); + } + + private ContinueOutput continuation(TraderReplayModelFixture.ContinueFixture fixture) { + return new ContinueOutput(fixture.longContinueProb(), fixture.shortContinueProb(), + fixture.longExpectedContinueEdgeBps(), fixture.shortExpectedContinueEdgeBps()); + } + + private ExitOutput exit(TraderReplayModelFixture.ExitFixture fixture) { + return new ExitOutput(fixture.longExitProb(), fixture.shortExitProb(), + fixture.longAdverseMoveBps(), fixture.shortAdverseMoveBps(), fixture.exitReasonScores()); + } + + private RiskOutput risk(TraderReplayModelFixture.RiskFixture fixture) { + return new RiskOutput(fixture.marketRiskProb(), fixture.longPositionRiskProb(), fixture.shortPositionRiskProb(), + fixture.marketPathRiskBps(), fixture.longPositionPathRiskBps(), fixture.shortPositionPathRiskBps(), + fixture.riskReasonScores()); + } +} diff --git a/src/main/java/com/quantai/trader/model/RoutingTraderModelService.java b/src/main/java/com/quantai/trader/model/RoutingTraderModelService.java new file mode 100644 index 0000000..794eafe --- /dev/null +++ b/src/main/java/com/quantai/trader/model/RoutingTraderModelService.java @@ -0,0 +1,49 @@ +package com.quantai.trader.model; + +import com.quantai.trader.artifact.TraderArtifactBundle; +import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.domain.TraderMarketSnapshot; +import com.quantai.trader.domain.TraderModelOutput; +import com.quantai.trader.enums.TraderErrorCode; +import com.quantai.trader.enums.TraderRunMode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.context.annotation.Primary; +import org.springframework.stereotype.Service; + +@Primary +@Service +public class RoutingTraderModelService implements TraderModelService { + private static final Logger log = LoggerFactory.getLogger(RoutingTraderModelService.class); + + private final TraderProperties properties; + private final ReplayFixtureTraderModelService replayFixtureTraderModelService; + private final OnnxTraderModelService onnxTraderModelService; + + public RoutingTraderModelService(TraderProperties properties, + ReplayFixtureTraderModelService replayFixtureTraderModelService, + OnnxTraderModelService onnxTraderModelService) { + this.properties = properties; + this.replayFixtureTraderModelService = replayFixtureTraderModelService; + this.onnxTraderModelService = onnxTraderModelService; + } + + @Override + public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) { + if (properties.runMode() == TraderRunMode.REPLAY_SIM) { + log.info("event=trader.model.route runId={} cycleId={} runMode=REPLAY_SIM modelService=ReplayFixtureTraderModelService", + snapshot.runId(), snapshot.cycleId()); + return replayFixtureTraderModelService.evaluate(snapshot, bundle); + } + if (properties.runMode() == TraderRunMode.SHADOW) { + log.info("event=trader.model.route runId={} cycleId={} runMode=SHADOW modelService=OnnxTraderModelService", + snapshot.runId(), snapshot.cycleId()); + return onnxTraderModelService.evaluate(snapshot, bundle); + } + log.warn("event=trader.model.route_rejected runId={} cycleId={} runMode={}", + snapshot.runId(), snapshot.cycleId(), properties.runMode()); + throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, + "P0 model service only supports REPLAY_SIM and SHADOW"); + } +} diff --git a/src/main/java/com/quantai/trader/model/TraderOnnxInferenceClient.java b/src/main/java/com/quantai/trader/model/TraderOnnxInferenceClient.java new file mode 100644 index 0000000..d738c67 --- /dev/null +++ b/src/main/java/com/quantai/trader/model/TraderOnnxInferenceClient.java @@ -0,0 +1,10 @@ +package com.quantai.trader.model; + +import com.quantai.trader.artifact.TraderModelManifest; + +import java.nio.file.Path; +import java.util.Map; + +public interface TraderOnnxInferenceClient { + Map infer(TraderModelManifest manifest, Path modelPath, float[] features); +} diff --git a/src/main/java/com/quantai/trader/model/TraderOutputSchemaBounds.java b/src/main/java/com/quantai/trader/model/TraderOutputSchemaBounds.java new file mode 100644 index 0000000..b933626 --- /dev/null +++ b/src/main/java/com/quantai/trader/model/TraderOutputSchemaBounds.java @@ -0,0 +1,84 @@ +package com.quantai.trader.model; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.enums.TraderErrorCode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +final class TraderOutputSchemaBounds { + private static final Logger log = LoggerFactory.getLogger(TraderOutputSchemaBounds.class); + + private final Map ranges; + + private TraderOutputSchemaBounds(Map ranges) { + this.ranges = Map.copyOf(ranges); + } + + static TraderOutputSchemaBounds read(ObjectMapper objectMapper, Path path) { + if (!Files.isRegularFile(path)) { + throw modelException("output_schema.json is missing: " + path); + } + try { + JsonNode root = objectMapper.readTree(path.toFile()); + Map ranges = new HashMap<>(); + collectRanges(root, ranges); + if (ranges.isEmpty()) { + throw modelException("output_schema.json must define field ranges: " + path); + } + log.debug("event=trader.output_schema.loaded path={} rangedFieldCount={}", path, ranges.size()); + return new TraderOutputSchemaBounds(ranges); + } catch (IOException exception) { + throw modelException("output_schema.json cannot be read: " + path); + } + } + + BigDecimal clip(String fieldName, BigDecimal value) { + Range range = ranges.get(fieldName); + if (range == null) { + log.warn("event=trader.output_schema.range_missing field={}", fieldName); + throw modelException("output_schema.json does not define range for field: " + fieldName); + } + if (value.compareTo(range.min()) < 0) { + log.debug("event=trader.output_schema.clipped field={} raw={} clipped={}", fieldName, value, range.min()); + return range.min(); + } + if (value.compareTo(range.max()) > 0) { + log.debug("event=trader.output_schema.clipped field={} raw={} clipped={}", fieldName, value, range.max()); + return range.max(); + } + return value; + } + + private static void collectRanges(JsonNode node, Map ranges) { + if (!node.isObject()) { + return; + } + Iterator> iterator = node.properties().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + JsonNode child = entry.getValue(); + JsonNode rangeNode = child.path("range"); + if (rangeNode.isArray() && rangeNode.size() == 2 && rangeNode.get(0).isNumber() && rangeNode.get(1).isNumber()) { + ranges.put(entry.getKey(), new Range(rangeNode.get(0).decimalValue(), rangeNode.get(1).decimalValue())); + } + collectRanges(child, ranges); + } + } + + private static TraderException modelException(String message) { + return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message); + } + + private record Range(BigDecimal min, BigDecimal max) { + } +} diff --git a/src/main/java/com/quantai/trader/model/TraderProbabilityCalibrator.java b/src/main/java/com/quantai/trader/model/TraderProbabilityCalibrator.java new file mode 100644 index 0000000..3636ad3 --- /dev/null +++ b/src/main/java/com/quantai/trader/model/TraderProbabilityCalibrator.java @@ -0,0 +1,129 @@ +package com.quantai.trader.model; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.enums.TraderErrorCode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.StreamSupport; + +final class TraderProbabilityCalibrator { + private static final Logger log = LoggerFactory.getLogger(TraderProbabilityCalibrator.class); + + private final String modelName; + private final String method; + private final BigDecimal clipMin; + private final BigDecimal clipMax; + private final Map> targets; + + private TraderProbabilityCalibrator(String modelName, String method, BigDecimal clipMin, BigDecimal clipMax, + Map> targets) { + this.modelName = modelName; + this.method = method; + this.clipMin = clipMin; + this.clipMax = clipMax; + this.targets = targets; + } + + static TraderProbabilityCalibrator read(ObjectMapper objectMapper, Path path, String modelName) { + if (!Files.isRegularFile(path)) { + throw modelException("calibrator is missing: " + path); + } + try { + JsonNode root = objectMapper.readTree(path.toFile()); + String method = requiredText(root, "method", path); + if (!"BINNING".equals(method)) { + throw modelException("probability calibrator method must be BINNING for model " + modelName); + } + JsonNode clip = root.path("clip"); + BigDecimal clipMin = decimal(clip.path("min"), "clip.min", path); + BigDecimal clipMax = decimal(clip.path("max"), "clip.max", path); + JsonNode targetsNode = root.path("targets"); + if (!targetsNode.isObject() || targetsNode.isEmpty()) { + throw modelException("calibrator targets must not be empty: " + path); + } + Map> targets = StreamSupport.stream(targetsNode.properties().spliterator(), false) + .collect(java.util.stream.Collectors.toUnmodifiableMap( + Map.Entry::getKey, + entry -> bins(entry.getValue().path("bins"), entry.getKey(), path))); + log.debug("event=trader.calibrator.loaded modelName={} method={} targetCount={} path={}", + modelName, method, targets.size(), path); + return new TraderProbabilityCalibrator(modelName, method, clipMin, clipMax, targets); + } catch (IOException exception) { + throw modelException("calibrator cannot be read: " + path); + } + } + + BigDecimal calibrate(String targetName, BigDecimal rawProbability) { + if (!"BINNING".equals(method)) { + throw modelException("unsupported calibrator method for model " + modelName); + } + List bins = targets.get(targetName); + if (bins == null || bins.isEmpty()) { + log.warn("event=trader.calibrator.target_missing modelName={} target={}", modelName, targetName); + throw modelException("calibrator target is missing: model=" + modelName + ", target=" + targetName); + } + for (Bin bin : bins) { + if (rawProbability.compareTo(bin.min()) >= 0 && rawProbability.compareTo(bin.max()) <= 0) { + BigDecimal calibrated = bin.calibrated(); + if (calibrated.compareTo(clipMin) < 0 || calibrated.compareTo(clipMax) > 0 + || calibrated.compareTo(BigDecimal.ZERO) < 0 || calibrated.compareTo(BigDecimal.ONE) > 0) { + log.warn("event=trader.calibrator.output_out_of_range modelName={} target={} raw={} calibrated={} clipMin={} clipMax={}", + modelName, targetName, rawProbability, calibrated, clipMin, clipMax); + throw modelException("calibrated probability is outside allowed range: model=" + + modelName + ", target=" + targetName); + } + return calibrated; + } + } + log.warn("event=trader.calibrator.bin_missing modelName={} target={} raw={}", + modelName, targetName, rawProbability); + throw modelException("raw probability does not match any calibrator bin: model=" + + modelName + ", target=" + targetName + ", value=" + rawProbability); + } + + private static List bins(JsonNode binsNode, String targetName, Path path) { + if (!binsNode.isArray() || binsNode.isEmpty()) { + throw modelException("calibrator target has no bins: " + targetName + " in " + path); + } + List bins = new ArrayList<>(); + for (JsonNode node : binsNode) { + bins.add(new Bin( + decimal(node.path("min"), targetName + ".min", path), + decimal(node.path("max"), targetName + ".max", path), + decimal(node.path("calibrated"), targetName + ".calibrated", path))); + } + return List.copyOf(bins); + } + + private static String requiredText(JsonNode node, String field, Path path) { + String value = node.path(field).asText(""); + if (value.isBlank()) { + throw modelException("calibrator field is required: " + field + " in " + path); + } + return value; + } + + private static BigDecimal decimal(JsonNode node, String field, Path path) { + if (!node.isNumber()) { + throw modelException("calibrator numeric field is required: " + field + " in " + path); + } + return node.decimalValue(); + } + + private static TraderException modelException(String message) { + return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message); + } + + private record Bin(BigDecimal min, BigDecimal max, BigDecimal calibrated) { + } +} diff --git a/src/main/java/com/quantai/trader/outbox/JdbcTraderOutboxRepository.java b/src/main/java/com/quantai/trader/outbox/JdbcTraderOutboxRepository.java index 3fa7f1e..79a1d3d 100644 --- a/src/main/java/com/quantai/trader/outbox/JdbcTraderOutboxRepository.java +++ b/src/main/java/com/quantai/trader/outbox/JdbcTraderOutboxRepository.java @@ -1,6 +1,7 @@ package com.quantai.trader.outbox; import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.enums.PositionSide; import com.quantai.trader.persistence.TraderJsonCodec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,4 +36,31 @@ public class JdbcTraderOutboxRepository implements TraderOutboxRepository { log.info("event=trader.outbox.inserted runId={} cycleId={} destination={} aggregateType={} aggregateId={} status={}", event.runId(), event.cycleId(), event.destination(), event.aggregateType(), event.aggregateId(), event.status()); } + + @Override + public void markSent(String outboxId) { + jdbcTemplate.update(""" + update trader_outbox + set status = 'SENT', updated_at = current_timestamp(3) + where outbox_id = ? + """, + outboxId); + log.info("event=trader.outbox.sent outboxId={}", outboxId); + } + + @Override + public boolean hasUnsentExposureIncrease(String runId, String symbol, PositionSide side) { + Integer count = jdbcTemplate.queryForObject(""" + select count(*) + from trader_outbox o + join trader_action a on a.run_id = o.run_id and a.action_id = o.aggregate_id + where o.run_id = ? + and a.symbol = ? + and a.side = ? + and a.action_type in ('OPEN_LONG','OPEN_SHORT','ADD_LONG','ADD_SHORT') + and o.status in ('PENDING','SENDING','FAILED') + """, + Integer.class, runId, symbol, side.name()); + return count != null && count > 0; + } } diff --git a/src/main/java/com/quantai/trader/outbox/P0OutboxDispatcher.java b/src/main/java/com/quantai/trader/outbox/P0OutboxDispatcher.java new file mode 100644 index 0000000..576f7b4 --- /dev/null +++ b/src/main/java/com/quantai/trader/outbox/P0OutboxDispatcher.java @@ -0,0 +1,76 @@ +package com.quantai.trader.outbox; + +import com.quantai.trader.domain.TraderAction; +import com.quantai.trader.domain.TraderAppFeedback; +import com.quantai.trader.domain.TraderMarketSnapshot; +import com.quantai.trader.enums.FeedbackSource; +import com.quantai.trader.feedback.TraderFeedbackRepository; +import com.quantai.trader.replay.state.TraderPostActionStateRepository; +import com.quantai.trader.replay.state.TraderReplayState; +import com.quantai.trader.replay.state.TraderReplayStateStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; + +import java.time.Instant; +import java.util.Map; + +@Component +public class P0OutboxDispatcher implements TraderOutboxDispatcher { + private static final Logger log = LoggerFactory.getLogger(P0OutboxDispatcher.class); + + private final TraderOutboxRepository outboxRepository; + private final TraderFeedbackRepository feedbackRepository; + private final TraderReplayStateStore stateStore; + private final TraderPostActionStateRepository postActionStateRepository; + + public P0OutboxDispatcher(TraderOutboxRepository outboxRepository, + TraderFeedbackRepository feedbackRepository, + TraderReplayStateStore stateStore, + TraderPostActionStateRepository postActionStateRepository) { + this.outboxRepository = outboxRepository; + this.feedbackRepository = feedbackRepository; + this.stateStore = stateStore; + this.postActionStateRepository = postActionStateRepository; + } + + @Override + @Transactional + public void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination) { + TraderAppFeedback feedback = p0Feedback(action, snapshot, destination); + feedbackRepository.insert(feedback); + outboxRepository.markSent("outbox_" + action.actionId()); + TraderReplayState nextState = stateStore.advance(currentState, action, snapshot); + postActionStateRepository.insertPostActionState(nextState); + log.info("event=trader.outbox.dispatched runId={} cycleId={} action={} destination={} feedbackId={}", + action.runId(), action.cycleId(), action.actionType(), destination, feedback.feedbackId()); + } + + private TraderAppFeedback p0Feedback(TraderAction action, TraderMarketSnapshot snapshot, String destination) { + FeedbackSource source = switch (destination) { + case "REPLAY_SIM_EXECUTION" -> FeedbackSource.REPLAY_SIMULATOR; + case "SHADOW_RECORDER" -> FeedbackSource.SHADOW_APP; + default -> throw new IllegalArgumentException("P0 outbox destination is not allowed: " + destination); + }; + Instant now = snapshot.snapshotTime(); + return new TraderAppFeedback( + "feedback_" + action.actionId(), + action.runId(), + action.cycleId(), + action.actionId(), + source, + false, + null, + "RECORDED", + now, + now, + null, + null, + null, + null, + null, + null, + Map.of("destination", destination, "actionType", action.actionType().name())); + } +} diff --git a/src/main/java/com/quantai/trader/outbox/TraderOutboxDispatcher.java b/src/main/java/com/quantai/trader/outbox/TraderOutboxDispatcher.java new file mode 100644 index 0000000..7eb19e7 --- /dev/null +++ b/src/main/java/com/quantai/trader/outbox/TraderOutboxDispatcher.java @@ -0,0 +1,9 @@ +package com.quantai.trader.outbox; + +import com.quantai.trader.domain.TraderAction; +import com.quantai.trader.domain.TraderMarketSnapshot; +import com.quantai.trader.replay.state.TraderReplayState; + +public interface TraderOutboxDispatcher { + void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination); +} diff --git a/src/main/java/com/quantai/trader/outbox/TraderOutboxRepository.java b/src/main/java/com/quantai/trader/outbox/TraderOutboxRepository.java index c31f150..3f3ba08 100644 --- a/src/main/java/com/quantai/trader/outbox/TraderOutboxRepository.java +++ b/src/main/java/com/quantai/trader/outbox/TraderOutboxRepository.java @@ -1,5 +1,11 @@ package com.quantai.trader.outbox; +import com.quantai.trader.enums.PositionSide; + public interface TraderOutboxRepository { void insert(TraderOutboxEvent event); + + void markSent(String outboxId); + + boolean hasUnsentExposureIncrease(String runId, String symbol, PositionSide side); } diff --git a/src/main/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriter.java b/src/main/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriter.java index 2352d5c..66ae33e 100644 --- a/src/main/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriter.java +++ b/src/main/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriter.java @@ -30,10 +30,16 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter TraderMarketSnapshot snapshot, TraderModelOutput modelOutput, TraderPositionState positionState, + TraderAccountState accountState, + TraderExecutionState executionState, TraderPositionManagerDecision pmDecision, TraderRiskDecision riskDecision, TraderAction action) { upsertRun(cycle); + insertMarketSnapshot(snapshot); + insertPositionState(positionState, "PM_INPUT"); + insertAccountState(accountState, "PM_INPUT"); + insertExecutionState(executionState, "PM_INPUT"); insertCycle(cycle, positionState, riskDecision); insertModelOutput(modelOutput, snapshot); insertPmDecision(pmDecision); @@ -66,6 +72,63 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter Timestamp.from(cycle.cycleTime())); } + void insertMarketSnapshot(TraderMarketSnapshot snapshot) { + jdbcTemplate.update(""" + insert into trader_market_snapshot + (run_id, cycle_id, snapshot_id, symbol, snapshot_time, feature_version, + mark_price, index_price, spread_bps, funding_rate_bps, + depth_notional_5bps, depth_notional_10bps, depth_notional_25bps, + data_ready, feature_json, data_quality_json) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + snapshot.runId(), snapshot.cycleId(), snapshot.snapshotId(), snapshot.symbol(), + Timestamp.from(snapshot.snapshotTime()), snapshot.featureVersion(), snapshot.markPrice(), + snapshot.indexPrice(), snapshot.spreadBps(), snapshot.fundingRateBps(), + snapshot.depthNotional5Bps(), snapshot.depthNotional10Bps(), snapshot.depthNotional25Bps(), + snapshot.dataReady(), jsonCodec.toJson(snapshot.featureJson()), jsonCodec.toJson(snapshot.dataQualityJson())); + } + + void insertPositionState(TraderPositionState state, String role) { + jdbcTemplate.update(""" + insert into trader_position_state + (run_id, cycle_id, position_state_id, state_role, symbol, side, position_ratio, + average_entry_price, current_price, unrealized_pnl_bps, liquidation_buffer_bps, + add_count, remaining_add_capacity, last_add_time) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + state.runId(), state.cycleId(), state.positionStateId(), role, state.symbol(), state.side().name(), + state.positionRatio(), state.averageEntryPrice(), state.currentPrice(), state.unrealizedPnlBps(), + state.liquidationBufferBps(), state.addCount(), state.remainingAddCapacity(), + state.lastAddTime() == null ? null : Timestamp.from(state.lastAddTime())); + } + + void insertAccountState(TraderAccountState state, String role) { + jdbcTemplate.update(""" + insert into trader_account_state + (run_id, cycle_id, account_state_id, state_role, daily_drawdown_bps, + portfolio_exposure_ratio, remaining_symbol_capacity_ratio, consecutive_losses) + values (?, ?, ?, ?, ?, ?, ?, ?) + """, + state.runId(), state.cycleId(), state.accountStateId(), role, state.dailyDrawdownBps(), + state.portfolioExposureRatio(), state.remainingSymbolCapacityRatio(), state.consecutiveLosses()); + } + + void insertExecutionState(TraderExecutionState state, String role) { + jdbcTemplate.update(""" + insert into trader_execution_state + (run_id, cycle_id, execution_state_id, state_role, symbol, open_orders_json, + expected_slippage_bps, exchange_latency_ms, api_error_count, maker_fee_bps, + taker_fee_bps, min_notional, price_tick_size, lot_size_step_size, + market_lot_size_step_size, liquidity_capacity_ratio) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + state.runId(), state.cycleId(), state.executionStateId(), role, state.symbol(), + jsonCodec.toJson(state.openOrders()), state.expectedSlippageBps(), state.exchangeLatencyMs(), + state.apiErrorCount(), state.makerFeeBps(), state.takerFeeBps(), state.minNotional(), + state.priceTickSize(), state.lotSizeStepSize(), state.marketLotSizeStepSize(), + state.liquidityCapacityRatio()); + } + void insertCycle(TraderDecisionCycle cycle, TraderPositionState positionState, TraderRiskDecision riskDecision) { jdbcTemplate.update(""" insert into trader_decision_cycle @@ -83,14 +146,15 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter insert into trader_model_output (run_id, cycle_id, model_output_id, model_bundle_version, calibration_bundle_version, direction_json, entry_json, continue_json, exit_json, risk_json, - uncertainty, ood_score, usable, blocker) - values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + metadata_json, uncertainty, ood_score, usable, blocker) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, modelOutput.runId(), modelOutput.cycleId(), modelOutput.modelOutputId(), - modelOutput.modelBundleVersion(), modelOutput.calibrationBundleVersion(), + modelOutput.metadata().modelBundleVersion(), modelOutput.metadata().calibrationBundleVersion(), jsonCodec.toJson(modelOutput.direction()), jsonCodec.toJson(modelOutput.entry()), jsonCodec.toJson(modelOutput.continuation()), jsonCodec.toJson(modelOutput.exit()), - jsonCodec.toJson(modelOutput.risk()), modelOutput.uncertainty(), modelOutput.oodScore(), + jsonCodec.toJson(modelOutput.risk()), jsonCodec.toJson(modelOutput.metadata()), + modelOutput.metadata().uncertainty(), modelOutput.metadata().oodScore(), snapshot.dataReady(), snapshot.dataReady() ? null : "DATA_NOT_READY"); } diff --git a/src/main/java/com/quantai/trader/persistence/TraderDecisionTraceWriter.java b/src/main/java/com/quantai/trader/persistence/TraderDecisionTraceWriter.java index 5deecbe..4bd0b4f 100644 --- a/src/main/java/com/quantai/trader/persistence/TraderDecisionTraceWriter.java +++ b/src/main/java/com/quantai/trader/persistence/TraderDecisionTraceWriter.java @@ -7,6 +7,8 @@ public interface TraderDecisionTraceWriter { TraderMarketSnapshot snapshot, TraderModelOutput modelOutput, TraderPositionState positionState, + TraderAccountState accountState, + TraderExecutionState executionState, TraderPositionManagerDecision pmDecision, TraderRiskDecision riskDecision, TraderAction action); diff --git a/src/main/java/com/quantai/trader/position/TraderPositionManager.java b/src/main/java/com/quantai/trader/position/TraderPositionManager.java index 008e47e..617a5ea 100644 --- a/src/main/java/com/quantai/trader/position/TraderPositionManager.java +++ b/src/main/java/com/quantai/trader/position/TraderPositionManager.java @@ -33,7 +33,8 @@ public class TraderPositionManager { TraderPmConfig.SizingConfig sizing = input.pmConfig().sizing(); EntryOutput entry = input.modelOutput().entry(); RiskOutput risk = input.modelOutput().risk(); - BigDecimal expectedEdge = entry.expectedEdgeBps().max(BigDecimal.ZERO); + TraderPricePlanContext pricePlan = input.pricePlanContext(); + BigDecimal expectedEdge = entry.netEdgeBpsFor(side).max(BigDecimal.ZERO); if (expectedEdge.compareTo(sizing.minEdgeBps()) < 0) { return BigDecimal.ZERO; } @@ -41,8 +42,8 @@ public class TraderPositionManager { ? input.modelOutput().direction().longProb() : input.modelOutput().direction().shortProb(); BigDecimal entryProb = side.isLong() ? entry.longEntryProb() : entry.shortEntryProb(); - BigDecimal stopLossBudget = entry.stopDistanceBps().add(entry.costBps()).max(BigDecimal.ONE); - BigDecimal uncertaintyPenalty = BigDecimal.ONE.subtract(input.modelOutput().uncertainty() + BigDecimal stopLossBudget = pricePlan.stopDistanceBps().add(pricePlan.costBps()).max(BigDecimal.ONE); + BigDecimal uncertaintyPenalty = BigDecimal.ONE.subtract(input.modelOutput().metadata().uncertainty() .multiply(sizing.uncertaintyPenaltyMultiplier())).max(BigDecimal.ZERO); BigDecimal edgeRiskBudget = TraderNumbers.safeDivide(expectedEdge, stopLossBudget) .multiply(directionStrength) @@ -50,7 +51,7 @@ public class TraderPositionManager { .multiply(BigDecimal.ONE.subtract(risk.marketRiskProb())) .multiply(uncertaintyPenalty); BigDecimal raw = sizing.baseRatio().multiply(edgeRiskBudget); - BigDecimal liquidityCap = risk.liquidityCapacityRatio().multiply(sizing.maxLiquidityUsageRatio()); + BigDecimal liquidityCap = input.executionState().liquidityCapacityRatio().multiply(sizing.maxLiquidityUsageRatio()); BigDecimal lossBudgetCap = TraderNumbers.safeDivide(sizing.maxLossPerTradeBps(), stopLossBudget); BigDecimal hardCap = min(sizing.maxSingleLegRatio(), min(input.accountState().remainingSymbolCapacityRatio(), min(liquidityCap, lossBudgetCap))); if (hardCap.compareTo(sizing.minInitialRatio()) < 0) { @@ -61,7 +62,8 @@ public class TraderPositionManager { public BigDecimal calculateAddRatio(PositionManagerInput input) { BigDecimal raw = input.pmConfig().sizing().baseRatio() - .multiply(input.modelOutput().continuation().continueVsExitEdgeBps().max(BigDecimal.ZERO)) + .multiply(input.modelOutput().continuation() + .continueEdgeBpsFor(input.positionState().side()).max(BigDecimal.ZERO)) .divide(new BigDecimal("100"), java.math.MathContext.DECIMAL64); return TraderNumbers.clamp(raw, input.pmConfig().sizing().minAddRatio(), min(input.pmConfig().sizing().maxAddRatio(), input.positionState().remainingAddCapacity())); @@ -73,20 +75,20 @@ public class TraderPositionManager { RiskOutput risk = input.modelOutput().risk(); TraderPmConfig.OpenRuleConfig open = input.pmConfig().open(); boolean longPass = direction.longProb().compareTo(open.longOpenProb()) > 0 - && direction.directionMargin().compareTo(open.minDirectionMargin()) > 0 + && direction.margin().compareTo(open.minDirectionMargin()) > 0 && entry.longEntryProb().compareTo(open.minLongEntryProb()) > 0 - && entry.expectedEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0 + && entry.longExpectedNetEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0 && risk.marketRiskProb().compareTo(open.maxMarketRiskProb()) < 0 - && risk.liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0 - && input.modelOutput().oodScore().compareTo(open.maxOodScore()) <= 0 + && input.executionState().liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0 + && input.modelOutput().metadata().oodScore().compareTo(open.maxOodScore()) <= 0 && input.executionState().openOrders().isEmpty(); boolean shortPass = direction.shortProb().compareTo(open.shortOpenProb()) > 0 - && direction.directionMargin().compareTo(open.minDirectionMargin()) > 0 + && direction.margin().compareTo(open.minDirectionMargin()) > 0 && entry.shortEntryProb().compareTo(open.minShortEntryProb()) > 0 - && entry.expectedEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0 + && entry.shortExpectedNetEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0 && risk.marketRiskProb().compareTo(open.maxMarketRiskProb()) < 0 - && risk.liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0 - && input.modelOutput().oodScore().compareTo(open.maxOodScore()) <= 0 + && input.executionState().liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0 + && input.modelOutput().metadata().oodScore().compareTo(open.maxOodScore()) <= 0 && input.executionState().openOrders().isEmpty(); BigDecimal longRatio = longPass ? calculateInitialRatio(input, PositionSide.LONG) : BigDecimal.ZERO; BigDecimal shortRatio = shortPass ? calculateInitialRatio(input, PositionSide.SHORT) : BigDecimal.ZERO; @@ -106,7 +108,7 @@ public class TraderPositionManager { } if (shouldReduce(input)) { TraderActionType action = input.positionState().side().isLong() ? TraderActionType.REDUCE_LONG : TraderActionType.REDUCE_SHORT; - return decision(input, action, input.positionState().side(), null, null, new BigDecimal("0.50"), "PROFIT_GIVEBACK_REDUCE"); + return decision(input, action, input.positionState().side(), null, null, new BigDecimal("0.50"), "ADVERSE_MOVE_REDUCE"); } if (shouldAdd(input)) { BigDecimal ratio = calculateAddRatio(input); @@ -130,17 +132,28 @@ public class TraderPositionManager { : input.modelOutput().continuation().shortContinueProb(); return sidePass && continueProb.compareTo(add.minContinueProb()) > 0 - && input.modelOutput().continuation().continueVsExitEdgeBps().compareTo(add.minContinueVsExitEdgeBps()) > 0 - && input.modelOutput().entry().expectedEdgeBps().compareTo(add.minExpectedEdgeBps()) > 0 + && input.modelOutput().continuation().continueEdgeBpsFor(input.positionState().side()) + .compareTo(add.minContinueVsExitEdgeBps()) > 0 + && input.modelOutput().entry().netEdgeBpsFor(input.positionState().side()) + .compareTo(add.minExpectedEdgeBps()) > 0 && input.modelOutput().risk().marketRiskProb().compareTo(add.maxMarketRiskProb()) < 0 - && input.modelOutput().risk().positionRiskProb().compareTo(add.maxPositionRiskProb()) < 0 - && input.modelOutput().risk().liquidityCapacityRatio().compareTo(add.minLiquidityCapacityRatio()) >= 0 + && input.modelOutput().risk().sideRiskProbFor(input.positionState().side()) + .compareTo(add.maxPositionRiskProb()) < 0 + && input.executionState().liquidityCapacityRatio().compareTo(add.minLiquidityCapacityRatio()) >= 0 && input.positionState().unrealizedPnlBps().compareTo(BigDecimal.ZERO) > 0 && input.positionState().addCount() < add.maxAddCount() + && addCooldownPassed(input, add) && input.positionState().remainingAddCapacity().compareTo(BigDecimal.ZERO) > 0 && input.executionState().openOrders().isEmpty(); } + private boolean addCooldownPassed(PositionManagerInput input, TraderPmConfig.AddRuleConfig add) { + if (input.positionState().lastAddTime() == null) { + return true; + } + return !input.cycle().cycleTime().isBefore(input.positionState().lastAddTime().plusSeconds(add.cooldownMinutes() * 60)); + } + private boolean shouldClose(PositionManagerInput input) { TraderPmConfig.ExitRuleConfig exit = input.pmConfig().exit(); boolean exitProb = input.positionState().side().isLong() @@ -151,9 +164,11 @@ public class TraderPositionManager { : input.modelOutput().continuation().shortContinueProb(); return exitProb || continueProb.compareTo(exit.closeContinueMax()) < 0 - || input.modelOutput().risk().positionRiskProb().compareTo(exit.closePositionRiskProb()) > 0 + || input.modelOutput().risk().sideRiskProbFor(input.positionState().side()) + .compareTo(exit.closePositionRiskProb()) > 0 || input.modelOutput().risk().marketRiskProb().compareTo(exit.closeMarketRiskProb()) > 0 - || input.modelOutput().risk().expectedShortfallBps().compareTo(exit.maxExpectedShortfallBps()) > 0; + || input.modelOutput().risk().positionPathRiskBpsFor(input.positionState().side()) + .compareTo(exit.maxPositionPathRiskBps()) > 0; } private boolean shouldReduce(PositionManagerInput input) { @@ -161,7 +176,7 @@ public class TraderPositionManager { BigDecimal continueProb = input.positionState().side().isLong() ? input.modelOutput().continuation().longContinueProb() : input.modelOutput().continuation().shortContinueProb(); - return input.modelOutput().exit().profitGivebackProb().compareTo(exit.reduceGivebackProb()) > 0 + return input.modelOutput().exit().reasonScore("adverse_move_prob").compareTo(exit.reduceAdverseMoveProb()) > 0 && continueProb.compareTo(exit.reduceContinueMin()) >= 0 && continueProb.compareTo(exit.reduceContinueMax()) <= 0 && input.positionState().unrealizedPnlBps().compareTo(exit.minProfitForReduceBps()) > 0; @@ -173,9 +188,9 @@ public class TraderPositionManager { private TraderPositionManagerDecision decision(PositionManagerInput input, TraderActionType action, PositionSide side, BigDecimal targetRatio, BigDecimal addRatio, BigDecimal reduceRatio, String reason) { - EntryOutput entry = input.modelOutput().entry(); - BigDecimal stopPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), entry.stopDistanceBps(), side, false) : null; - BigDecimal targetPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), entry.targetDistanceBps(), side, true) : null; + TraderPricePlanContext pricePlan = input.pricePlanContext(); + BigDecimal stopPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), pricePlan.stopDistanceBps(), side, false) : null; + BigDecimal targetPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), pricePlan.targetDistanceBps(), side, true) : null; return new TraderPositionManagerDecision( "pm_" + input.cycle().cycleId(), input.cycle().runId(), @@ -186,15 +201,17 @@ public class TraderPositionManager { input.executionState().executionStateId(), action, side, - action.increasesExposure() ? entry.pricePlanId() : null, - action.increasesExposure() ? entry.pricePlanConfigHash() : null, + action.increasesExposure() ? pricePlan.pricePlanId() : null, + action.increasesExposure() ? pricePlan.pricePlanConfigHash() : null, targetRatio, addRatio, reduceRatio, stopPrice, targetPrice, reason, - Map.of("pmConfigVersion", input.pmConfig().pmConfigVersion())); + Map.of( + "pmConfigVersion", input.pmConfig().pmConfigVersion(), + "pricePlanMaxHoldMinutes", pricePlan.maxHoldMinutes())); } private BigDecimal priceFromBps(BigDecimal markPrice, BigDecimal distanceBps, PositionSide side, boolean profitTarget) { diff --git a/src/main/java/com/quantai/trader/replay/ReplayMarketEvent.java b/src/main/java/com/quantai/trader/replay/ReplayMarketEvent.java index 3958a03..2ba1450 100644 --- a/src/main/java/com/quantai/trader/replay/ReplayMarketEvent.java +++ b/src/main/java/com/quantai/trader/replay/ReplayMarketEvent.java @@ -1,15 +1,44 @@ package com.quantai.trader.replay; +import static com.quantai.trader.util.TraderNumbers.nonNegative; +import static com.quantai.trader.util.TraderNumbers.positive; +import static com.quantai.trader.util.TraderNumbers.required; +import static com.quantai.trader.util.TraderNumbers.requiredText; + import java.math.BigDecimal; import java.time.Instant; +import java.util.Map; +import java.util.Objects; public record ReplayMarketEvent( String runId, String symbol, Instant eventTime, + String featureVersion, BigDecimal markPrice, BigDecimal indexPrice, BigDecimal spreadBps, - BigDecimal depthNotional5Bps + BigDecimal fundingRateBps, + BigDecimal depthNotional5Bps, + BigDecimal depthNotional10Bps, + BigDecimal depthNotional25Bps, + boolean dataReady, + Map featureJson, + Map dataQualityJson ) { + public ReplayMarketEvent { + runId = requiredText(runId, "runId"); + symbol = requiredText(symbol, "symbol"); + eventTime = Objects.requireNonNull(eventTime, "eventTime is required"); + featureVersion = requiredText(featureVersion, "featureVersion"); + markPrice = positive(markPrice, "markPrice"); + indexPrice = positive(indexPrice, "indexPrice"); + spreadBps = nonNegative(spreadBps, "spreadBps"); + fundingRateBps = required(fundingRateBps, "fundingRateBps"); + depthNotional5Bps = nonNegative(depthNotional5Bps, "depthNotional5Bps"); + depthNotional10Bps = nonNegative(depthNotional10Bps, "depthNotional10Bps"); + depthNotional25Bps = nonNegative(depthNotional25Bps, "depthNotional25Bps"); + featureJson = Map.copyOf(featureJson == null ? Map.of() : featureJson); + dataQualityJson = Map.copyOf(dataQualityJson == null ? Map.of() : dataQualityJson); + } } diff --git a/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java b/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java index b3d4bd9..7f21aa2 100644 --- a/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java +++ b/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java @@ -1,6 +1,7 @@ package com.quantai.trader.replay; import com.quantai.trader.artifact.TraderArtifactBundle; +import com.quantai.trader.artifact.TraderArtifactManifestRepository; import com.quantai.trader.artifact.TraderArtifactLoader; import com.quantai.trader.config.TraderProperties; import com.quantai.trader.domain.*; @@ -9,6 +10,7 @@ import com.quantai.trader.evidence.EvidenceAppender; import com.quantai.trader.model.TraderModelService; import com.quantai.trader.outbox.TraderOutboxRepository; import com.quantai.trader.outbox.TraderOutboxEvent; +import com.quantai.trader.outbox.TraderOutboxDispatcher; import com.quantai.trader.persistence.TraderDecisionTraceWriter; import com.quantai.trader.position.TraderPositionManager; import com.quantai.trader.replay.state.TraderReplayState; @@ -16,11 +18,15 @@ import com.quantai.trader.replay.state.TraderReplayStateStore; import com.quantai.trader.risk.RiskGateInput; import com.quantai.trader.risk.RiskLimits; import com.quantai.trader.risk.TraderRiskGate; +import com.quantai.trader.runtime.TraderRuntimeControlDecision; +import com.quantai.trader.runtime.TraderRuntimeControlService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationManager; -import java.math.BigDecimal; import java.time.Instant; import java.util.Map; @@ -36,8 +42,11 @@ public class TraderP0CycleRunner { private final TraderActionFactory actionFactory; private final EvidenceAppender evidenceAppender; private final TraderDecisionTraceWriter traceWriter; + private final TraderArtifactManifestRepository artifactManifestRepository; private final TraderOutboxRepository outboxRepository; + private final TraderOutboxDispatcher outboxDispatcher; private final TraderReplayStateStore stateStore; + private final TraderRuntimeControlService runtimeControlService; public TraderP0CycleRunner(TraderProperties properties, TraderArtifactLoader artifactLoader, @@ -47,8 +56,11 @@ public class TraderP0CycleRunner { TraderActionFactory actionFactory, EvidenceAppender evidenceAppender, TraderDecisionTraceWriter traceWriter, + TraderArtifactManifestRepository artifactManifestRepository, TraderOutboxRepository outboxRepository, - TraderReplayStateStore stateStore) { + TraderOutboxDispatcher outboxDispatcher, + TraderReplayStateStore stateStore, + TraderRuntimeControlService runtimeControlService) { this.properties = properties; this.artifactLoader = artifactLoader; this.modelService = modelService; @@ -57,47 +69,88 @@ public class TraderP0CycleRunner { this.actionFactory = actionFactory; this.evidenceAppender = evidenceAppender; this.traceWriter = traceWriter; + this.artifactManifestRepository = artifactManifestRepository; this.outboxRepository = outboxRepository; + this.outboxDispatcher = outboxDispatcher; this.stateStore = stateStore; + this.runtimeControlService = runtimeControlService; } + @Transactional public TraderCycleResult runCycle(ReplayMarketEvent event) { String cycleId = "cycle_" + event.runId() + "_" + event.eventTime().toEpochMilli(); TraderArtifactBundle bundle = artifactLoader.loadActiveBundle(); + artifactManifestRepository.upsertActiveBundle(bundle); TraderDecisionCycle cycle = new TraderDecisionCycle(event.runId(), cycleId, event.symbol(), event.eventTime(), properties.runMode(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion()); - TraderMarketSnapshot snapshot = snapshot(event, cycleId); - TraderReplayState state = stateStore.load(cycle, snapshot); - evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of()); - TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle); - evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId())); - PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput, - state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig()); - TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput); - evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name())); - TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(), - pmInput.executionState(), snapshot, riskLimits())); - evidenceAppender.append(cycle.runId(), cycle.cycleId(), "RISK_DECISION", riskDecision.allowAction(), riskDecision.allowAction() ? "RISK_PASS" : riskDecision.blocker(), riskDecision.blocker(), Map.of()); - TraderAction action = actionFactory.create(pmDecision, riskDecision, event.symbol()); - traceWriter.persistCycleTrace(cycle, snapshot, modelOutput, pmInput.positionState(), pmDecision, riskDecision, action); - outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(), - "TRADER_ACTION", action.actionId(), "ACTION_CREATED", properties.runMode().name() + "_RECORDER", - Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now())); - stateStore.advance(state, action, snapshot); - log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING", action.runId(), action.cycleId(), action.actionType()); - return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action); + runtimeControlService.acquireCycleLock(cycle); + try { + TraderMarketSnapshot snapshot = snapshot(event, cycleId); + TraderReplayState state = stateStore.load(cycle, snapshot); + evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of()); + TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle); + evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId())); + PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput, + bundle.pricePlanContext(), state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig()); + TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput); + evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name())); + TraderRuntimeControlDecision runtimeDecision = runtimeControlService.validateExposureIncrease(cycle, pmDecision); + if (runtimeDecision.allowed() + && pmDecision.candidateAction().increasesExposure() + && outboxRepository.hasUnsentExposureIncrease(cycle.runId(), cycle.symbol(), pmDecision.side())) { + runtimeDecision = TraderRuntimeControlDecision.block("OUTBOX_PENDING_OPEN_ADD"); + } + TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(), + pmInput.executionState(), snapshot, riskLimits(runtimeDecision))); + evidenceAppender.append(cycle.runId(), cycle.cycleId(), "RISK_DECISION", riskDecision.allowAction(), riskDecision.allowAction() ? "RISK_PASS" : riskDecision.blocker(), riskDecision.blocker(), Map.of()); + TraderAction action = actionFactory.create(pmDecision, riskDecision, event.symbol()); + traceWriter.persistCycleTrace(cycle, snapshot, modelOutput, pmInput.positionState(), + pmInput.accountState(), pmInput.executionState(), pmDecision, riskDecision, action); + String destination = outboxDestination(); + outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(), + "TRADER_ACTION", action.actionId(), "ACTION_CREATED", destination, + Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now())); + dispatchAfterCommit(action, state, snapshot, destination); + log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING runtimeBlocker={}", + action.runId(), action.cycleId(), action.actionType(), runtimeDecision.blocker()); + return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action); + } finally { + runtimeControlService.releaseCycleLock(cycle); + } } private TraderMarketSnapshot snapshot(ReplayMarketEvent event, String cycleId) { return new TraderMarketSnapshot("snapshot_" + cycleId, event.runId(), cycleId, event.symbol(), event.eventTime(), - "feature-v4-p0", event.markPrice(), event.indexPrice(), event.spreadBps(), BigDecimal.ZERO, - event.depthNotional5Bps(), event.depthNotional5Bps(), event.depthNotional5Bps(), - event.depthNotional5Bps().compareTo(BigDecimal.ZERO) > 0, Map.of(), Map.of()); + event.featureVersion(), event.markPrice(), event.indexPrice(), event.spreadBps(), event.fundingRateBps(), + event.depthNotional5Bps(), event.depthNotional10Bps(), event.depthNotional25Bps(), + event.dataReady(), event.featureJson(), event.dataQualityJson()); } - private RiskLimits riskLimits() { + private RiskLimits riskLimits(TraderRuntimeControlDecision runtimeDecision) { return new RiskLimits(properties.risk().maxDailyLossBps(), properties.risk().maxTotalExposureRatio(), properties.risk().minLiquidationBufferBps(), properties.execution().maxApiErrorCount(), - properties.execution().maxExchangeLatencyMs(), false, false); + properties.execution().maxExchangeLatencyMs(), false, !runtimeDecision.allowed(), runtimeDecision.blocker()); + } + + private String outboxDestination() { + return switch (properties.execution().mode()) { + case REPLAY_SIM -> "REPLAY_SIM_EXECUTION"; + case SHADOW -> "SHADOW_RECORDER"; + case PAPER, REAL -> throw new IllegalStateException("P0 runtime guard must reject PAPER/REAL execution mode"); + }; + } + + private void dispatchAfterCommit(TraderAction action, TraderReplayState state, + TraderMarketSnapshot snapshot, String destination) { + if (!TransactionSynchronizationManager.isSynchronizationActive()) { + outboxDispatcher.dispatch(action, state, snapshot, destination); + return; + } + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() { + @Override + public void afterCommit() { + outboxDispatcher.dispatch(action, state, snapshot, destination); + } + }); } } diff --git a/src/main/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepository.java b/src/main/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepository.java new file mode 100644 index 0000000..40dd25d --- /dev/null +++ b/src/main/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepository.java @@ -0,0 +1,70 @@ +package com.quantai.trader.replay.state; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.quantai.trader.domain.TraderAccountState; +import com.quantai.trader.domain.TraderExecutionState; +import com.quantai.trader.domain.TraderPositionState; +import com.quantai.trader.persistence.TraderJsonCodec; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Repository; + +import java.sql.Timestamp; + +@Repository +public class JdbcTraderPostActionStateRepository implements TraderPostActionStateRepository { + private final JdbcTemplate jdbcTemplate; + private final TraderJsonCodec jsonCodec; + + public JdbcTraderPostActionStateRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) { + this.jdbcTemplate = jdbcTemplate; + this.jsonCodec = new TraderJsonCodec(objectMapper); + } + + @Override + public void insertPostActionState(TraderReplayState state) { + insertPositionState(state.positionState()); + insertAccountState(state.accountState()); + insertExecutionState(state.executionState()); + } + + private void insertPositionState(TraderPositionState state) { + jdbcTemplate.update(""" + insert into trader_position_state + (run_id, cycle_id, position_state_id, state_role, symbol, side, position_ratio, + average_entry_price, current_price, unrealized_pnl_bps, liquidation_buffer_bps, + add_count, remaining_add_capacity, last_add_time) + values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + state.runId(), state.cycleId(), state.positionStateId() + "_post", state.symbol(), + state.side().name(), state.positionRatio(), state.averageEntryPrice(), state.currentPrice(), + state.unrealizedPnlBps(), state.liquidationBufferBps(), state.addCount(), + state.remainingAddCapacity(), state.lastAddTime() == null ? null : Timestamp.from(state.lastAddTime())); + } + + private void insertAccountState(TraderAccountState state) { + jdbcTemplate.update(""" + insert into trader_account_state + (run_id, cycle_id, account_state_id, state_role, daily_drawdown_bps, + portfolio_exposure_ratio, remaining_symbol_capacity_ratio, consecutive_losses) + values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?) + """, + state.runId(), state.cycleId(), state.accountStateId() + "_post", state.dailyDrawdownBps(), + state.portfolioExposureRatio(), state.remainingSymbolCapacityRatio(), state.consecutiveLosses()); + } + + private void insertExecutionState(TraderExecutionState state) { + jdbcTemplate.update(""" + insert into trader_execution_state + (run_id, cycle_id, execution_state_id, state_role, symbol, open_orders_json, + expected_slippage_bps, exchange_latency_ms, api_error_count, maker_fee_bps, + taker_fee_bps, min_notional, price_tick_size, lot_size_step_size, + market_lot_size_step_size, liquidity_capacity_ratio) + values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + state.runId(), state.cycleId(), state.executionStateId() + "_post", state.symbol(), + jsonCodec.toJson(state.openOrders()), state.expectedSlippageBps(), state.exchangeLatencyMs(), + state.apiErrorCount(), state.makerFeeBps(), state.takerFeeBps(), state.minNotional(), + state.priceTickSize(), state.lotSizeStepSize(), state.marketLotSizeStepSize(), + state.liquidityCapacityRatio()); + } +} diff --git a/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java b/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java index be09775..91a5ca4 100644 --- a/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java +++ b/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java @@ -11,7 +11,6 @@ import org.springframework.stereotype.Component; import java.math.BigDecimal; import java.math.MathContext; -import java.time.Instant; import java.util.List; import java.util.concurrent.ConcurrentHashMap; @@ -88,7 +87,7 @@ public class P0ReplayStateStore implements TraderReplayStateStore { BigDecimal weightedEntry = weightedEntry(current.averageEntryPrice(), current.positionRatio(), snapshot.markPrice(), addRatio); return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(), current.side(), newRatio, weightedEntry, snapshot.markPrice(), pnlBps(current.side(), weightedEntry, snapshot.markPrice()), - current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), Instant.now()); + current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), snapshot.snapshotTime()); } private TraderPositionState reducePosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) { diff --git a/src/main/java/com/quantai/trader/replay/state/TraderPostActionStateRepository.java b/src/main/java/com/quantai/trader/replay/state/TraderPostActionStateRepository.java new file mode 100644 index 0000000..f788349 --- /dev/null +++ b/src/main/java/com/quantai/trader/replay/state/TraderPostActionStateRepository.java @@ -0,0 +1,5 @@ +package com.quantai.trader.replay.state; + +public interface TraderPostActionStateRepository { + void insertPostActionState(TraderReplayState state); +} diff --git a/src/main/java/com/quantai/trader/risk/RiskLimits.java b/src/main/java/com/quantai/trader/risk/RiskLimits.java index fbc5914..8a9c1d3 100644 --- a/src/main/java/com/quantai/trader/risk/RiskLimits.java +++ b/src/main/java/com/quantai/trader/risk/RiskLimits.java @@ -9,6 +9,7 @@ public record RiskLimits( int maxApiErrorCount, long maxExchangeLatencyMs, boolean killSwitchActive, - boolean executionBlocked + boolean executionBlocked, + String executionBlocker ) { } diff --git a/src/main/java/com/quantai/trader/risk/TraderRiskGate.java b/src/main/java/com/quantai/trader/risk/TraderRiskGate.java index ec8d666..f731949 100644 --- a/src/main/java/com/quantai/trader/risk/TraderRiskGate.java +++ b/src/main/java/com/quantai/trader/risk/TraderRiskGate.java @@ -17,7 +17,7 @@ public class TraderRiskGate { if (input.riskLimits().killSwitchActive() && input.pmDecision().candidateAction().increasesExposure()) { decision = block(input, "KILL_SWITCH_ACTIVE"); } else if (input.riskLimits().executionBlocked()) { - decision = block(input, "EXECUTION_BLOCKED"); + decision = block(input, input.riskLimits().executionBlocker()); } else if (input.accountState().dailyDrawdownBps().compareTo(input.riskLimits().maxDailyLossBps()) >= 0) { decision = block(input, "MAX_DAILY_LOSS"); } else if (input.accountState().portfolioExposureRatio().compareTo(input.riskLimits().maxTotalExposureRatio()) >= 0 @@ -46,6 +46,9 @@ public class TraderRiskGate { } private TraderRiskDecision block(RiskGateInput input, String blocker) { + if (blocker == null || blocker.isBlank()) { + blocker = "EXECUTION_BLOCKED"; + } TraderActionType finalAction = input.pmDecision().candidateAction().increasesExposure() ? TraderActionType.WAIT : input.pmDecision().candidateAction(); return new TraderRiskDecision("risk_" + input.pmDecision().cycleId(), input.pmDecision().runId(), input.pmDecision().cycleId(), input.pmDecision().pmDecisionId(), false, input.pmDecision().candidateAction(), finalAction, blocker, diff --git a/src/main/java/com/quantai/trader/runtime/RedisTraderRuntimeControlService.java b/src/main/java/com/quantai/trader/runtime/RedisTraderRuntimeControlService.java new file mode 100644 index 0000000..834c203 --- /dev/null +++ b/src/main/java/com/quantai/trader/runtime/RedisTraderRuntimeControlService.java @@ -0,0 +1,93 @@ +package com.quantai.trader.runtime; + +import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.domain.TraderDecisionCycle; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.domain.TraderPositionManagerDecision; +import com.quantai.trader.enums.TraderErrorCode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.stereotype.Component; + +import java.time.Duration; + +@Component +public class RedisTraderRuntimeControlService implements TraderRuntimeControlService { + private static final Logger log = LoggerFactory.getLogger(RedisTraderRuntimeControlService.class); + + private final TraderProperties properties; + private final StringRedisTemplate redisTemplate; + + public RedisTraderRuntimeControlService(TraderProperties properties, StringRedisTemplate redisTemplate) { + this.properties = properties; + this.redisTemplate = redisTemplate; + } + + @Override + public void acquireCycleLock(TraderDecisionCycle cycle) { + String key = key("cycle:lock:" + cycle.symbol() + ":" + cycle.cycleTime().toEpochMilli()); + try { + Boolean acquired = redisTemplate.opsForValue().setIfAbsent(key, cycle.cycleId(), Duration.ofSeconds(30)); + if (!Boolean.TRUE.equals(acquired)) { + throw new TraderException(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED, + "cycle lock already exists: " + key); + } + log.info("event=trader.runtime.lock_acquired runId={} cycleId={} key={}", + cycle.runId(), cycle.cycleId(), key); + } catch (TraderException exception) { + throw exception; + } catch (RuntimeException exception) { + throw new TraderException(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED, + "redis unavailable while acquiring cycle lock"); + } + } + + @Override + public void releaseCycleLock(TraderDecisionCycle cycle) { + String key = key("cycle:lock:" + cycle.symbol() + ":" + cycle.cycleTime().toEpochMilli()); + try { + redisTemplate.delete(key); + log.info("event=trader.runtime.lock_released runId={} cycleId={} key={}", + cycle.runId(), cycle.cycleId(), key); + } catch (RuntimeException exception) { + log.warn("event=trader.runtime.lock_release_failed runId={} cycleId={} key={} blocker=REDIS_UNAVAILABLE", + cycle.runId(), cycle.cycleId(), key); + } + } + + @Override + public TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle, + TraderPositionManagerDecision decision) { + if (!decision.candidateAction().increasesExposure()) { + return TraderRuntimeControlDecision.allow(); + } + try { + String activePointer = redisTemplate.opsForValue().get(key("model:active:" + cycle.symbol())); + String expectedPointer = cycle.modelBundleVersion() + "|" + cycle.calibrationBundleVersion() + "|" + cycle.pmConfigVersion(); + if (properties.release().activePointerCheckEnabled() && activePointer == null) { + return TraderRuntimeControlDecision.block("ACTIVE_POINTER_MISSING"); + } + if (properties.release().activePointerCheckEnabled() && !expectedPointer.equals(activePointer)) { + return TraderRuntimeControlDecision.block("ACTIVE_POINTER_MISMATCH"); + } + if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:kill-switch:global")))) { + return TraderRuntimeControlDecision.block("KILL_SWITCH_ACTIVE"); + } + if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:kill-switch:symbol:" + cycle.symbol())))) { + return TraderRuntimeControlDecision.block("KILL_SWITCH_ACTIVE"); + } + if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:execution:close-only:" + cycle.symbol())))) { + return TraderRuntimeControlDecision.block("CLOSE_ONLY_ACTIVE"); + } + redisTemplate.opsForValue().set(key("runtime:open-add-probe:" + cycle.cycleId()), "1", Duration.ofSeconds(30)); + return TraderRuntimeControlDecision.allow(); + } catch (RuntimeException exception) { + return TraderRuntimeControlDecision.block("REDIS_UNAVAILABLE"); + } + } + + private String key(String suffix) { + return properties.runtime().redisKeyPrefix() + ":" + suffix; + } +} diff --git a/src/main/java/com/quantai/trader/runtime/TraderRuntimeControlDecision.java b/src/main/java/com/quantai/trader/runtime/TraderRuntimeControlDecision.java new file mode 100644 index 0000000..5bdc62c --- /dev/null +++ b/src/main/java/com/quantai/trader/runtime/TraderRuntimeControlDecision.java @@ -0,0 +1,14 @@ +package com.quantai.trader.runtime; + +public record TraderRuntimeControlDecision( + boolean allowed, + String blocker +) { + public static TraderRuntimeControlDecision allow() { + return new TraderRuntimeControlDecision(true, null); + } + + public static TraderRuntimeControlDecision block(String blocker) { + return new TraderRuntimeControlDecision(false, blocker); + } +} diff --git a/src/main/java/com/quantai/trader/runtime/TraderRuntimeControlService.java b/src/main/java/com/quantai/trader/runtime/TraderRuntimeControlService.java new file mode 100644 index 0000000..4a11afe --- /dev/null +++ b/src/main/java/com/quantai/trader/runtime/TraderRuntimeControlService.java @@ -0,0 +1,13 @@ +package com.quantai.trader.runtime; + +import com.quantai.trader.domain.TraderDecisionCycle; +import com.quantai.trader.domain.TraderPositionManagerDecision; + +public interface TraderRuntimeControlService { + void acquireCycleLock(TraderDecisionCycle cycle); + + void releaseCycleLock(TraderDecisionCycle cycle); + + TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle, + TraderPositionManagerDecision decision); +} diff --git a/src/main/java/com/quantai/trader/util/TraderNumbers.java b/src/main/java/com/quantai/trader/util/TraderNumbers.java index 726c14a..cb8750c 100644 --- a/src/main/java/com/quantai/trader/util/TraderNumbers.java +++ b/src/main/java/com/quantai/trader/util/TraderNumbers.java @@ -3,7 +3,6 @@ package com.quantai.trader.util; import java.math.BigDecimal; import java.math.MathContext; import java.util.Collection; -import java.util.Objects; public final class TraderNumbers { public static final BigDecimal ZERO = BigDecimal.ZERO; @@ -14,7 +13,10 @@ public final class TraderNumbers { } public static BigDecimal required(BigDecimal value, String field) { - return Objects.requireNonNull(value, field + " is required"); + if (value == null) { + throw new IllegalArgumentException(field + " is required"); + } + return value; } public static String requiredText(String value, String field) { @@ -68,6 +70,9 @@ public final class TraderNumbers { } public static Collection requiredCollection(Collection value, String field) { - return Objects.requireNonNull(value, field + " is required"); + if (value == null) { + throw new IllegalArgumentException(field + " is required"); + } + return value; } } diff --git a/src/main/resources/db/migration/V1__trader_v4_p0_schema.sql b/src/main/resources/db/migration/V1__trader_v4_p0_schema.sql index 99035ca..41a27bd 100644 --- a/src/main/resources/db/migration/V1__trader_v4_p0_schema.sql +++ b/src/main/resources/db/migration/V1__trader_v4_p0_schema.sql @@ -33,22 +33,111 @@ create table trader_decision_cycle ( ); create table trader_model_bundle_manifest ( - id bigint primary key auto_increment, - model_bundle_version varchar(96) not null, - calibration_bundle_version varchar(96) not null, - feature_version varchar(96) not null, - label_version varchar(96) not null, - split_version varchar(96) not null, - required_models_json json not null, - provided_models_json json not null, - missing_models_json json not null, - bundle_hash_sha256 varchar(64) not null, - complete boolean not null, - status varchar(32) not null, - created_at datetime(3) not null default current_timestamp(3), + id bigint primary key auto_increment comment '自增主键', + manifest_schema_version varchar(32) not null comment '清单格式版本,代码启动时必须校验', + model_bundle_version varchar(96) not null comment '五个小模型组成的一整包模型版本', + calibration_bundle_version varchar(96) not null comment '这一包模型对应的校准器版本', + feature_version varchar(96) not null comment '训练和运行使用的特征版本', + label_version varchar(96) not null comment '训练标签版本', + split_version varchar(96) not null comment '训练、验证、测试切分版本', + training_run_id varchar(128) not null comment '训练运行编号,用来回查训练日志', + training_export_id varchar(128) not null comment '训练产物导出编号', + backtest_manifest_id varchar(128) not null comment '对应的回测报告编号', + required_models_json json not null comment '必须具备的模型类型', + provided_models_json json not null comment '实际提供的模型类型', + missing_models_json json not null comment '缺失的模型类型,运行时必须为空数组', + allowed_run_modes_json json not null comment '允许使用的运行模式,P0 只能包含 REPLAY_SIM、SHADOW', + bundle_hash_sha256 varchar(64) not null comment '整包文件指纹', + complete boolean not null comment '模型包是否完整', + status varchar(32) not null comment '状态:候选、启用、拒绝、下线', + created_at datetime(3) not null default current_timestamp(3) comment '创建时间', + updated_at datetime(3) not null default current_timestamp(3) on update current_timestamp(3) comment '更新时间', unique key uk_trader_model_bundle (model_bundle_version, calibration_bundle_version), + key idx_trader_model_bundle_status (status, created_at), + key idx_trader_model_bundle_training (training_run_id, training_export_id), constraint chk_trader_model_bundle_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED')) -); +) comment='Trader 模型整包清单,控制五个小模型是否齐全、是否允许进入运行'; + +create table trader_model_manifest ( + id bigint primary key auto_increment comment '自增主键', + model_bundle_version varchar(96) not null comment '所属模型整包版本', + calibration_bundle_version varchar(96) not null comment '对应校准器整包版本', + model_name varchar(96) not null comment '小模型稳定名字', + model_type varchar(32) not null comment '小模型类型,只允许 DIRECTION、ENTRY、CONTINUE、EXIT、RISK', + side varchar(16) not null comment '适用方向:LONG、SHORT、BOTH', + symbol_scope_json json not null comment '适用交易对范围', + bar_interval varchar(16) not null comment '输入行情周期', + horizon_minutes int not null comment '预测周期,单位分钟', + model_format varchar(32) not null comment '模型文件格式,首版固定 ONNX', + model_runtime varchar(64) not null comment 'Java 侧加载方式,首版固定 ONNX_RUNTIME_JAVA', + model_runtime_version varchar(64) not null comment '运行库版本', + onnx_opset_version int not null comment 'ONNX 算子版本', + producer_name varchar(128) not null comment '模型导出工具名称', + producer_version varchar(64) not null comment '模型导出工具版本', + artifact_path varchar(512) not null comment '模型文件相对 artifact root 的路径', + artifact_hash_sha256 varchar(64) not null comment '模型文件指纹', + source_hash varchar(64) not null comment '训练代码和关键训练配置指纹', + feature_version varchar(96) not null comment '特征版本', + feature_schema_path varchar(512) not null comment '输入特征清单文件路径', + feature_schema_hash varchar(64) not null comment '输入特征清单指纹', + feature_order_path varchar(512) not null comment '输入特征顺序文件路径', + feature_order_hash varchar(64) not null comment '输入特征顺序指纹', + input_tensor_name varchar(128) not null comment 'ONNX 输入张量名', + input_dtype varchar(32) not null comment '输入数字类型', + input_shape_json json not null comment '输入形状', + input_example_path varchar(512) not null comment '固定样例输入文件,用于启动冒烟检查', + output_schema_path varchar(512) not null comment '输出字段说明文件路径', + output_schema_hash varchar(64) not null comment '输出字段说明指纹', + output_tensor_names_json json not null comment 'ONNX 原始输出张量名', + output_mapping_json json not null comment '原始输出到业务字段的映射', + output_value_rules_json json not null comment '输出合法范围', + label_version varchar(96) not null comment '标签版本', + split_version varchar(96) not null comment '训练切分版本', + training_fold varchar(64) not null comment '训练折编号', + train_start datetime(3) not null comment '训练数据开始时间', + train_end datetime(3) not null comment '训练数据结束时间', + validation_start datetime(3) not null comment '验证数据开始时间', + validation_end datetime(3) not null comment '验证数据结束时间', + test_start datetime(3) not null comment '测试数据开始时间', + test_end datetime(3) not null comment '测试数据结束时间', + metrics_json json not null comment '模型评估结果', + status varchar(32) not null comment '状态:候选、启用、拒绝、下线', + created_at datetime(3) not null default current_timestamp(3) comment '创建时间', + updated_at datetime(3) not null default current_timestamp(3) on update current_timestamp(3) comment '更新时间', + unique key uk_trader_model_manifest (model_bundle_version, calibration_bundle_version, model_name, side, horizon_minutes), + key idx_trader_model_manifest_type (model_type, status), + key idx_trader_model_manifest_bundle (model_bundle_version, calibration_bundle_version), + constraint chk_trader_model_type check (model_type in ('DIRECTION','ENTRY','CONTINUE','EXIT','RISK')), + constraint chk_trader_model_side check (side in ('LONG','SHORT','BOTH')), + constraint chk_trader_model_format check (model_format in ('ONNX')), + constraint chk_trader_model_runtime check (model_runtime in ('ONNX_RUNTIME_JAVA')), + constraint chk_trader_model_input_dtype check (input_dtype in ('FLOAT32')), + constraint chk_trader_model_horizon check (horizon_minutes > 0), + constraint chk_trader_model_manifest_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED')) +) comment='Trader 单个小模型清单,说明模型如何加载、输入是什么、输出是什么'; + +create table trader_calibration_manifest ( + id bigint primary key auto_increment comment '自增主键', + calibration_bundle_version varchar(96) not null comment '校准器整包版本', + model_bundle_version varchar(96) not null comment '对应模型整包版本', + model_name varchar(96) not null comment '对应小模型名字', + calibrator_version varchar(96) not null comment '单个校准器版本', + calibration_method varchar(64) not null comment '校准方法', + calibrator_path varchar(512) not null comment '校准文件路径', + calibrator_hash_sha256 varchar(64) not null comment '校准文件指纹', + calibration_window_from datetime(3) not null comment '校准数据开始时间', + calibration_window_to datetime(3) not null comment '校准数据结束时间', + calibration_metrics_json json not null comment '校准效果指标', + bucket_metrics_json json not null comment '分桶校准明细', + output_after_calibration_schema_hash varchar(64) not null comment '校准后输出格式指纹', + status varchar(32) not null comment '状态:候选、启用、拒绝、下线', + created_at datetime(3) not null default current_timestamp(3) comment '创建时间', + updated_at datetime(3) not null default current_timestamp(3) on update current_timestamp(3) comment '更新时间', + unique key uk_trader_calibration_manifest (calibration_bundle_version, model_bundle_version, model_name), + key idx_trader_calibration_bundle (model_bundle_version, calibration_bundle_version), + constraint chk_trader_calibration_method check (calibration_method in ('ISOTONIC','PLATT','BINNING','NONE')), + constraint chk_trader_calibration_manifest_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED')) +) comment='Trader 校准器清单,说明每个小模型的原始输出如何变成可用概率或分数'; create table trader_pm_config_manifest ( id bigint primary key auto_increment, @@ -65,6 +154,90 @@ create table trader_pm_config_manifest ( constraint chk_trader_pm_config_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED')) ); +create table trader_market_snapshot ( + id bigint primary key auto_increment, + run_id varchar(64) not null, + cycle_id varchar(128) not null, + snapshot_id varchar(128) not null, + symbol varchar(32) not null, + snapshot_time datetime(3) not null, + feature_version varchar(96) not null, + mark_price decimal(28,10) not null, + index_price decimal(28,10) not null, + spread_bps decimal(18,8) not null, + funding_rate_bps decimal(18,8) not null, + depth_notional_5bps decimal(28,10) not null, + depth_notional_10bps decimal(28,10) not null, + depth_notional_25bps decimal(28,10) not null, + data_ready boolean not null, + feature_json json not null, + data_quality_json json not null, + created_at datetime(3) not null default current_timestamp(3), + unique key uk_trader_market_snapshot (run_id, snapshot_id), + key idx_trader_market_snapshot_cycle (run_id, cycle_id) +); + +create table trader_position_state ( + id bigint primary key auto_increment, + run_id varchar(64) not null, + cycle_id varchar(128) not null, + position_state_id varchar(128) not null, + state_role varchar(32) not null, + symbol varchar(32) not null, + side varchar(16) not null, + position_ratio decimal(18,8) not null, + average_entry_price decimal(28,10) null, + current_price decimal(28,10) not null, + unrealized_pnl_bps decimal(18,8) not null, + liquidation_buffer_bps decimal(18,8) not null, + add_count int not null, + remaining_add_capacity decimal(18,8) not null, + last_add_time datetime(3) null, + created_at datetime(3) not null default current_timestamp(3), + unique key uk_trader_position_state (run_id, position_state_id), + key idx_trader_position_latest (run_id, symbol, state_role, created_at), + constraint chk_trader_position_side check (side in ('NONE','LONG','SHORT')), + constraint chk_trader_position_role check (state_role in ('PM_INPUT','POST_ACTION','FEEDBACK_UPDATE')) +); + +create table trader_account_state ( + id bigint primary key auto_increment, + run_id varchar(64) not null, + cycle_id varchar(128) not null, + account_state_id varchar(128) not null, + state_role varchar(32) not null, + daily_drawdown_bps decimal(18,8) not null, + portfolio_exposure_ratio decimal(18,8) not null, + remaining_symbol_capacity_ratio decimal(18,8) not null, + consecutive_losses int not null, + created_at datetime(3) not null default current_timestamp(3), + unique key uk_trader_account_state (run_id, account_state_id), + constraint chk_trader_account_role check (state_role in ('PM_INPUT','POST_ACTION','FEEDBACK_UPDATE')) +); + +create table trader_execution_state ( + id bigint primary key auto_increment, + run_id varchar(64) not null, + cycle_id varchar(128) not null, + execution_state_id varchar(128) not null, + state_role varchar(32) not null, + symbol varchar(32) not null, + open_orders_json json not null, + expected_slippage_bps decimal(18,8) not null, + exchange_latency_ms bigint not null, + api_error_count int not null, + maker_fee_bps decimal(18,8) not null, + taker_fee_bps decimal(18,8) not null, + min_notional decimal(28,10) not null, + price_tick_size decimal(28,10) not null, + lot_size_step_size decimal(28,10) not null, + market_lot_size_step_size decimal(28,10) not null, + liquidity_capacity_ratio decimal(18,8) not null, + created_at datetime(3) not null default current_timestamp(3), + unique key uk_trader_execution_state (run_id, execution_state_id), + constraint chk_trader_execution_role check (state_role in ('PM_INPUT','POST_ACTION','FEEDBACK_UPDATE')) +); + create table trader_model_output ( id bigint primary key auto_increment, run_id varchar(64) not null, @@ -77,6 +250,7 @@ create table trader_model_output ( continue_json json not null, exit_json json not null, risk_json json not null, + metadata_json json not null, uncertainty decimal(18,8) not null, ood_score decimal(18,8) not null, usable boolean not null, diff --git a/src/test/java/com/quantai/trader/TestFixtures.java b/src/test/java/com/quantai/trader/TestFixtures.java index 47cf8a7..b2fbd36 100644 --- a/src/test/java/com/quantai/trader/TestFixtures.java +++ b/src/test/java/com/quantai/trader/TestFixtures.java @@ -8,13 +8,17 @@ import com.quantai.trader.enums.PositionSide; import com.quantai.trader.enums.TraderActionType; import com.quantai.trader.enums.TraderExecutionMode; import com.quantai.trader.enums.TraderRunMode; +import com.quantai.trader.replay.ReplayMarketEvent; import com.quantai.trader.risk.RiskLimits; import java.math.BigDecimal; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.time.Instant; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -37,14 +41,19 @@ public final class TestFixtures { } public static TraderProperties propertiesWithArtifactRoot(Path artifactRoot) { + return propertiesWithArtifactRoot(artifactRoot, TraderRunMode.SHADOW, TraderExecutionMode.SHADOW); + } + + public static TraderProperties propertiesWithArtifactRoot(Path artifactRoot, TraderRunMode runMode, + TraderExecutionMode executionMode) { TraderProperties base = properties(); return new TraderProperties( base.serviceName(), - base.runMode(), + runMode, base.symbol(), new TraderProperties.Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", artifactRoot.toString()), base.feedback(), - base.execution(), + new TraderProperties.Execution(executionMode, base.execution().maxApiErrorCount(), base.execution().maxExchangeLatencyMs()), base.runtime(), base.outbox(), base.release(), @@ -60,7 +69,7 @@ public final class TestFixtures { "BTC-USDT-PERP", new TraderProperties.Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", "/tmp/trader-v4-p0"), new TraderProperties.Feedback(feedbackHttpEnabled), - new TraderProperties.Execution(executionMode, 3, 1500), + new TraderProperties.Execution(executionMode, 3, 1500L), new TraderProperties.Runtime("trader:v4:test", true, tradingEnabled), new TraderProperties.Outbox(true, 5), new TraderProperties.Release(true, true, true), @@ -98,27 +107,85 @@ public final class TestFixtures { public static TraderMarketSnapshot snapshot(boolean dataReady, String depthNotional5Bps) { return new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", "BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), BigDecimal.ZERO, - bd(depthNotional5Bps), bd(depthNotional5Bps), bd(depthNotional5Bps), dataReady, Map.of(), Map.of()); + bd(depthNotional5Bps), bd(depthNotional5Bps), bd(depthNotional5Bps), dataReady, + featureJson(), dataQualityJson()); + } + + public static ReplayMarketEvent replayEvent(String runId, Instant eventTime, String markPrice, + String indexPrice, String depthNotional) { + BigDecimal depth = bd(depthNotional); + return new ReplayMarketEvent(runId, "BTC-USDT-PERP", eventTime, "feature-v4-p0", + bd(markPrice), bd(indexPrice), bd("1.2"), bd("0.5"), + depth, depth.multiply(new BigDecimal("1.4")), depth.multiply(new BigDecimal("2.2")), + depth.compareTo(BigDecimal.ZERO) > 0, featureJson(), dataQualityJson()); + } + + public static Map dataQualityJson() { + return Map.of("ood_score", bd("0.05"), "data_quality_flag", "OK"); + } + + public static Map featureJson() { + Map features = new LinkedHashMap<>(); + features.put("ret_1m_bps", bd("1.1")); + features.put("ret_5m_bps", bd("4.5")); + features.put("ret_15m_bps", bd("8.2")); + features.put("ret_60m_bps", bd("18.6")); + features.put("ret_240m_bps", bd("42.0")); + features.put("realized_vol_15m_bps", bd("9.4")); + features.put("realized_vol_60m_bps", bd("12.7")); + features.put("vol_ratio_15m_60m", bd("0.74")); + features.put("range_15m_bps", bd("21.2")); + features.put("range_60m_bps", bd("55.4")); + features.put("volume_zscore_60m", bd("0.35")); + features.put("trend_consistency_15m", bd("0.20")); + features.put("channel_position_60m_pct", bd("0.62")); + features.put("upper_breakout_60m_bps", bd("0.0")); + features.put("lower_breakout_60m_bps", bd("0.0")); + features.put("upper_failed_break_reclaim_15m_bps", bd("0.0")); + features.put("lower_failed_break_reclaim_15m_bps", bd("0.0")); + features.put("sweep_up_15m_bps", bd("6.8")); + features.put("sweep_down_15m_bps", bd("5.1")); + features.put("compression_score_4h_pct", bd("0.30")); + features.put("compression_release_15m_bps", bd("2.4")); + features.put("taker_imbalance_1m", bd("0.08")); + features.put("taker_imbalance_5m", bd("0.12")); + features.put("taker_imbalance_15m", bd("0.04")); + features.put("level1_ofi_1m", bd("0.10")); + features.put("spread_bps", bd("1.2")); + features.put("spread_rank_24h_pct", bd("0.40")); + features.put("oi_delta_15m_bps", bd("3.2")); + features.put("oi_delta_60m_bps", bd("6.5")); + features.put("funding_bps", bd("0.5")); + features.put("mark_index_basis_bps", bd("5.0")); + features.put("liquidation_buy_notional_1m", bd("1000")); + features.put("liquidation_sell_notional_1m", bd("800")); + features.put("liquidation_imbalance_15m", bd("0.10")); + features.put("liquidation_notional_zscore_15m", bd("0.25")); + features.put("liquidation_available", bd("1")); + features.put("minute_of_day_sin", bd("0.0")); + features.put("minute_of_day_cos", bd("1.0")); + features.put("minutes_to_next_funding", bd("120")); + return Map.copyOf(features); } public static TraderModelOutput modelOutput() { return modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34", - "0.20", "0.50", "0.18", "0.12", "1.0", "12", "4", "0.10", "0.05", "20"); + "0.20", "0.50", "0.18", "0.12", "12", "4", "0.10", "0.05", "20"); } public static TraderModelOutput shortModelOutput() { return modelOutput("0.20", "0.70", "0.30", "0.70", "0.34", "0.66", - "0.50", "0.20", "0.18", "0.12", "1.0", "12", "4", "0.10", "0.05", "20"); + "0.50", "0.20", "0.18", "0.12", "12", "4", "0.10", "0.05", "20"); } public static TraderModelOutput modelOutput(String longProb, String shortProb, String longEntryProb, String shortEntryProb, String longContinueProb, String shortContinueProb, String longExitProb, String shortExitProb, - String marketRiskProb, String positionRiskProb, - String liquidityCapacityRatio, String expectedEdgeBps, - String continueVsExitEdgeBps, String uncertainty, - String oodScore, String expectedShortfallBps) { + String marketRiskProb, String sidePositionRiskProb, + String netEdgeBps, String continueEdgeBps, + String uncertainty, String oodScore, + String positionPathRiskBps) { BigDecimal longP = bd(longProb); BigDecimal shortP = bd(shortProb); BigDecimal neutral = BigDecimal.ONE.subtract(longP).subtract(shortP); @@ -126,23 +193,53 @@ public final class TestFixtures { "model-output-1", "run-1", "cycle-1", - "trader-v4-btc-p0", - "cal-v4-btc-p0", - new DirectionOutput(longP, shortP, neutral, longP.max(shortP), longP.subtract(shortP).abs(), - bd("8.0"), 45, "direction-p0", "cal-v4-btc-p0", Map.of()), - new EntryOutput(bd(longEntryProb), bd(shortEntryProb), bd("0.70"), bd(expectedEdgeBps), - "p0-plan-atr-2r", "p0-price-plan-hash", bd("35"), bd("70"), 45, bd("4.0"), - "entry-p0", "cal-v4-btc-p0", Map.of()), - new ContinueOutput(bd(longContinueProb), bd(shortContinueProb), bd("0.60"), bd("5.0"), - bd(continueVsExitEdgeBps), "continue-p0", "cal-v4-btc-p0", Map.of()), - new ExitOutput(bd(longExitProb), bd(shortExitProb), bd("0.20"), bd("0.25"), bd("0.22"), - bd("0.20"), bd("10"), "exit-p0", "cal-v4-btc-p0", Map.of()), - new RiskOutput(bd(marketRiskProb), bd(positionRiskProb), bd("20"), bd("18"), bd("0.15"), - bd(expectedShortfallBps), bd("0.20"), bd("0.10"), bd("0.12"), - bd(liquidityCapacityRatio), "risk-p0", "cal-v4-btc-p0", Map.of()), - bd(uncertainty), - bd(oodScore), - Map.of("fixture", "p0")); + modelMetadata(uncertainty, oodScore), + new DirectionOutput(longP, shortP, neutral), + new EntryOutput(bd(longEntryProb), bd(shortEntryProb), bd(netEdgeBps), bd(netEdgeBps)), + new ContinueOutput(bd(longContinueProb), bd(shortContinueProb), bd(continueEdgeBps), bd(continueEdgeBps)), + new ExitOutput(bd(longExitProb), bd(shortExitProb), bd("10"), bd("10"), exitReasonScores()), + new RiskOutput(bd(marketRiskProb), bd(sidePositionRiskProb), bd(sidePositionRiskProb), bd("20"), + bd(positionPathRiskBps), bd(positionPathRiskBps), riskReasonScores())); + } + + public static TraderModelOutputMetadata modelMetadata(String uncertainty, String oodScore) { + Map versions = Map.of( + "DIRECTION", "trader-v4-btc-p0", + "ENTRY", "trader-v4-btc-p0", + "CONTINUE", "trader-v4-btc-p0", + "EXIT", "trader-v4-btc-p0", + "RISK", "trader-v4-btc-p0"); + Map calibrationVersions = Map.of( + "DIRECTION", "cal-v4-btc-p0", + "ENTRY", "cal-v4-btc-p0", + "CONTINUE", "cal-v4-btc-p0", + "EXIT", "cal-v4-btc-p0", + "RISK", "cal-v4-btc-p0"); + return new TraderModelOutputMetadata("trader-v4-btc-p0", "cal-v4-btc-p0", + versions, calibrationVersions, + "feature-schema-hash", "feature-order-hash", "output-schema-hash", + bd(uncertainty), bd(oodScore)); + } + + public static TraderPricePlanContext pricePlan() { + return new TraderPricePlanContext("p0-plan-atr-2r", "p0-price-plan-hash", bd("35"), bd("70"), 45, bd("4.0")); + } + + public static Map exitReasonScores() { + return Map.of( + "adverse_move_prob", bd("0.20"), + "reversal_prob", bd("0.25"), + "stop_hit_prob", bd("0.22"), + "stagnation_prob", bd("0.20")); + } + + public static Map riskReasonScores() { + return Map.of( + "market_drawdown_prob", bd("0.15"), + "volatility_expansion_prob", bd("0.20"), + "spike_prob", bd("0.10"), + "liquidity_deterioration_prob", bd("0.12"), + "position_drawdown_prob", bd("0.14")); } public static PositionManagerInput pmInput(TraderModelOutput modelOutput, TraderPositionState positionState) { @@ -151,7 +248,7 @@ public final class TestFixtures { public static PositionManagerInput pmInput(TraderModelOutput modelOutput, TraderPositionState positionState, TraderAccountState accountState, TraderExecutionState executionState) { - return new PositionManagerInput(cycle(), snapshot(), modelOutput, positionState, accountState, executionState, pmConfig()); + return new PositionManagerInput(cycle(), snapshot(), modelOutput, pricePlan(), positionState, accountState, executionState, pmConfig()); } public static TraderPositionState flatPosition() { @@ -192,9 +289,14 @@ public final class TestFixtures { } public static TraderExecutionState execution(List openOrders, long latencyMs, int apiErrorCount) { + return execution(openOrders, latencyMs, apiErrorCount, "1.0"); + } + + public static TraderExecutionState execution(List openOrders, long latencyMs, int apiErrorCount, + String liquidityCapacityRatio) { return new TraderExecutionState("execution-state-1", "run-1", "cycle-1", "BTC-USDT-PERP", openOrders, bd("1.5"), latencyMs, apiErrorCount, bd("1"), bd("4"), bd("5"), - bd("0.1"), bd("0.001"), bd("0.001"), BigDecimal.ONE); + bd("0.1"), bd("0.001"), bd("0.001"), bd(liquidityCapacityRatio)); } public static TraderPositionManagerDecision pmDecision(TraderActionType action, PositionSide side) { @@ -213,31 +315,243 @@ public final class TestFixtures { } public static RiskLimits riskLimits() { - return new RiskLimits(bd("200"), BigDecimal.ONE, bd("500"), 3, 1500, false, false); + return new RiskLimits(bd("200"), BigDecimal.ONE, bd("500"), 3, 1500, false, false, null); } public static void writeArtifactBundle(Path artifactRoot) throws IOException { Files.createDirectories(artifactRoot.resolve("manifests")); + Files.createDirectories(artifactRoot.resolve("models")); + Files.createDirectories(artifactRoot.resolve("calibrators")); + Files.createDirectories(artifactRoot.resolve("schemas")); + Files.createDirectories(artifactRoot.resolve("examples")); + String directionHash = writeArtifact(artifactRoot.resolve("models/direction.json"), "{\"model\":\"DIRECTION\"}"); + String entryHash = writeArtifact(artifactRoot.resolve("models/entry.json"), "{\"model\":\"ENTRY\"}"); + String continueHash = writeArtifact(artifactRoot.resolve("models/continue.json"), "{\"model\":\"CONTINUE\"}"); + String exitHash = writeArtifact(artifactRoot.resolve("models/exit.json"), "{\"model\":\"EXIT\"}"); + String riskHash = writeArtifact(artifactRoot.resolve("models/risk.json"), "{\"model\":\"RISK\"}"); + String featureSchemaHash = writeArtifact(artifactRoot.resolve("schemas/features.json"), "{\"features\":[\"mark_price\",\"index_price\"]}"); + String featureOrderHash = writeArtifact(artifactRoot.resolve("schemas/feature_order.json"), """ + [ + "ret_1m_bps","ret_5m_bps","ret_15m_bps","ret_60m_bps","ret_240m_bps", + "realized_vol_15m_bps","realized_vol_60m_bps","vol_ratio_15m_60m", + "range_15m_bps","range_60m_bps","volume_zscore_60m","trend_consistency_15m", + "channel_position_60m_pct","upper_breakout_60m_bps","lower_breakout_60m_bps", + "upper_failed_break_reclaim_15m_bps","lower_failed_break_reclaim_15m_bps", + "sweep_up_15m_bps","sweep_down_15m_bps","compression_score_4h_pct", + "compression_release_15m_bps","taker_imbalance_1m","taker_imbalance_5m", + "taker_imbalance_15m","level1_ofi_1m","spread_bps","spread_rank_24h_pct", + "oi_delta_15m_bps","oi_delta_60m_bps","funding_bps","mark_index_basis_bps", + "liquidation_buy_notional_1m","liquidation_sell_notional_1m", + "liquidation_imbalance_15m","liquidation_notional_zscore_15m", + "liquidation_available","minute_of_day_sin","minute_of_day_cos", + "minutes_to_next_funding" + ] + """); + String outputSchemaHash = writeArtifact(artifactRoot.resolve("schemas/outputs.json"), """ + { + "direction": { + "longProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortProb": {"type": "decimal", "range": [0.0, 1.0]}, + "neutralProb": {"type": "decimal", "range": [0.0, 1.0]} + }, + "entry": { + "longEntryProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortEntryProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longExpectedNetEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]}, + "shortExpectedNetEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]} + }, + "continue": { + "longContinueProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortContinueProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longExpectedContinueEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]}, + "shortExpectedContinueEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]} + }, + "exit": { + "longExitProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortExitProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longAdverseMoveBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "shortAdverseMoveBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "exitReasonScores": { + "adverse_move_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "reversal_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "stop_hit_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "stagnation_prob": {"type": "decimal", "range": [0.0, 1.0]} + } + }, + "risk": { + "marketRiskProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]}, + "marketPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "longPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "shortPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "riskReasonScores": { + "market_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "volatility_expansion_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "spike_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "liquidity_deterioration_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "position_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]} + } + } + } + """); + writeArtifact(artifactRoot.resolve("examples/input.json"), "{\"mark_price\":100,\"index_price\":99.5}"); + String directionCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/direction.json"), + calibratorJson("direction-cal-p0", Map.of("longProb", "0.62", "shortProb", "0.28", "neutralProb", "0.10"))); + String entryCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/entry.json"), + calibratorJson("entry-cal-p0", Map.of("longEntryProb", "0.63", "shortEntryProb", "0.42"))); + String continueCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/continue.json"), + calibratorJson("continue-cal-p0", Map.of("longContinueProb", "0.61", "shortContinueProb", "0.39"))); + String exitCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/exit.json"), + calibratorJson("exit-cal-p0", Map.of( + "longExitProb", "0.24", "shortExitProb", "0.48", + "adverse_move_prob", "0.20", "reversal_prob", "0.25", + "stop_hit_prob", "0.22", "stagnation_prob", "0.20"))); + String riskCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/risk.json"), + calibratorJson("risk-cal-p0", Map.of( + "marketRiskProb", "0.20", "longPositionRiskProb", "0.18", "shortPositionRiskProb", "0.28", + "market_drawdown_prob", "0.15", "volatility_expansion_prob", "0.20", + "spike_prob", "0.10", "liquidity_deterioration_prob", "0.12", + "position_drawdown_prob", "0.14"))); Files.writeString(artifactRoot.resolve("manifests/model_bundle_manifest.json"), """ { + "manifest_schema_version": "trader-model-bundle-v1", "model_bundle_version": "trader-v4-btc-p0", "calibration_bundle_version": "cal-v4-btc-p0", "feature_version": "feature-v4-p0", "label_version": "label-v4-p0", "split_version": "split-v4-p0", + "training_run_id": "train-run-p0", + "training_export_id": "export-p0", + "backtest_manifest_id": "backtest-p0", "required_models_json": ["DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"], "provided_models_json": ["DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"], "missing_models_json": [], + "allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"], "bundle_hash_sha256": "00000000000000000000000000000000000000000000000000000000000000a1", "complete": true, "status": "ACTIVE" } """); + Files.writeString(artifactRoot.resolve("manifests/model_manifest.json"), """ + [ + { + "model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0", + "model_name":"DIRECTION","model_type":"DIRECTION","side":"BOTH", + "symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45, + "model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0", + "onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0", + "feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json", + "feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s", + "input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39}, + "input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json", + "output_schema_hash":"%7$s","output_tensor_names_json":["probabilities"], + "output_mapping_json":{"long_prob":"probabilities[0]","short_prob":"probabilities[1]","neutral_prob":"probabilities[2]"}, + "output_value_rules_json":{"probability":"[0,1]"}, + "label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0", + "train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z", + "validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z", + "test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z", + "metrics_json":{"sample_count":1000},"artifact_path":"models/direction.json", + "artifact_hash_sha256":"%1$s","source_hash":"source-direction","status":"ACTIVE" + }, + { + "model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0", + "model_name":"ENTRY","model_type":"ENTRY","side":"BOTH", + "symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45, + "model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0", + "onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0", + "feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json", + "feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s", + "input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39}, + "input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json", + "output_schema_hash":"%7$s","output_tensor_names_json":["entry"], + "output_mapping_json":{"long_entry_prob":"entry[0]","short_entry_prob":"entry[1]","long_expected_net_edge_bps":"entry[2]","short_expected_net_edge_bps":"entry[3]"}, + "output_value_rules_json":{"probability":"[0,1]"}, + "label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0", + "train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z", + "validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z", + "test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z", + "metrics_json":{"sample_count":1000},"artifact_path":"models/entry.json", + "artifact_hash_sha256":"%2$s","source_hash":"source-entry","status":"ACTIVE" + }, + { + "model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0", + "model_name":"CONTINUE","model_type":"CONTINUE","side":"BOTH", + "symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45, + "model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0", + "onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0", + "feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json", + "feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s", + "input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39}, + "input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json", + "output_schema_hash":"%7$s","output_tensor_names_json":["continue"], + "output_mapping_json":{"long_continue_prob":"continue[0]","short_continue_prob":"continue[1]","long_expected_continue_edge_bps":"continue[2]","short_expected_continue_edge_bps":"continue[3]"}, + "output_value_rules_json":{"probability":"[0,1]"}, + "label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0", + "train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z", + "validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z", + "test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z", + "metrics_json":{"sample_count":1000},"artifact_path":"models/continue.json", + "artifact_hash_sha256":"%3$s","source_hash":"source-continue","status":"ACTIVE" + }, + { + "model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0", + "model_name":"EXIT","model_type":"EXIT","side":"BOTH", + "symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45, + "model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0", + "onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0", + "feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json", + "feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s", + "input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39}, + "input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json", + "output_schema_hash":"%7$s","output_tensor_names_json":["exit"], + "output_mapping_json":{"long_exit_prob":"exit[0]","short_exit_prob":"exit[1]","long_adverse_move_bps":"exit[2]","short_adverse_move_bps":"exit[3]","adverse_move_prob":"exit[4]","reversal_prob":"exit[5]","stop_hit_prob":"exit[6]","stagnation_prob":"exit[7]"}, + "output_value_rules_json":{"probability":"[0,1]"}, + "label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0", + "train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z", + "validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z", + "test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z", + "metrics_json":{"sample_count":1000},"artifact_path":"models/exit.json", + "artifact_hash_sha256":"%4$s","source_hash":"source-exit","status":"ACTIVE" + }, + { + "model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0", + "model_name":"RISK","model_type":"RISK","side":"BOTH", + "symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45, + "model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0", + "onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0", + "feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json", + "feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s", + "input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39}, + "input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json", + "output_schema_hash":"%7$s","output_tensor_names_json":["risk"], + "output_mapping_json":{"market_risk_prob":"risk[0]","long_position_risk_prob":"risk[1]","short_position_risk_prob":"risk[2]","market_path_risk_bps":"risk[3]","long_position_path_risk_bps":"risk[4]","short_position_path_risk_bps":"risk[5]","market_drawdown_prob":"risk[6]","volatility_expansion_prob":"risk[7]","spike_prob":"risk[8]","liquidity_deterioration_prob":"risk[9]","position_drawdown_prob":"risk[10]"}, + "output_value_rules_json":{"probability":"[0,1]"}, + "label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0", + "train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z", + "validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z", + "test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z", + "metrics_json":{"sample_count":1000},"artifact_path":"models/risk.json", + "artifact_hash_sha256":"%5$s","source_hash":"source-risk","status":"ACTIVE" + } + ] + """.formatted(directionHash, entryHash, continueHash, exitHash, riskHash, featureSchemaHash, outputSchemaHash, featureOrderHash)); + Files.writeString(artifactRoot.resolve("manifests/calibration_manifest.json"), """ + [ + {"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"DIRECTION","calibrator_version":"direction-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/direction.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"}, + {"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"ENTRY","calibrator_version":"entry-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/entry.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"}, + {"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"CONTINUE","calibrator_version":"continue-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/continue.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"}, + {"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"EXIT","calibrator_version":"exit-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/exit.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"}, + {"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"RISK","calibrator_version":"risk-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/risk.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"} + ] + """.formatted(directionCalibratorHash, entryCalibratorHash, continueCalibratorHash, exitCalibratorHash, riskCalibratorHash)); Files.writeString(artifactRoot.resolve("manifests/position_manager_manifest.json"), """ { "pm_config_version": "pm-v4-btc-p0", "model_bundle_version": "trader-v4-btc-p0", "calibration_bundle_version": "cal-v4-btc-p0", + "threshold_stability_json": {}, "allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"], "config_hash_sha256": "00000000000000000000000000000000000000000000000000000000000000b1", "status": "ACTIVE", @@ -274,11 +588,11 @@ public final class TestFixtures { "closePositionRiskProb": 0.70, "closeMarketRiskProb": 0.70, "closeContinueMax": 0.25, - "reduceGivebackProb": 0.62, + "reduceAdverseMoveProb": 0.62, "reduceContinueMin": 0.35, "reduceContinueMax": 0.70, "minProfitForReduceBps": 5.0, - "maxExpectedShortfallBps": 80 + "maxPositionPathRiskBps": 80 }, "sizing": { "baseRatio": 0.80, @@ -296,64 +610,110 @@ public final class TestFixtures { } } """); - Files.writeString(artifactRoot.resolve("model_output_policy.json"), """ + Files.writeString(artifactRoot.resolve("price_plan_context.json"), """ + { + "pricePlanId": "p0-plan-atr-2r", + "pricePlanConfigHash": "p0-price-plan-hash", + "stopDistanceBps": 35, + "targetDistanceBps": 70, + "maxHoldMinutes": 45, + "costBps": 4.0 + } + """); + Files.writeString(artifactRoot.resolve("replay_model_fixture.json"), """ { "direction": { "longProbWhenMarkGteIndex": 0.62, "longProbWhenMarkLtIndex": 0.32, - "neutralProb": 0.10, - "expectedReturnBps": 8.0, - "horizonMinutes": 45, - "modelVersion": "direction-p0" + "neutralProb": 0.10 }, "entry": { "longEntryProb": 0.63, "shortEntryProb": 0.42, - "entryQualityScore": 0.64, - "expectedEdgeBps": 12.0, - "pricePlanId": "p0-plan-atr-2r", - "pricePlanConfigHash": "p0-price-plan-hash", - "stopDistanceBps": 35, - "targetDistanceBps": 70, - "maxHoldMinutes": 45, - "costBps": 4.0, - "modelVersion": "entry-p0" + "longExpectedNetEdgeBps": 12.0, + "shortExpectedNetEdgeBps": 8.0 }, "continuation": { "longContinueProb": 0.61, "shortContinueProb": 0.39, - "trendPersistenceProb": 0.58, - "holdEdgeBps": 5.0, - "continueVsExitEdgeBps": 3.0, - "modelVersion": "continue-p0" + "longExpectedContinueEdgeBps": 3.0, + "shortExpectedContinueEdgeBps": 1.5 }, "exit": { "longExitProb": 0.24, "shortExitProb": 0.48, - "profitGivebackProb": 0.20, - "reversalProb": 0.25, - "stopRiskProb": 0.22, - "stagnationProb": 0.20, - "expectedGivebackBps": 10, - "modelVersion": "exit-p0" + "longAdverseMoveBps": 10, + "shortAdverseMoveBps": 18, + "exitReasonScores": { + "adverse_move_prob": 0.20, + "reversal_prob": 0.25, + "stop_hit_prob": 0.22, + "stagnation_prob": 0.20 + } }, "risk": { "marketRiskProb": 0.20, - "positionRiskProb": 0.18, - "marketRiskSeverityBps": 20, - "positionRiskSeverityBps": 18, - "drawdownProb": 0.15, - "expectedShortfallBps": 20, - "volatilityExpansionProb": 0.20, - "spikeProb": 0.10, - "liquidityRiskProb": 0.12, - "liquidityCapacityRatioWhenReady": 1.0, - "liquidityCapacityRatioWhenNotReady": 0, - "modelVersion": "risk-p0" + "longPositionRiskProb": 0.18, + "shortPositionRiskProb": 0.28, + "marketPathRiskBps": 20, + "longPositionPathRiskBps": 20, + "shortPositionPathRiskBps": 30, + "riskReasonScores": { + "market_drawdown_prob": 0.15, + "volatility_expansion_prob": 0.20, + "spike_prob": 0.10, + "liquidity_deterioration_prob": 0.12, + "position_drawdown_prob": 0.14 + } }, "uncertainty": 0.10, - "oodScore": 0.05 + "oodScore": 0.05, + "featureSchemaHash": "feature-schema-hash", + "featureOrderHash": "feature-order-hash", + "outputSchemaHash": "output-schema-hash" } """); } + + private static String calibratorJson(String version, Map targets) { + StringBuilder targetJson = new StringBuilder(); + int index = 0; + for (Map.Entry entry : targets.entrySet()) { + if (index++ > 0) { + targetJson.append(","); + } + targetJson.append("\"").append(entry.getKey()).append("\":") + .append("{\"bins\":[{\"min\":0.0,\"max\":1.0,\"calibrated\":") + .append(entry.getValue()) + .append("}]}"); + } + return """ + { + "calibrator_version": "%s", + "method": "BINNING", + "targets": {%s}, + "clip": {"min": 0.0, "max": 1.0}, + "fail_policy": "FAIL_FAST" + } + """.formatted(version, targetJson); + } + + private static String writeArtifact(Path path, String content) throws IOException { + Files.writeString(path, content); + return sha256(content); + } + + private static String sha256(String content) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(content.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + StringBuilder builder = new StringBuilder(hash.length * 2); + for (byte value : hash) { + builder.append(String.format("%02x", value & 0xff)); + } + return builder.toString(); + } catch (NoSuchAlgorithmException exception) { + throw new IllegalStateException("SHA-256 is not available", exception); + } + } } diff --git a/src/test/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepositoryTest.java b/src/test/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepositoryTest.java new file mode 100644 index 0000000..7b9ead2 --- /dev/null +++ b/src/test/java/com/quantai/trader/artifact/JdbcTraderArtifactManifestRepositoryTest.java @@ -0,0 +1,35 @@ +package com.quantai.trader.artifact; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.springframework.jdbc.core.JdbcTemplate; + +import java.io.IOException; +import java.nio.file.Path; + +import static com.quantai.trader.TestFixtures.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.contains; +import static org.mockito.Mockito.*; + +class JdbcTraderArtifactManifestRepositoryTest { + @TempDir + Path artifactRoot; + + @Test + void upsertsModelCalibrationAndPmManifestsForActiveBundle() throws IOException { + writeArtifactBundle(artifactRoot); + TraderArtifactBundle bundle = new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()) + .loadActiveBundle(); + JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class); + JdbcTraderArtifactManifestRepository repository = new JdbcTraderArtifactManifestRepository(jdbcTemplate, objectMapper()); + + repository.upsertActiveBundle(bundle); + + verify(jdbcTemplate).update(contains("insert into trader_model_bundle_manifest"), any(Object[].class)); + verify(jdbcTemplate, times(5)).update(contains("insert into trader_model_manifest"), any(Object[].class)); + verify(jdbcTemplate, times(5)).update(contains("insert into trader_calibration_manifest"), any(Object[].class)); + verify(jdbcTemplate).update(contains("insert into trader_pm_config_manifest"), any(Object[].class)); + verifyNoMoreInteractions(jdbcTemplate); + } +} diff --git a/src/test/java/com/quantai/trader/artifact/TraderArtifactLoaderTest.java b/src/test/java/com/quantai/trader/artifact/TraderArtifactLoaderTest.java index f4759e7..635a6e4 100644 --- a/src/test/java/com/quantai/trader/artifact/TraderArtifactLoaderTest.java +++ b/src/test/java/com/quantai/trader/artifact/TraderArtifactLoaderTest.java @@ -1,10 +1,12 @@ package com.quantai.trader.artifact; import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.model.ReplayFixtureTraderModelService; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.io.IOException; +import java.nio.file.Files; import java.nio.file.Path; import java.util.Set; @@ -24,7 +26,25 @@ class TraderArtifactLoaderTest { assertThat(bundle.providedModels()).containsExactlyInAnyOrder("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"); assertThat(bundle.pmConfig().pmConfigVersion()).isEqualTo("pm-v4-btc-p0"); - assertThat(bundle.modelPolicy().entry().pricePlanConfigHash()).isEqualTo("p0-price-plan-hash"); + assertThat(bundle.pricePlanContext().pricePlanConfigHash()).isEqualTo("p0-price-plan-hash"); + assertThat(bundle.modelManifests()).allSatisfy(manifest -> { + assertThat(manifest.featureOrderPath()).isEqualTo("schemas/feature_order.json"); + assertThat(manifest.inputShapeJson()).containsEntry("features", 39); + assertThat(manifest.onnxOpsetVersion()).isEqualTo(17); + }); + assertThat(bundle.requireReplayModelFixture().entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12.0"); + assertThat(bundle.requireReplayModelFixture().entry().shortExpectedNetEdgeBps()).isEqualByComparingTo("8.0"); + } + + @Test + void shadowModeCannotUseReplayFixtureInference() throws IOException { + writeArtifactBundle(artifactRoot); + TraderProperties properties = propertiesWithArtifactRoot(artifactRoot); + TraderArtifactBundle bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + + assertThatThrownBy(() -> new ReplayFixtureTraderModelService(properties).evaluate(snapshot(), bundle)) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("only allows REPLAY_SIM"); } @Test @@ -41,7 +61,9 @@ class TraderArtifactLoaderTest { assertThatThrownBy(() -> new TraderArtifactBundle( bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion(), - bundle.bundleHashSha256(), Set.of("DIRECTION", "ENTRY", "CONTINUE", "RISK"), bundle.pmConfig(), bundle.modelPolicy())) + bundle.bundleHashSha256(), Set.of("DIRECTION", "ENTRY", "CONTINUE", "RISK"), + bundle.modelBundleManifest(), bundle.modelManifests(), bundle.calibrationManifests(), + bundle.pmConfigManifest(), bundle.pmConfig(), bundle.pricePlanContext(), bundle.replayModelFixture())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("all five"); } @@ -52,4 +74,37 @@ class TraderArtifactLoaderTest { .isInstanceOf(com.quantai.trader.domain.TraderException.class) .hasMessageContaining("artifact file is missing"); } + + @Test + void rejectsMissingFeatureOrderArtifact() throws IOException { + writeArtifactBundle(artifactRoot); + Files.delete(artifactRoot.resolve("schemas/feature_order.json")); + + assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle()) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("artifact referenced by manifest is missing"); + } + + @Test + void rejectsNonV4InputShape() throws IOException { + writeArtifactBundle(artifactRoot); + Path manifest = artifactRoot.resolve("manifests/model_manifest.json"); + Files.writeString(manifest, Files.readString(manifest).replace("\"features\":39", "\"features\":38")); + + assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle()) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("runtime contract is invalid"); + } + + @Test + void rejectsLegacyOutputMappingKeys() throws IOException { + writeArtifactBundle(artifactRoot); + Path manifest = artifactRoot.resolve("manifests/model_manifest.json"); + Files.writeString(manifest, Files.readString(manifest) + .replace("\"long_expected_net_edge_bps\":\"entry[2]\"", "\"expected_net_edge_bps\":\"entry[2]\"")); + + assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle()) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("rejected legacy key"); + } } diff --git a/src/test/java/com/quantai/trader/config/JacksonConfigTest.java b/src/test/java/com/quantai/trader/config/JacksonConfigTest.java new file mode 100644 index 0000000..5c260bf --- /dev/null +++ b/src/test/java/com/quantai/trader/config/JacksonConfigTest.java @@ -0,0 +1,17 @@ +package com.quantai.trader.config; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; +import static org.assertj.core.api.Assertions.assertThat; + +class JacksonConfigTest { + @Test + void createsObjectMapperWithJavaTimeSupport() throws Exception { + var mapper = new JacksonConfig().traderObjectMapper(); + + Instant value = mapper.readValue("\"2026-06-26T00:00:00Z\"", Instant.class); + + assertThat(value).isEqualTo(Instant.parse("2026-06-26T00:00:00Z")); + } +} diff --git a/src/test/java/com/quantai/trader/controller/TraderControllerTest.java b/src/test/java/com/quantai/trader/controller/TraderControllerTest.java index 6e145a0..d57cd3a 100644 --- a/src/test/java/com/quantai/trader/controller/TraderControllerTest.java +++ b/src/test/java/com/quantai/trader/controller/TraderControllerTest.java @@ -7,6 +7,7 @@ import com.quantai.trader.enums.FeedbackSource; import com.quantai.trader.enums.TraderErrorCode; import com.quantai.trader.enums.TraderExecutionMode; import com.quantai.trader.enums.TraderRunMode; +import com.quantai.trader.feedback.TraderFeedbackRepository; import org.junit.jupiter.api.Test; import org.springframework.http.ResponseEntity; @@ -21,7 +22,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; class TraderControllerTest { @Test void feedbackEndpointRejectsWhenHttpFeedbackIsDisabled() { - TraderFeedbackController controller = new TraderFeedbackController(properties(), new FeedbackValidator()); + RecordingFeedbackRepository repository = new RecordingFeedbackRepository(); + TraderFeedbackController controller = new TraderFeedbackController(properties(), new FeedbackValidator(), repository); assertThatThrownBy(() -> controller.feedback(feedback(FeedbackSource.SHADOW_APP, false))) .isInstanceOf(TraderException.class) @@ -31,7 +33,8 @@ class TraderControllerTest { @Test void feedbackEndpointRejectsPaperAppSourceInP0EvenWhenHttpIsEnabledForTest() { TraderFeedbackController controller = new TraderFeedbackController( - properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator()); + properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator(), + new RecordingFeedbackRepository()); assertThatThrownBy(() -> controller.feedback(feedback(FeedbackSource.PAPER_APP, true))) .isInstanceOf(TraderException.class) @@ -40,12 +43,15 @@ class TraderControllerTest { @Test void feedbackEndpointAcceptsShadowRecorderFeedbackWhenExplicitlyEnabled() { + RecordingFeedbackRepository repository = new RecordingFeedbackRepository(); TraderFeedbackController controller = new TraderFeedbackController( - properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator()); + properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator(), + repository); Map result = controller.feedback(feedback(FeedbackSource.SHADOW_APP, false)); assertThat(result).containsEntry("accepted", true).containsEntry("feedbackId", "feedback-1"); + assertThat(repository.items()).containsExactly(feedback(FeedbackSource.SHADOW_APP, false)); } @Test @@ -83,4 +89,17 @@ class TraderControllerTest { null, Map.of()); } + + private static final class RecordingFeedbackRepository implements TraderFeedbackRepository { + private final java.util.List items = new java.util.ArrayList<>(); + + @Override + public void insert(TraderAppFeedback feedback) { + items.add(feedback); + } + + java.util.List items() { + return items; + } + } } diff --git a/src/test/java/com/quantai/trader/controller/TraderReplayControllerTest.java b/src/test/java/com/quantai/trader/controller/TraderReplayControllerTest.java new file mode 100644 index 0000000..2af51aa --- /dev/null +++ b/src/test/java/com/quantai/trader/controller/TraderReplayControllerTest.java @@ -0,0 +1,45 @@ +package com.quantai.trader.controller; + +import com.quantai.trader.domain.TraderAction; +import com.quantai.trader.domain.TraderRiskDecision; +import com.quantai.trader.enums.PositionSide; +import com.quantai.trader.enums.TraderActionType; +import com.quantai.trader.replay.ReplayMarketEvent; +import com.quantai.trader.replay.TraderCycleResult; +import com.quantai.trader.replay.TraderP0CycleRunner; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static com.quantai.trader.TestFixtures.pmDecision; +import static com.quantai.trader.TestFixtures.replayEvent; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class TraderReplayControllerTest { + @Test + void delegatesReplayCycleRequestToRunner() { + TraderP0CycleRunner runner = mock(TraderP0CycleRunner.class); + TraderReplayController controller = new TraderReplayController(runner); + ReplayMarketEvent event = replayEvent("run-1", com.quantai.trader.TestFixtures.T0, "100", "99.5", "1000"); + TraderCycleResult expected = result(); + when(runner.runCycle(event)).thenReturn(expected); + + TraderCycleResult actual = controller.runOneCycle(event); + + assertThat(actual).isSameAs(expected); + verify(runner).runCycle(event); + } + + private TraderCycleResult result() { + TraderRiskDecision riskDecision = new TraderRiskDecision("risk-1", "run-1", "cycle-1", + "pm-cycle-1", true, TraderActionType.WAIT, TraderActionType.WAIT, null, Map.of()); + TraderAction action = new TraderAction("action-1", "run-1", "cycle-1", "model-output-1", + "pm-cycle-1", "risk-1", TraderActionType.WAIT, "BTC-USDT-PERP", PositionSide.NONE, + null, null, null, null, null, null, false, "idem-1", "WAIT", Map.of()); + return new TraderCycleResult("run-1", "cycle-1", pmDecision(TraderActionType.WAIT, PositionSide.NONE), + riskDecision, action); + } +} diff --git a/src/test/java/com/quantai/trader/domain/ModelOutputContractTest.java b/src/test/java/com/quantai/trader/domain/ModelOutputContractTest.java index 092bfff..3bad04a 100644 --- a/src/test/java/com/quantai/trader/domain/ModelOutputContractTest.java +++ b/src/test/java/com/quantai/trader/domain/ModelOutputContractTest.java @@ -16,8 +16,7 @@ class ModelOutputContractTest { @Test void rejectsProbabilityOutsideClosedUnitRange() { assertThatThrownBy(() -> new DirectionOutput( - bd("1.01"), BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ONE, BigDecimal.ONE, - bd("1"), 45, "direction-p0", "cal-v4-btc-p0", Map.of())) + bd("1.01"), BigDecimal.ZERO, BigDecimal.ZERO)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("direction.longProb"); } diff --git a/src/test/java/com/quantai/trader/evidence/JdbcTraderEvidenceRepositoryTest.java b/src/test/java/com/quantai/trader/evidence/JdbcTraderEvidenceRepositoryTest.java new file mode 100644 index 0000000..cf75af0 --- /dev/null +++ b/src/test/java/com/quantai/trader/evidence/JdbcTraderEvidenceRepositoryTest.java @@ -0,0 +1,29 @@ +package com.quantai.trader.evidence; + +import com.quantai.trader.domain.TraderEvidence; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcTemplate; + +import java.time.Instant; +import java.util.Map; + +import static com.quantai.trader.TestFixtures.objectMapper; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class JdbcTraderEvidenceRepositoryTest { + @Test + void insertsEvidenceWithJsonDetails() { + JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class); + JdbcTraderEvidenceRepository repository = new JdbcTraderEvidenceRepository(jdbcTemplate, objectMapper()); + TraderEvidence evidence = new TraderEvidence("evidence-1", "run-1", "cycle-1", "MODEL_OUTPUT", + true, "MODEL_EVALUATED", null, Instant.parse("2026-06-26T00:00:00Z"), + Map.of("modelOutputId", "model-output-1")); + + repository.insert(evidence); + + verify(jdbcTemplate).update(anyString(), any(Object[].class)); + } +} diff --git a/src/test/java/com/quantai/trader/feature/TraderFeatureVectorBuilderTest.java b/src/test/java/com/quantai/trader/feature/TraderFeatureVectorBuilderTest.java new file mode 100644 index 0000000..50af260 --- /dev/null +++ b/src/test/java/com/quantai/trader/feature/TraderFeatureVectorBuilderTest.java @@ -0,0 +1,68 @@ +package com.quantai.trader.feature; + +import com.quantai.trader.artifact.TraderArtifactLoader; +import com.quantai.trader.domain.TraderMarketSnapshot; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.LinkedHashMap; +import java.util.Map; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TraderFeatureVectorBuilderTest { + @TempDir + Path artifactRoot; + + @Test + void buildsFeatureVectorByArtifactFeatureOrder() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + TraderFeatureVectorBuilder builder = new TraderFeatureVectorBuilder(properties, objectMapper()); + + float[] values = builder.build(snapshot(), bundle); + + assertThat(values).hasSize(39); + assertThat(values[0]).isEqualTo(1.1f); + assertThat(values[38]).isEqualTo(120.0f); + } + + @Test + void rejectsMissingFeatureInsteadOfFillingZero() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + TraderFeatureVectorBuilder builder = new TraderFeatureVectorBuilder(properties, objectMapper()); + Map features = new LinkedHashMap<>(featureJson()); + features.remove("ret_1m_bps"); + TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", + "BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), + bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson()); + + assertThatThrownBy(() -> builder.build(badSnapshot, bundle)) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("snapshot feature is missing"); + } + + @Test + void rejectsRuntimeFieldsMixedIntoModelFeatures() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + TraderFeatureVectorBuilder builder = new TraderFeatureVectorBuilder(properties, objectMapper()); + Map features = new LinkedHashMap<>(featureJson()); + features.put("account_balance", bd("1000")); + TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", + "BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), + bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson()); + + assertThatThrownBy(() -> builder.build(badSnapshot, bundle)) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("outside feature_order"); + } +} diff --git a/src/test/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepositoryTest.java b/src/test/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepositoryTest.java new file mode 100644 index 0000000..d5baf2c --- /dev/null +++ b/src/test/java/com/quantai/trader/feedback/JdbcTraderFeedbackRepositoryTest.java @@ -0,0 +1,40 @@ +package com.quantai.trader.feedback; + +import com.quantai.trader.domain.TraderAppFeedback; +import com.quantai.trader.enums.FeedbackSource; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.jdbc.core.JdbcTemplate; + +import java.sql.Timestamp; +import java.util.Map; + +import static com.quantai.trader.TestFixtures.T0; +import static com.quantai.trader.TestFixtures.objectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.contains; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class JdbcTraderFeedbackRepositoryTest { + @Test + void insertsFeedbackWithSourceFillFlagsAndTimestamps() { + JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class); + JdbcTraderFeedbackRepository repository = new JdbcTraderFeedbackRepository(jdbcTemplate, objectMapper()); + TraderAppFeedback feedback = new TraderAppFeedback( + "feedback-1", "run-1", "cycle-1", "action-1", + FeedbackSource.SHADOW_APP, false, null, "RECORDED", T0, T0.plusMillis(10), + null, null, null, null, null, null, Map.of("destination", "SHADOW_RECORDER")); + + repository.insert(feedback); + + ArgumentCaptor args = ArgumentCaptor.forClass(Object[].class); + verify(jdbcTemplate).update(contains("insert into trader_app_feedback"), args.capture()); + assertThat(args.getValue()[4]).isEqualTo("SHADOW_APP"); + assertThat(args.getValue()[5]).isEqualTo(false); + assertThat(args.getValue()[7]).isEqualTo("RECORDED"); + assertThat(args.getValue()[8]).isEqualTo(Timestamp.from(T0)); + assertThat(args.getValue()[9]).isEqualTo(Timestamp.from(T0.plusMillis(10))); + assertThat(args.getValue()[16].toString()).contains("SHADOW_RECORDER"); + } +} diff --git a/src/test/java/com/quantai/trader/model/OnnxTraderModelServiceTest.java b/src/test/java/com/quantai/trader/model/OnnxTraderModelServiceTest.java new file mode 100644 index 0000000..0aa01f2 --- /dev/null +++ b/src/test/java/com/quantai/trader/model/OnnxTraderModelServiceTest.java @@ -0,0 +1,77 @@ +package com.quantai.trader.model; + +import com.quantai.trader.artifact.TraderArtifactLoader; +import com.quantai.trader.artifact.TraderModelManifest; +import com.quantai.trader.feature.TraderFeatureVectorBuilder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Map; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class OnnxTraderModelServiceTest { + @TempDir + Path artifactRoot; + + @Test + void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + OnnxTraderModelService service = new OnnxTraderModelService( + properties, + objectMapper(), + new TraderFeatureVectorBuilder(properties, objectMapper()), + new FakeInferenceClient()); + + var output = service.evaluate(snapshot(), bundle); + + assertThat(output.direction().longProb()).isEqualByComparingTo("0.62"); + assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12"); + assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5"); + assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25"); + assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12"); + assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38"); + assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05"); + } + + @Test + void failsWhenShadowSnapshotHasNoOodScore() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + OnnxTraderModelService service = new OnnxTraderModelService( + properties, + objectMapper(), + new TraderFeatureVectorBuilder(properties, objectMapper()), + new FakeInferenceClient()); + + var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", + "BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), + bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of()); + + assertThatThrownBy(() -> service.evaluate(snapshot, bundle)) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("ood_score"); + } + + private static final class FakeInferenceClient implements TraderOnnxInferenceClient { + @Override + public Map infer(TraderModelManifest manifest, Path modelPath, float[] features) { + return switch (manifest.modelType()) { + case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f}); + case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f}); + case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f}); + case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f}); + case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType()); + }; + } + } +} diff --git a/src/test/java/com/quantai/trader/model/OrtTraderOnnxInferenceClientTest.java b/src/test/java/com/quantai/trader/model/OrtTraderOnnxInferenceClientTest.java new file mode 100644 index 0000000..8474af6 --- /dev/null +++ b/src/test/java/com/quantai/trader/model/OrtTraderOnnxInferenceClientTest.java @@ -0,0 +1,32 @@ +package com.quantai.trader.model; + +import com.quantai.trader.artifact.TraderArtifactLoader; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class OrtTraderOnnxInferenceClientTest { + @TempDir + Path artifactRoot; + + @Test + void failsClearlyWhenOnnxFileIsNotLoadable() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + var manifest = bundle.modelManifests().stream() + .filter(item -> item.modelType().equals("DIRECTION")) + .findFirst() + .orElseThrow(); + + assertThatThrownBy(() -> new OrtTraderOnnxInferenceClient() + .infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[39])) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("ONNX model cannot be loaded"); + } +} diff --git a/src/test/java/com/quantai/trader/model/RoutingTraderModelServiceTest.java b/src/test/java/com/quantai/trader/model/RoutingTraderModelServiceTest.java new file mode 100644 index 0000000..7ffc618 --- /dev/null +++ b/src/test/java/com/quantai/trader/model/RoutingTraderModelServiceTest.java @@ -0,0 +1,84 @@ +package com.quantai.trader.model; + +import com.quantai.trader.artifact.TraderArtifactLoader; +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.enums.TraderExecutionMode; +import com.quantai.trader.enums.TraderRunMode; +import com.quantai.trader.feature.TraderFeatureVectorBuilder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class RoutingTraderModelServiceTest { + @TempDir + Path artifactRoot; + + @Test + void replaySimUsesReplayFixtureOnly() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + var service = new RoutingTraderModelService( + properties, + new ReplayFixtureTraderModelService(properties), + new OnnxTraderModelService(properties, objectMapper(), + new TraderFeatureVectorBuilder(properties, objectMapper()), + (manifest, modelPath, features) -> { + throw new AssertionError("REPLAY_SIM must not call ONNX"); + })); + + var output = service.evaluate(snapshot(), bundle); + + assertThat(output.direction().longProb()).isEqualByComparingTo("0.62"); + } + + @Test + void shadowUsesOnnxService() throws IOException { + writeArtifactBundle(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.SHADOW, TraderExecutionMode.SHADOW); + var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle(); + var service = new RoutingTraderModelService( + properties, + new ReplayFixtureTraderModelService(properties), + new OnnxTraderModelService(properties, objectMapper(), + new TraderFeatureVectorBuilder(properties, objectMapper()), + (manifest, modelPath, features) -> switch (manifest.modelType()) { + case "DIRECTION" -> java.util.Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f}); + case "ENTRY" -> java.util.Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f}); + case "CONTINUE" -> java.util.Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f}); + case "EXIT" -> java.util.Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f}); + case "RISK" -> java.util.Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType()); + })); + + var output = service.evaluate(snapshot(), bundle); + + assertThat(output.risk().marketRiskProb()).isEqualByComparingTo("0.20"); + } + + @Test + void rejectsRunModesOutsideP0Scope() throws IOException { + writeArtifactBundle(artifactRoot); + var bundle = new TraderArtifactLoader( + propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM), + objectMapper()).loadActiveBundle(); + var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.PAPER, TraderExecutionMode.PAPER); + var service = new RoutingTraderModelService( + properties, + new ReplayFixtureTraderModelService(properties), + new OnnxTraderModelService(properties, objectMapper(), + new TraderFeatureVectorBuilder(properties, objectMapper()), + (manifest, modelPath, features) -> java.util.Map.of())); + + assertThatThrownBy(() -> service.evaluate(snapshot(), bundle)) + .isInstanceOf(TraderException.class) + .hasMessageContaining("only supports REPLAY_SIM and SHADOW"); + } +} diff --git a/src/test/java/com/quantai/trader/model/TraderOutputSchemaBoundsTest.java b/src/test/java/com/quantai/trader/model/TraderOutputSchemaBoundsTest.java new file mode 100644 index 0000000..ca43ad8 --- /dev/null +++ b/src/test/java/com/quantai/trader/model/TraderOutputSchemaBoundsTest.java @@ -0,0 +1,55 @@ +package com.quantai.trader.model; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; + +import static com.quantai.trader.TestFixtures.objectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TraderOutputSchemaBoundsTest { + @TempDir + Path tempDir; + + @Test + void clipsBpsByOutputSchemaRange() throws IOException { + Path path = tempDir.resolve("output_schema.json"); + Files.writeString(path, """ + {"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}} + """); + + TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path); + + assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("12"))).isEqualByComparingTo("10"); + assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("-12"))).isEqualByComparingTo("-10"); + assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("2"))).isEqualByComparingTo("2"); + } + + @Test + void rejectsMissingFieldRange() throws IOException { + Path path = tempDir.resolve("output_schema.json"); + Files.writeString(path, """ + {"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}} + """); + TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path); + + assertThatThrownBy(() -> bounds.clip("shortExpectedNetEdgeBps", BigDecimal.ONE)) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("does not define range"); + } + + @Test + void rejectsSchemaWithoutRanges() throws IOException { + Path path = tempDir.resolve("output_schema.json"); + Files.writeString(path, "{\"entry\":{}}"); + + assertThatThrownBy(() -> TraderOutputSchemaBounds.read(objectMapper(), path)) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("must define field ranges"); + } +} diff --git a/src/test/java/com/quantai/trader/model/TraderProbabilityCalibratorTest.java b/src/test/java/com/quantai/trader/model/TraderProbabilityCalibratorTest.java new file mode 100644 index 0000000..38a724e --- /dev/null +++ b/src/test/java/com/quantai/trader/model/TraderProbabilityCalibratorTest.java @@ -0,0 +1,58 @@ +package com.quantai.trader.model; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; + +import static com.quantai.trader.TestFixtures.objectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TraderProbabilityCalibratorTest { + @TempDir + Path tempDir; + + @Test + void appliesBinningCalibration() throws IOException { + Path path = tempDir.resolve("calibrator.json"); + Files.writeString(path, """ + {"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}} + """); + + TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION"); + + assertThat(calibrator.calibrate("longProb", new BigDecimal("0.40"))).isEqualByComparingTo("0.62"); + } + + @Test + void rejectsUnsupportedCalibrationMethod() throws IOException { + Path path = tempDir.resolve("calibrator.json"); + Files.writeString(path, """ + {"method":"NONE","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}} + """); + + assertThatThrownBy(() -> TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION")) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("must be BINNING"); + } + + @Test + void rejectsMissingTargetAndOutOfRangeBins() throws IOException { + Path path = tempDir.resolve("calibrator.json"); + Files.writeString(path, """ + {"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":0.1,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}} + """); + TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION"); + + assertThatThrownBy(() -> calibrator.calibrate("shortProb", new BigDecimal("0.05"))) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("target is missing"); + assertThatThrownBy(() -> calibrator.calibrate("longProb", new BigDecimal("0.50"))) + .isInstanceOf(com.quantai.trader.domain.TraderException.class) + .hasMessageContaining("does not match any"); + } +} diff --git a/src/test/java/com/quantai/trader/outbox/P0OutboxDispatcherTest.java b/src/test/java/com/quantai/trader/outbox/P0OutboxDispatcherTest.java new file mode 100644 index 0000000..6f67249 --- /dev/null +++ b/src/test/java/com/quantai/trader/outbox/P0OutboxDispatcherTest.java @@ -0,0 +1,82 @@ +package com.quantai.trader.outbox; + +import com.quantai.trader.domain.*; +import com.quantai.trader.enums.FeedbackSource; +import com.quantai.trader.enums.PositionSide; +import com.quantai.trader.enums.TraderActionType; +import com.quantai.trader.feedback.TraderFeedbackRepository; +import com.quantai.trader.replay.state.TraderPostActionStateRepository; +import com.quantai.trader.replay.state.TraderReplayState; +import com.quantai.trader.replay.state.TraderReplayStateStore; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.util.Map; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; + +class P0OutboxDispatcherTest { + private final TraderOutboxRepository outboxRepository = mock(TraderOutboxRepository.class); + private final TraderFeedbackRepository feedbackRepository = mock(TraderFeedbackRepository.class); + private final TraderReplayStateStore stateStore = mock(TraderReplayStateStore.class); + private final TraderPostActionStateRepository postActionStateRepository = mock(TraderPostActionStateRepository.class); + private final P0OutboxDispatcher dispatcher = new P0OutboxDispatcher( + outboxRepository, feedbackRepository, stateStore, postActionStateRepository); + + @Test + void dispatchesShadowRecorderFeedbackMarksOutboxAndPersistsPostActionState() { + TraderAction action = action(TraderActionType.OPEN_LONG, PositionSide.LONG); + TraderMarketSnapshot snapshot = snapshot(); + TraderReplayState currentState = new TraderReplayState(flatPosition(), account(), execution()); + TraderReplayState nextState = new TraderReplayState(longPosition("0"), account("0", "0.20", "0.80"), execution()); + when(stateStore.advance(currentState, action, snapshot)).thenReturn(nextState); + + dispatcher.dispatch(action, currentState, snapshot, "SHADOW_RECORDER"); + + ArgumentCaptor feedback = ArgumentCaptor.forClass(TraderAppFeedback.class); + verify(feedbackRepository).insert(feedback.capture()); + assertThat(feedback.getValue().feedbackSource()).isEqualTo(FeedbackSource.SHADOW_APP); + assertThat(feedback.getValue().realFill()).isFalse(); + assertThat(feedback.getValue().orderStatus()).isEqualTo("RECORDED"); + assertThat(feedback.getValue().rawFeedbackJson()) + .containsEntry("destination", "SHADOW_RECORDER") + .containsEntry("actionType", "OPEN_LONG"); + verify(outboxRepository).markSent("outbox_" + action.actionId()); + verify(postActionStateRepository).insertPostActionState(nextState); + } + + @Test + void mapsReplaySimDestinationToReplaySimulatorFeedback() { + TraderAction action = action(TraderActionType.OPEN_LONG, PositionSide.LONG); + TraderReplayState state = new TraderReplayState(flatPosition(), account(), execution()); + when(stateStore.advance(state, action, snapshot())).thenReturn(state); + + dispatcher.dispatch(action, state, snapshot(), "REPLAY_SIM_EXECUTION"); + + ArgumentCaptor feedback = ArgumentCaptor.forClass(TraderAppFeedback.class); + verify(feedbackRepository).insert(feedback.capture()); + assertThat(feedback.getValue().feedbackSource()).isEqualTo(FeedbackSource.REPLAY_SIMULATOR); + } + + @Test + void rejectsPaperOrRealDestinationsInP0Dispatcher() { + TraderAction action = action(TraderActionType.OPEN_LONG, PositionSide.LONG); + TraderReplayState state = new TraderReplayState(flatPosition(), account(), execution()); + + assertThatThrownBy(() -> dispatcher.dispatch(action, state, snapshot(), "PAPER_EXECUTION")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("P0 outbox destination is not allowed"); + + verifyNoInteractions(feedbackRepository, outboxRepository, postActionStateRepository); + } + + private TraderAction action(TraderActionType type, PositionSide side) { + TraderRiskDecision riskDecision = new TraderRiskDecision( + "risk-1", "run-1", "cycle-1", "pm-cycle-1", + true, type, type, null, Map.of()); + return new TraderActionFactory().create(pmDecision(type, side), riskDecision, "BTC-USDT-PERP"); + } +} diff --git a/src/test/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriterTest.java b/src/test/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriterTest.java index 886d755..91f5ab2 100644 --- a/src/test/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriterTest.java +++ b/src/test/java/com/quantai/trader/persistence/JdbcTraderDecisionTraceWriterTest.java @@ -26,9 +26,9 @@ class JdbcTraderDecisionTraceWriterTest { "pm-cycle-1", true, TraderActionType.OPEN_LONG, TraderActionType.OPEN_LONG, null, Map.of()); var action = new TraderActionFactory().create(pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG), riskDecision, "BTC-USDT-PERP"); - writer.persistCycleTrace(cycle(), snapshot(), modelOutput(), flatPosition(), + writer.persistCycleTrace(cycle(), snapshot(), modelOutput(), flatPosition(), account(), execution(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG), riskDecision, action); - verify(jdbcTemplate, times(6)).update(anyString(), any(Object[].class)); + verify(jdbcTemplate, times(10)).update(anyString(), any(Object[].class)); } } diff --git a/src/test/java/com/quantai/trader/position/TraderPositionManagerTest.java b/src/test/java/com/quantai/trader/position/TraderPositionManagerTest.java index 9e9e58c..6945362 100644 --- a/src/test/java/com/quantai/trader/position/TraderPositionManagerTest.java +++ b/src/test/java/com/quantai/trader/position/TraderPositionManagerTest.java @@ -1,7 +1,10 @@ package com.quantai.trader.position; +import com.quantai.trader.domain.PositionManagerInput; import com.quantai.trader.domain.TraderModelOutput; import com.quantai.trader.domain.TraderPositionManagerDecision; +import com.quantai.trader.domain.TraderPositionState; +import com.quantai.trader.enums.PositionSide; import com.quantai.trader.enums.TraderActionType; import org.junit.jupiter.api.Test; @@ -36,12 +39,13 @@ class TraderPositionManagerTest { @Test void waitsWhenDataOrLiquidityIsInsufficient() { TraderModelOutput noLiquidity = modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34", - "0.20", "0.50", "0.18", "0.12", "0", "12", "4", "0.10", "0.05", "20"); + "0.20", "0.50", "0.18", "0.12", "12", "4", "0.10", "0.05", "20"); - assertThat(positionManager.decide(pmInput(noLiquidity, flatPosition())).candidateAction()) + assertThat(positionManager.decide(pmInput(noLiquidity, flatPosition(), account(), + execution(java.util.List.of(), 10, 0, "0"))).candidateAction()) .isEqualTo(TraderActionType.WAIT); assertThat(positionManager.decide(new com.quantai.trader.domain.PositionManagerInput( - cycle(), snapshot(false, "1000"), modelOutput(), flatPosition(), account(), execution(), pmConfig())).candidateAction()) + cycle(), snapshot(false, "1000"), modelOutput(), pricePlan(), flatPosition(), account(), execution(), pmConfig())).candidateAction()) .isEqualTo(TraderActionType.WAIT); } @@ -55,6 +59,20 @@ class TraderPositionManagerTest { assertThat(hold.candidateAction()).isEqualTo(TraderActionType.HOLD); } + @Test + void holdsWhenAddCooldownHasNotElapsed() { + TraderPositionState coolingPosition = new TraderPositionState( + "position-state-1", "run-1", "cycle-1", "BTC-USDT-PERP", + PositionSide.LONG, bd("0.30"), bd("100"), bd("101"), bd("10"), bd("1000"), + 1, bd("0.40"), T0); + + TraderPositionManagerDecision decision = positionManager.decide(new PositionManagerInput( + cycle(), snapshot(), modelOutput(), pricePlan(), coolingPosition, account(), execution(), pmConfig())); + + assertThat(decision.candidateAction()).isEqualTo(TraderActionType.HOLD); + assertThat(decision.reason()).isEqualTo("CONTINUE_HOLD"); + } + @Test void usesPositionSideContinuationProbabilityForShortAdd() { TraderPositionManagerDecision decision = positionManager.decide(pmInput(shortModelOutput(), shortPosition("10"))); @@ -66,7 +84,7 @@ class TraderPositionManagerTest { @Test void closesExistingPositionWhenExitOrRiskSignalIsHigh() { TraderModelOutput exitHigh = modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34", - "0.80", "0.50", "0.18", "0.12", "1.0", "12", "4", "0.10", "0.05", "20"); + "0.80", "0.50", "0.18", "0.12", "12", "4", "0.10", "0.05", "20"); TraderPositionManagerDecision decision = positionManager.decide(pmInput(exitHigh, longPosition("10"))); @@ -77,7 +95,7 @@ class TraderPositionManagerTest { @Test void returnsZeroInitialRatioWhenExpectedEdgeIsBelowSizingFloor() { TraderModelOutput weakEdge = modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34", - "0.20", "0.50", "0.18", "0.12", "1.0", "0.5", "4", "0.10", "0.05", "20"); + "0.20", "0.50", "0.18", "0.12", "0.5", "4", "0.10", "0.05", "20"); BigDecimal ratio = positionManager.calculateInitialRatio(pmInput(weakEdge, flatPosition()), com.quantai.trader.enums.PositionSide.LONG); diff --git a/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java b/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java index cc01e12..1dc2a7a 100644 --- a/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java +++ b/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java @@ -2,23 +2,28 @@ package com.quantai.trader.replay; import com.quantai.trader.artifact.TraderArtifactLoader; import com.quantai.trader.domain.*; +import com.quantai.trader.enums.TraderExecutionMode; import com.quantai.trader.enums.TraderActionType; +import com.quantai.trader.enums.TraderRunMode; import com.quantai.trader.evidence.EvidenceAppender; import com.quantai.trader.evidence.TraderEvidenceRepository; -import com.quantai.trader.model.ArtifactTraderModelService; +import com.quantai.trader.model.ReplayFixtureTraderModelService; import com.quantai.trader.outbox.TraderOutboxEvent; +import com.quantai.trader.outbox.TraderOutboxDispatcher; import com.quantai.trader.outbox.TraderOutboxRepository; import com.quantai.trader.persistence.TraderDecisionTraceWriter; import com.quantai.trader.position.TraderPositionManager; import com.quantai.trader.replay.state.P0ReplayStateStore; +import com.quantai.trader.replay.state.TraderReplayState; import com.quantai.trader.risk.TraderRiskGate; +import com.quantai.trader.runtime.TraderRuntimeControlDecision; +import com.quantai.trader.runtime.TraderRuntimeControlService; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.util.ArrayList; import java.util.List; import java.io.IOException; -import java.math.BigDecimal; import java.nio.file.Path; import static com.quantai.trader.TestFixtures.*; @@ -35,28 +40,31 @@ class TraderP0CycleRunnerTest { EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository); RecordingTraceWriter traceWriter = new RecordingTraceWriter(); RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository(); - var properties = propertiesWithArtifactRoot(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM); + P0ReplayStateStore stateStore = new P0ReplayStateStore(properties); TraderP0CycleRunner runner = new TraderP0CycleRunner( properties, new TraderArtifactLoader(properties, objectMapper()), - new ArtifactTraderModelService(), + new ReplayFixtureTraderModelService(properties), new TraderPositionManager(), new TraderRiskGate(), new TraderActionFactory(), evidenceAppender, traceWriter, + bundle -> { + }, outboxRepository, - new P0ReplayStateStore(properties)); + new ImmediateDispatcher(stateStore), + stateStore, + new AllowRuntimeControlService()); - TraderCycleResult result = runner.runCycle(new ReplayMarketEvent( - "run-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"), - new BigDecimal("1.2"), new BigDecimal("1000"))); + TraderCycleResult result = runner.runCycle(replayEvent("run-1", T0, "100", "99.5", "1000")); assertThat(result.action().actionType()).isEqualTo(TraderActionType.OPEN_LONG); assertThat(result.action().reduceOnly()).isFalse(); assertThat(traceWriter.actions()).containsExactly(result.action()); assertThat(outboxRepository.events()).hasSize(1); - assertThat(outboxRepository.events().getFirst().destination()).isEqualTo("SHADOW_RECORDER"); + assertThat(outboxRepository.events().getFirst().destination()).isEqualTo("REPLAY_SIM_EXECUTION"); assertThat(evidenceRepository.items()).extracting("stage") .containsExactly("MARKET_SNAPSHOT", "MODEL_OUTPUT", "PM_DECISION", "RISK_DECISION"); } @@ -68,22 +76,25 @@ class TraderP0CycleRunnerTest { EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository); RecordingTraceWriter traceWriter = new RecordingTraceWriter(); RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository(); - var properties = propertiesWithArtifactRoot(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM); + P0ReplayStateStore stateStore = new P0ReplayStateStore(properties); TraderP0CycleRunner runner = new TraderP0CycleRunner( properties, new TraderArtifactLoader(properties, objectMapper()), - new ArtifactTraderModelService(), + new ReplayFixtureTraderModelService(properties), new TraderPositionManager(), new TraderRiskGate(), new TraderActionFactory(), evidenceAppender, traceWriter, + bundle -> { + }, outboxRepository, - new P0ReplayStateStore(properties)); + new ImmediateDispatcher(stateStore), + stateStore, + new AllowRuntimeControlService()); - TraderCycleResult result = runner.runCycle(new ReplayMarketEvent( - "run-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("100"), new BigDecimal("99.5"), - new BigDecimal("1.2"), BigDecimal.ZERO)); + TraderCycleResult result = runner.runCycle(replayEvent("run-1", T0.plusSeconds(60), "100", "99.5", "0")); assertThat(result.action().actionType()).isEqualTo(TraderActionType.WAIT); assertThat(result.action().pricePlanId()).isNull(); @@ -99,25 +110,26 @@ class TraderP0CycleRunnerTest { EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository); RecordingTraceWriter traceWriter = new RecordingTraceWriter(); RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository(); - var properties = propertiesWithArtifactRoot(artifactRoot); + var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM); + P0ReplayStateStore stateStore = new P0ReplayStateStore(properties); TraderP0CycleRunner runner = new TraderP0CycleRunner( properties, new TraderArtifactLoader(properties, objectMapper()), - new ArtifactTraderModelService(), + new ReplayFixtureTraderModelService(properties), new TraderPositionManager(), new TraderRiskGate(), new TraderActionFactory(), evidenceAppender, traceWriter, + bundle -> { + }, outboxRepository, - new P0ReplayStateStore(properties)); + new ImmediateDispatcher(stateStore), + stateStore, + new AllowRuntimeControlService()); - TraderCycleResult first = runner.runCycle(new ReplayMarketEvent( - "run-state-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"), - new BigDecimal("1.2"), new BigDecimal("1000"))); - TraderCycleResult second = runner.runCycle(new ReplayMarketEvent( - "run-state-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("101"), new BigDecimal("100.5"), - new BigDecimal("1.2"), new BigDecimal("1000"))); + TraderCycleResult first = runner.runCycle(replayEvent("run-state-1", T0, "100", "99.5", "1000")); + TraderCycleResult second = runner.runCycle(replayEvent("run-state-1", T0.plusSeconds(60), "101", "100.5", "1000")); assertThat(first.action().actionType()).isEqualTo(TraderActionType.OPEN_LONG); assertThat(second.action().actionType()).isEqualTo(TraderActionType.ADD_LONG); @@ -147,18 +159,41 @@ class TraderP0CycleRunnerTest { events.add(event); } + @Override + public void markSent(String outboxId) { + } + + @Override + public boolean hasUnsentExposureIncrease(String runId, String symbol, com.quantai.trader.enums.PositionSide side) { + return false; + } + List events() { return events; } } + private static final class ImmediateDispatcher implements TraderOutboxDispatcher { + private final P0ReplayStateStore stateStore; + + private ImmediateDispatcher(P0ReplayStateStore stateStore) { + this.stateStore = stateStore; + } + + @Override + public void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination) { + stateStore.advance(currentState, action, snapshot); + } + } + private static final class RecordingTraceWriter implements TraderDecisionTraceWriter { private final List actions = new ArrayList<>(); private final List positionStates = new ArrayList<>(); @Override public void persistCycleTrace(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderModelOutput modelOutput, - TraderPositionState positionState, TraderPositionManagerDecision pmDecision, + TraderPositionState positionState, TraderAccountState accountState, + TraderExecutionState executionState, TraderPositionManagerDecision pmDecision, TraderRiskDecision riskDecision, TraderAction action) { actions.add(action); positionStates.add(positionState); @@ -172,4 +207,20 @@ class TraderP0CycleRunnerTest { return positionStates; } } + + private static final class AllowRuntimeControlService implements TraderRuntimeControlService { + @Override + public void acquireCycleLock(TraderDecisionCycle cycle) { + } + + @Override + public void releaseCycleLock(TraderDecisionCycle cycle) { + } + + @Override + public TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle, + TraderPositionManagerDecision decision) { + return TraderRuntimeControlDecision.allow(); + } + } } diff --git a/src/test/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepositoryTest.java b/src/test/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepositoryTest.java new file mode 100644 index 0000000..0343f99 --- /dev/null +++ b/src/test/java/com/quantai/trader/replay/state/JdbcTraderPostActionStateRepositoryTest.java @@ -0,0 +1,37 @@ +package com.quantai.trader.replay.state; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.jdbc.core.JdbcTemplate; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.contains; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +class JdbcTraderPostActionStateRepositoryTest { + @Test + void insertsPostActionPositionAccountAndExecutionStateWithPostIds() { + JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class); + JdbcTraderPostActionStateRepository repository = new JdbcTraderPostActionStateRepository(jdbcTemplate, objectMapper()); + TraderReplayState state = new TraderReplayState(longPosition("12"), account(), executionWithOpenOrder()); + + repository.insertPostActionState(state); + + ArgumentCaptor positionArgs = ArgumentCaptor.forClass(Object[].class); + ArgumentCaptor accountArgs = ArgumentCaptor.forClass(Object[].class); + ArgumentCaptor executionArgs = ArgumentCaptor.forClass(Object[].class); + verify(jdbcTemplate).update(contains("insert into trader_position_state"), positionArgs.capture()); + verify(jdbcTemplate).update(contains("insert into trader_account_state"), accountArgs.capture()); + verify(jdbcTemplate).update(contains("insert into trader_execution_state"), executionArgs.capture()); + verifyNoMoreInteractions(jdbcTemplate); + + assertThat(positionArgs.getValue()[2]).isEqualTo("position-state-1_post"); + assertThat(positionArgs.getValue()[4]).isEqualTo("LONG"); + assertThat(accountArgs.getValue()[2]).isEqualTo("account-state-1_post"); + assertThat(executionArgs.getValue()[2]).isEqualTo("execution-state-1_post"); + assertThat(executionArgs.getValue()[4].toString()).contains("order-1"); + } +} diff --git a/src/test/java/com/quantai/trader/risk/TraderRiskGateTest.java b/src/test/java/com/quantai/trader/risk/TraderRiskGateTest.java index d32e098..f72fd9a 100644 --- a/src/test/java/com/quantai/trader/risk/TraderRiskGateTest.java +++ b/src/test/java/com/quantai/trader/risk/TraderRiskGateTest.java @@ -14,7 +14,7 @@ class TraderRiskGateTest { @Test void killSwitchBlocksOnlyExposureIncreasingActions() { - RiskLimits limits = new RiskLimits(bd("200"), java.math.BigDecimal.ONE, bd("500"), 3, 1500, true, false); + RiskLimits limits = new RiskLimits(bd("200"), java.math.BigDecimal.ONE, bd("500"), 3, 1500, true, false, null); TraderRiskDecision open = riskGate.evaluate(new RiskGateInput( pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG), flatPosition(), account(), execution(), snapshot(), limits)); diff --git a/src/test/java/com/quantai/trader/runtime/RedisTraderRuntimeControlServiceTest.java b/src/test/java/com/quantai/trader/runtime/RedisTraderRuntimeControlServiceTest.java new file mode 100644 index 0000000..7bad31c --- /dev/null +++ b/src/test/java/com/quantai/trader/runtime/RedisTraderRuntimeControlServiceTest.java @@ -0,0 +1,95 @@ +package com.quantai.trader.runtime; + +import com.quantai.trader.domain.TraderException; +import com.quantai.trader.enums.PositionSide; +import com.quantai.trader.enums.TraderActionType; +import com.quantai.trader.enums.TraderErrorCode; +import org.junit.jupiter.api.Test; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.data.redis.core.ValueOperations; + +import java.time.Duration; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +class RedisTraderRuntimeControlServiceTest { + private final StringRedisTemplate redisTemplate = mock(StringRedisTemplate.class); + @SuppressWarnings("unchecked") + private final ValueOperations valueOperations = mock(ValueOperations.class); + private final RedisTraderRuntimeControlService service = new RedisTraderRuntimeControlService(properties(), redisTemplate); + + @Test + void acquiresAndReleasesCycleLockWithDeterministicRedisKey() { + when(redisTemplate.opsForValue()).thenReturn(valueOperations); + String lockKey = "trader:v4:test:cycle:lock:BTC-USDT-PERP:" + T0.toEpochMilli(); + when(valueOperations.setIfAbsent(eq(lockKey), eq("cycle-1"), eq(Duration.ofSeconds(30)))).thenReturn(true); + + service.acquireCycleLock(cycle()); + service.releaseCycleLock(cycle()); + + verify(valueOperations).setIfAbsent(lockKey, "cycle-1", Duration.ofSeconds(30)); + verify(redisTemplate).delete(lockKey); + } + + @Test + void reportsExistingCycleLockAsRuntimeBlockNotRedisOutage() { + when(redisTemplate.opsForValue()).thenReturn(valueOperations); + when(valueOperations.setIfAbsent(anyString(), anyString(), any(Duration.class))).thenReturn(false); + + assertThatThrownBy(() -> service.acquireCycleLock(cycle())) + .isInstanceOfSatisfying(TraderException.class, exception -> { + assertThat(exception.code()).isEqualTo(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED); + assertThat(exception.getMessage()).contains("cycle lock already exists"); + }); + } + + @Test + void allowsExposureIncreaseWhenActivePointerAndRuntimeSwitchesPass() { + when(redisTemplate.opsForValue()).thenReturn(valueOperations); + when(valueOperations.get("trader:v4:test:model:active:BTC-USDT-PERP")) + .thenReturn("trader-v4-btc-p0|cal-v4-btc-p0|pm-v4-btc-p0"); + when(redisTemplate.hasKey(anyString())).thenReturn(false); + + TraderRuntimeControlDecision decision = service.validateExposureIncrease( + cycle(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG)); + + assertThat(decision.allowed()).isTrue(); + verify(valueOperations).set("trader:v4:test:runtime:open-add-probe:cycle-1", "1", Duration.ofSeconds(30)); + } + + @Test + void blocksExposureIncreaseWhenActivePointerIsMissing() { + when(redisTemplate.opsForValue()).thenReturn(valueOperations); + when(valueOperations.get("trader:v4:test:model:active:BTC-USDT-PERP")).thenReturn(null); + + TraderRuntimeControlDecision decision = service.validateExposureIncrease( + cycle(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG)); + + assertThat(decision.allowed()).isFalse(); + assertThat(decision.blocker()).isEqualTo("ACTIVE_POINTER_MISSING"); + } + + @Test + void skipsRedisChecksWhenActionDoesNotIncreaseExposure() { + TraderRuntimeControlDecision decision = service.validateExposureIncrease( + cycle(), pmDecision(TraderActionType.HOLD, PositionSide.LONG)); + + assertThat(decision.allowed()).isTrue(); + verifyNoInteractions(redisTemplate); + } + + @Test + void blocksExposureIncreaseWhenRedisFailsDuringValidation() { + when(redisTemplate.opsForValue()).thenThrow(new RuntimeException("redis down")); + + TraderRuntimeControlDecision decision = service.validateExposureIncrease( + cycle(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG)); + + assertThat(decision.allowed()).isFalse(); + assertThat(decision.blocker()).isEqualTo("REDIS_UNAVAILABLE"); + } +} diff --git a/src/test/java/com/quantai/trader/runtime/StartupValidationRunnerTest.java b/src/test/java/com/quantai/trader/runtime/StartupValidationRunnerTest.java new file mode 100644 index 0000000..8256e6f --- /dev/null +++ b/src/test/java/com/quantai/trader/runtime/StartupValidationRunnerTest.java @@ -0,0 +1,18 @@ +package com.quantai.trader.runtime; + +import org.junit.jupiter.api.Test; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class StartupValidationRunnerTest { + @Test + void delegatesStartupValidationToRuntimeGuard() { + P0RuntimeGuard guard = mock(P0RuntimeGuard.class); + StartupValidationRunner runner = new StartupValidationRunner(guard); + + runner.run(null); + + verify(guard).validateStartup(); + } +} diff --git a/training/README.md b/training/README.md new file mode 100644 index 0000000..8ee6952 --- /dev/null +++ b/training/README.md @@ -0,0 +1,33 @@ +# Trader V4 Training Pipeline + +This directory contains the executable training chain for Trader V4. +Large data stays under `/Users/zach/Desktop/quant-strategy-training-data`. + +Run order: + +```bash +PY=/Users/zach/IdeaProjects/quant-trading-ai/quant-strategy-server/.venv/bin/python +RUN_ID=btc-v4-p0-001 +ROOT=/Users/zach/Desktop/quant-strategy-training-data + +$PY training/scripts/01_audit_source_data.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19 +$PY training/scripts/02_build_replay_1m.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19 +$PY training/scripts/03_build_splits.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/04_build_feature_frame.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/05_build_price_plan_context.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/06_build_direction_labels.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/07_build_entry_labels.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/08_build_position_state_samples.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/09_build_continue_exit_risk_labels.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/10_build_train_datasets.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/11_train_small_models.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/12_calibrate_models.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/13_search_pm_thresholds.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/14_integrated_backtest.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/15_export_artifact_bundle.py --run-id $RUN_ID --data-root $ROOT +$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle +$PY training/scripts/17_promote_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --reason "validation_locked and latest_stress passed for SHADOW" +$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --require-active --run-onnx +``` + +Java SHADOW 只加载 `ACTIVE` 包。15 号脚本永远只生成 `CANDIDATE`,16 号校验通过且上线门槛通过后,17 号脚本才允许把包提升为 `ACTIVE`。 diff --git a/training/requirements.txt b/training/requirements.txt new file mode 100644 index 0000000..f2593cf --- /dev/null +++ b/training/requirements.txt @@ -0,0 +1,7 @@ +pandas==2.2.3 +pyarrow==24.0.0 +numpy==2.4.6 +scikit-learn==1.7.0 +scipy==1.18.0 +onnx==1.22.0 +onnxruntime==1.27.0 diff --git a/training/scripts/01_audit_source_data.py b/training/scripts/01_audit_source_data.py new file mode 100644 index 0000000..594f665 --- /dev/null +++ b/training/scripts/01_audit_source_data.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import argparse + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.replay import write_audit_outputs + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--symbol", default="BTC-USDT-PERP") + parser.add_argument("--start-date") + parser.add_argument("--end-date") + parser.add_argument("--min-ready-days", type=int, default=250) + args = parser.parse_args() + setup_logging() + write_audit_outputs(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/02_build_replay_1m.py b/training/scripts/02_build_replay_1m.py new file mode 100644 index 0000000..758fd82 --- /dev/null +++ b/training/scripts/02_build_replay_1m.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.replay import build_replay_1m + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--raw-root", type=Path) + parser.add_argument("--symbol", default="BTC-USDT-PERP") + parser.add_argument("--start-date") + parser.add_argument("--end-date") + parser.add_argument("--min-minutes-per-day", type=int, default=1400) + parser.add_argument("--min-replay-ready-days", type=int, default=250) + args = parser.parse_args() + setup_logging() + build_replay_1m(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/03_build_splits.py b/training/scripts/03_build_splits.py new file mode 100644 index 0000000..7418d3e --- /dev/null +++ b/training/scripts/03_build_splits.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.replay import build_splits + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--replay-path", type=Path) + parser.add_argument("--fit-inner-start", default="2025-06-20") + parser.add_argument("--fit-inner-end", default="2026-01-15") + parser.add_argument("--tune-inner-start", default="2026-01-16") + parser.add_argument("--tune-inner-end", default="2026-02-28") + parser.add_argument("--validation-locked-start", default="2026-03-01") + parser.add_argument("--validation-locked-end", default="2026-04-30") + parser.add_argument("--latest-stress-start", default="2026-05-01") + parser.add_argument("--latest-stress-end", default="2026-06-19") + parser.add_argument("--gap-minutes", type=int, default=60) + parser.add_argument("--fold-count", type=int, default=3) + args = parser.parse_args() + setup_logging() + build_splits(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/04_build_feature_frame.py b/training/scripts/04_build_feature_frame.py new file mode 100644 index 0000000..800b6b2 --- /dev/null +++ b/training/scripts/04_build_feature_frame.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.features import build_feature_frame +from trader_training.io_utils import add_common_args, setup_logging + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--replay-path", type=Path) + parser.add_argument("--split-manifest-path", type=Path) + parser.add_argument("--allow-incomplete-days", action="store_true") + args = parser.parse_args() + setup_logging() + build_feature_frame(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/05_build_price_plan_context.py b/training/scripts/05_build_price_plan_context.py new file mode 100644 index 0000000..b5eb5b0 --- /dev/null +++ b/training/scripts/05_build_price_plan_context.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.labels import write_price_plan_context + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--label-config-path", type=Path) + parser.add_argument("--cost-config-path", type=Path) + parser.add_argument("--price-plan-id", default="btc-p0-plan-45m") + args = parser.parse_args() + setup_logging() + write_price_plan_context(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/06_build_direction_labels.py b/training/scripts/06_build_direction_labels.py new file mode 100644 index 0000000..fdff91e --- /dev/null +++ b/training/scripts/06_build_direction_labels.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.labels import build_direction_labels + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--feature-path", type=Path) + parser.add_argument("--replay-path", type=Path) + parser.add_argument("--label-config-path", type=Path) + args = parser.parse_args() + setup_logging() + build_direction_labels(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/07_build_entry_labels.py b/training/scripts/07_build_entry_labels.py new file mode 100644 index 0000000..74463d8 --- /dev/null +++ b/training/scripts/07_build_entry_labels.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.labels import build_entry_labels + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--feature-path", type=Path) + parser.add_argument("--replay-path", type=Path) + parser.add_argument("--label-config-path", type=Path) + parser.add_argument("--cost-config-path", type=Path) + parser.add_argument("--price-plan-context-path", type=Path) + args = parser.parse_args() + setup_logging() + build_entry_labels(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/08_build_position_state_samples.py b/training/scripts/08_build_position_state_samples.py new file mode 100644 index 0000000..e106050 --- /dev/null +++ b/training/scripts/08_build_position_state_samples.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.labels import build_position_state_samples + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--entry-label-path", type=Path) + args = parser.parse_args() + setup_logging() + build_position_state_samples(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/09_build_continue_exit_risk_labels.py b/training/scripts/09_build_continue_exit_risk_labels.py new file mode 100644 index 0000000..2c37036 --- /dev/null +++ b/training/scripts/09_build_continue_exit_risk_labels.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.labels import build_continue_exit_risk_labels + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--feature-path", type=Path) + parser.add_argument("--replay-path", type=Path) + parser.add_argument("--label-config-path", type=Path) + parser.add_argument("--cost-config-path", type=Path) + parser.add_argument("--price-plan-context-path", type=Path) + args = parser.parse_args() + setup_logging() + build_continue_exit_risk_labels(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/10_build_train_datasets.py b/training/scripts/10_build_train_datasets.py new file mode 100644 index 0000000..3018bce --- /dev/null +++ b/training/scripts/10_build_train_datasets.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.datasets import build_train_datasets +from trader_training.io_utils import add_common_args, setup_logging + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--feature-path", type=Path) + parser.add_argument("--direction-label-path", type=Path) + parser.add_argument("--entry-label-path", type=Path) + parser.add_argument("--continue-label-path", type=Path) + parser.add_argument("--exit-label-path", type=Path) + parser.add_argument("--risk-label-path", type=Path) + args = parser.parse_args() + setup_logging() + build_train_datasets(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/11_train_small_models.py b/training/scripts/11_train_small_models.py new file mode 100644 index 0000000..3037fd0 --- /dev/null +++ b/training/scripts/11_train_small_models.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import argparse + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.training import train_small_models + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--max-rows", type=int, default=0) + args = parser.parse_args() + setup_logging() + train_small_models(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/12_calibrate_models.py b/training/scripts/12_calibrate_models.py new file mode 100644 index 0000000..d7485fd --- /dev/null +++ b/training/scripts/12_calibrate_models.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import argparse + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.training import build_calibrators + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + args = parser.parse_args() + setup_logging() + build_calibrators(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/13_search_pm_thresholds.py b/training/scripts/13_search_pm_thresholds.py new file mode 100644 index 0000000..783be61 --- /dev/null +++ b/training/scripts/13_search_pm_thresholds.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import argparse + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.pm import search_pm_thresholds + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + args = parser.parse_args() + setup_logging() + search_pm_thresholds(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/14_integrated_backtest.py b/training/scripts/14_integrated_backtest.py new file mode 100644 index 0000000..afd023b --- /dev/null +++ b/training/scripts/14_integrated_backtest.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import argparse + +import _bootstrap # noqa: F401 +from trader_training.io_utils import add_common_args, setup_logging +from trader_training.pm import integrated_backtest + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + args = parser.parse_args() + setup_logging() + integrated_backtest(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/15_export_artifact_bundle.py b/training/scripts/15_export_artifact_bundle.py new file mode 100644 index 0000000..b6dd780 --- /dev/null +++ b/training/scripts/15_export_artifact_bundle.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.exporter import export_artifact_bundle +from trader_training.io_utils import add_common_args, setup_logging + + +def main() -> None: + parser = argparse.ArgumentParser() + add_common_args(parser) + parser.add_argument("--export-root", type=Path) + args = parser.parse_args() + setup_logging() + export_artifact_bundle(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/16_validate_artifact_bundle.py b/training/scripts/16_validate_artifact_bundle.py new file mode 100644 index 0000000..b2d32bf --- /dev/null +++ b/training/scripts/16_validate_artifact_bundle.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import setup_logging +from trader_training.validator import validate_artifact_bundle + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--artifact-root", type=Path, required=True) + parser.add_argument("--require-active", action="store_true") + parser.add_argument("--run-onnx", action="store_true") + args = parser.parse_args() + setup_logging() + validate_artifact_bundle(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/17_promote_artifact_bundle.py b/training/scripts/17_promote_artifact_bundle.py new file mode 100644 index 0000000..e4c054c --- /dev/null +++ b/training/scripts/17_promote_artifact_bundle.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import _bootstrap # noqa: F401 +from trader_training.io_utils import setup_logging +from trader_training.promote import promote_artifact_bundle + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--artifact-root", type=Path, required=True) + parser.add_argument("--reason", required=True) + args = parser.parse_args() + setup_logging() + promote_artifact_bundle(args) + + +if __name__ == "__main__": + main() diff --git a/training/scripts/_bootstrap.py b/training/scripts/_bootstrap.py new file mode 100644 index 0000000..636aafb --- /dev/null +++ b/training/scripts/_bootstrap.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +TRAINING_ROOT = Path(__file__).resolve().parents[1] +if str(TRAINING_ROOT) not in sys.path: + sys.path.insert(0, str(TRAINING_ROOT)) diff --git a/training/tests/test_training_contract.py b/training/tests/test_training_contract.py new file mode 100644 index 0000000..0a28e43 --- /dev/null +++ b/training/tests/test_training_contract.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import sys +import tempfile +import unittest +from argparse import Namespace +from pathlib import Path + +import numpy as np +import pandas as pd + +TRAINING_ROOT = Path(__file__).resolve().parents[1] +if str(TRAINING_ROOT) not in sys.path: + sys.path.insert(0, str(TRAINING_ROOT)) + +from trader_training.onnx_export import LinearHead, export_heads +from trader_training.io_utils import read_json, write_json +from trader_training.promote import promote_artifact_bundle +from trader_training.replay import build_splits +from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT + + +class TrainingContractTest(unittest.TestCase): + def test_feature_order_is_v4_contract_size(self) -> None: + self.assertEqual(39, len(FEATURE_ORDER)) + self.assertEqual(len(FEATURE_ORDER), len(set(FEATURE_ORDER))) + self.assertEqual("ret_1m_bps", FEATURE_ORDER[0]) + self.assertEqual("minutes_to_next_funding", FEATURE_ORDER[-1]) + + def test_output_mapping_matches_model_outputs(self) -> None: + for model_name, fields in MODEL_OUTPUTS.items(): + self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name])) + self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields]) + + def test_split_builder_uses_locked_validation_contract(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + data_root = Path(tmp) + replay_path = data_root / "replay_1m.parquet" + frame = pd.DataFrame( + { + "event_time": pd.date_range("2025-06-20", "2026-06-19", freq="D", tz="UTC"), + "symbol": "BTC-USDT-PERP", + } + ) + frame.to_parquet(replay_path, index=False) + + build_splits( + Namespace( + data_root=data_root, + run_id="unit-split", + replay_path=replay_path, + fit_inner_start="2025-06-20", + fit_inner_end="2026-01-15", + tune_inner_start="2026-01-16", + tune_inner_end="2026-02-28", + validation_locked_start="2026-03-01", + validation_locked_end="2026-04-30", + latest_stress_start="2026-05-01", + latest_stress_end="2026-06-19", + gap_minutes=0, + fold_count=2, + ) + ) + + manifest = read_json(data_root / "trader-v4" / "runs" / "unit-split" / "split" / "split_manifest.json") + self.assertEqual(set(TRAINING_SPLITS), {item["split_id"] for item in manifest["splits"]}) + self.assertEqual([VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], manifest["sealed_splits"]) + self.assertEqual("FINAL_GATE_ONLY", manifest["latest_stress_policy"]) + + def test_exported_onnx_accepts_java_feature_shape(self) -> None: + import onnxruntime as ort + + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "direction.onnx" + export_heads( + path, + [ + LinearHead( + "direction", + "softmax", + np.zeros((39, 3), dtype=np.float32), + np.array([0.1, 0.2, 0.3], dtype=np.float32), + ) + ], + ) + session = ort.InferenceSession(str(path)) + output = session.run(None, {"features": np.zeros((1, 39), dtype=np.float32)})[0] + self.assertEqual((1, 3), output.shape) + self.assertAlmostEqual(1.0, float(output.sum()), places=6) + + def test_promotion_requires_passed_validation_and_marks_all_manifests_active(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) / "artifact_bundle" + manifest_dir = root / "manifests" + manifest_dir.mkdir(parents=True) + write_json(root.parent / "artifact_validation_result.json", {"status": "PASS", "release_gate_status": "PASS", "release_gate_reasons": []}) + write_json(manifest_dir / "model_bundle_manifest.json", {"status": "CANDIDATE"}) + write_json(manifest_dir / "model_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}]) + write_json(manifest_dir / "calibration_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}]) + write_json(manifest_dir / "position_manager_manifest.json", {"status": "CANDIDATE"}) + write_json(manifest_dir / "training_export_manifest.json", {"status": "CANDIDATE"}) + + promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test")) + + self.assertEqual("ACTIVE", read_json(manifest_dir / "model_bundle_manifest.json")["status"]) + self.assertEqual("ACTIVE", read_json(manifest_dir / "model_manifest.json")[0]["status"]) + self.assertEqual("ACTIVE", read_json(manifest_dir / "calibration_manifest.json")[0]["status"]) + self.assertEqual("ACTIVE", read_json(manifest_dir / "position_manager_manifest.json")["status"]) + self.assertEqual("ACTIVE", read_json(manifest_dir / "training_export_manifest.json")["status"]) + + def test_promotion_refuses_failed_validation(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) / "artifact_bundle" + (root / "manifests").mkdir(parents=True) + write_json(root.parent / "artifact_validation_result.json", {"status": "FAIL"}) + with self.assertRaises(SystemExit): + promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test")) + result = read_json(root.parent / "artifact_promotion_result.json") + self.assertEqual("REFUSED", result["status"]) + self.assertEqual("validation result is not PASS", result["message"]) + + def test_promotion_refuses_failed_release_gate_and_overwrites_stale_result(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) / "artifact_bundle" + (root / "manifests").mkdir(parents=True) + write_json(root.parent / "artifact_promotion_result.json", {"status": "ACTIVE"}) + write_json( + root.parent / "artifact_validation_result.json", + { + "status": "PASS", + "release_gate_status": "REJECTED", + "release_gate_reasons": ["backtest_status=REJECTED"], + }, + ) + + with self.assertRaises(SystemExit): + promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test")) + + result = read_json(root.parent / "artifact_promotion_result.json") + self.assertEqual("REFUSED", result["status"]) + self.assertEqual("release gate is not PASS", result["message"]) + self.assertEqual(["backtest_status=REJECTED"], result["release_gate_reasons"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/training/trader_training/__init__.py b/training/trader_training/__init__.py new file mode 100644 index 0000000..e05e30e --- /dev/null +++ b/training/trader_training/__init__.py @@ -0,0 +1 @@ +"""Trader V4 training pipeline.""" diff --git a/training/trader_training/datasets.py b/training/trader_training/datasets.py new file mode 100644 index 0000000..5ea016c --- /dev/null +++ b/training/trader_training/datasets.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import logging +from typing import Any + +import pandas as pd + +from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text +from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS + + +def _feature_base(root, args) -> pd.DataFrame: + path = args.feature_path or root / "feature" / "feature_frame.parquet" + frame = read_parquet(path) + require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame") + trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy() + trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy() + if trainable.empty: + raise ValueError("no trainable feature rows; check feature quality report") + return trainable + + +def build_train_datasets(args: Any) -> None: + root = run_root(args) + feature = _feature_base(root, args) + dataset_dir = root / "dataset" + manifests = {} + + direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet") + direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner") + manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds) + + entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet") + entry_pivot = _entry_pivot(entry) + entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner") + manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds) + + continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet") + continue_ds = feature.merge( + continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]], + on="sample_id", + how="inner", + ) + manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds) + + exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet") + exit_cols = [ + "sample_id", + "long_exit_target", + "short_exit_target", + "long_adverse_move_bps", + "short_adverse_move_bps", + "adverse_move_prob_label", + "reversal_prob_label", + "stop_hit_prob_label", + "stagnation_prob_label", + ] + exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner") + manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds) + + risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet") + risk_cols = [ + "sample_id", + "market_risk_target", + "market_path_risk_bps", + "long_position_path_risk_bps", + "short_position_path_risk_bps", + "long_position_risk_target", + "short_position_risk_target", + "market_drawdown_prob_label", + "volatility_expansion_prob_label", + "spike_prob_label", + "liquidity_deterioration_prob_label", + "position_drawdown_prob_label", + ] + risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner") + manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds) + + write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests}) + _write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests) + logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests)) + + +def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame: + require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels") + long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename( + columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"} + ) + short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename( + columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"} + ) + return long.merge(short, on="sample_id", how="inner") + + +def _write_dataset(path, frame: pd.DataFrame) -> dict: + data_hash = write_parquet(path, frame) + return manifest( + path, + { + "row_count": len(frame), + "feature_count": len(FEATURE_ORDER), + "data_hash_sha256": data_hash, + "split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {}, + }, + ) + + +def _write_dataset_report(path, manifests: dict) -> None: + lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"] + for name, item in manifests.items(): + lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |") + write_text(path, "\n".join(lines) + "\n") diff --git a/training/trader_training/exporter.py b/training/trader_training/exporter.py new file mode 100644 index 0000000..c715816 --- /dev/null +++ b/training/trader_training/exporter.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import logging +import shutil +from pathlib import Path +from typing import Any + +import pandas as pd + +from trader_training.io_utils import read_json, read_parquet, run_root, sha256_file, sha256_json, utc_now_text, write_json +from trader_training.pm import default_pm_config +from trader_training.schemas import ( + CALIBRATION_BUNDLE_VERSION, + FEATURE_ORDER, + FEATURE_VERSION, + FIT_SPLIT, + LABEL_VERSION, + MODEL_BUNDLE_VERSION, + MODEL_OUTPUTS, + OUTPUT_MAPPING, + OUTPUT_SCHEMA, + PM_CONFIG_VERSION, + SPLIT_VERSION, + TUNE_SPLIT, + VALIDATION_LOCKED_SPLIT, +) + + +REQUIRED_MODELS = ("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK") +EXPORT_STATUS = "CANDIDATE" + + +def export_artifact_bundle(args: Any) -> None: + root = run_root(args) + export_root = args.export_root or root / "export" / f"trader-model-bundle-{args.run_id}" + artifact_root = export_root / "artifact_bundle" + for folder in ("models", "schemas", "calibrators", "manifests", "examples"): + (artifact_root / folder).mkdir(parents=True, exist_ok=True) + + feature_schema_path = root / "feature" / "feature_schema.json" + feature_order_path = root / "feature" / "feature_order.json" + output_schema_path = artifact_root / "schemas" / "output_schema.json" + shutil.copy2(feature_schema_path, artifact_root / "schemas" / "feature_schema.json") + shutil.copy2(feature_order_path, artifact_root / "schemas" / "feature_order.json") + write_json(output_schema_path, OUTPUT_SCHEMA) + + sample_input = _sample_input(root) + write_json(artifact_root / "examples" / "sample_input.json", sample_input) + write_json(artifact_root / "examples" / "sample_output.json", _sample_output()) + + price_plan = read_json(root / "label" / "price_plan_context.json") + write_json(artifact_root / "price_plan_context.json", price_plan) + + model_manifest_rows = [] + calibration_manifest_rows = [] + model_hashes = {} + train_start = _split_time(root, FIT_SPLIT, "start") + train_end = _split_time(root, FIT_SPLIT, "end") + validation_start = _split_time(root, VALIDATION_LOCKED_SPLIT, "start") + validation_end = _split_time(root, VALIDATION_LOCKED_SPLIT, "end") + calibration_train_manifest = read_json(root / "calibration" / "calibration_train_manifest.json") + calibration_quality = {row["model_name"]: row for row in calibration_train_manifest.get("calibrators", [])} + for model_name in REQUIRED_MODELS: + src = root / "model" / model_name.lower() / f"{model_name.lower()}.onnx" + dst = artifact_root / "models" / f"{model_name.lower()}.onnx" + shutil.copy2(src, dst) + model_hashes[model_name] = sha256_file(dst) + cal_src = root / "calibration" / model_name.lower() / "calibrator.json" + cal_dst = artifact_root / "calibrators" / model_name.lower() / "calibrator.json" + cal_dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(cal_src, cal_dst) + cal_hash = sha256_file(cal_dst) + metrics = read_json(root / "model" / model_name.lower() / "model_train_result.json") + model_manifest_rows.append( + _model_manifest_row( + model_name, + f"models/{model_name.lower()}.onnx", + model_hashes[model_name], + sha256_file(artifact_root / "schemas" / "feature_schema.json"), + sha256_file(artifact_root / "schemas" / "feature_order.json"), + sha256_file(output_schema_path), + metrics.get("metrics", {}), + metrics.get("quality_status", "UNKNOWN"), + metrics.get("quality_reasons", []), + train_start, + train_end, + validation_start, + validation_end, + EXPORT_STATUS, + ) + ) + calibration_manifest_rows.append( + { + "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, + "model_bundle_version": MODEL_BUNDLE_VERSION, + "model_name": model_name, + "calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0", + "calibration_method": "BINNING", + "calibrator_path": f"calibrators/{model_name.lower()}/calibrator.json", + "calibrator_hash_sha256": cal_hash, + "calibration_window_from": _split_time(root, TUNE_SPLIT, "start"), + "calibration_window_to": _split_time(root, TUNE_SPLIT, "end"), + "calibration_metrics_json": {}, + "bucket_metrics_json": {}, + "output_after_calibration_schema_hash": sha256_file(output_schema_path), + "quality_status": calibration_quality.get(model_name, {}).get("quality_status", "UNKNOWN"), + "quality_reasons_json": calibration_quality.get(model_name, {}).get("quality_reasons", []), + "status": EXPORT_STATUS, + } + ) + + pm_payload = read_json(root / "pm-search" / "position_manager_config.json") if (root / "pm-search" / "position_manager_config.json").is_file() else {"config": default_pm_config(), "threshold_stability_json": {}} + pm_config = pm_payload["config"] + pm_hash = sha256_json(pm_config) + backtest_manifest = read_json(root / "backtest" / "backtest_manifest.json") + write_json( + artifact_root / "manifests" / "position_manager_manifest.json", + { + "pm_config_version": PM_CONFIG_VERSION, + "model_bundle_version": MODEL_BUNDLE_VERSION, + "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, + "threshold_stability_json": pm_payload.get("threshold_stability_json", {}), + "allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"], + "config_json": pm_config, + "config_hash_sha256": pm_hash, + "status": EXPORT_STATUS, + }, + ) + write_json(artifact_root / "manifests" / "model_manifest.json", model_manifest_rows) + write_json(artifact_root / "manifests" / "calibration_manifest.json", calibration_manifest_rows) + bundle_hash = sha256_json({"models": model_hashes, "feature_order_hash": sha256_file(artifact_root / "schemas" / "feature_order.json"), "pm_hash": pm_hash}) + write_json( + artifact_root / "manifests" / "model_bundle_manifest.json", + { + "manifest_schema_version": "trader-model-bundle-manifest-v4-p0", + "model_bundle_version": MODEL_BUNDLE_VERSION, + "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, + "feature_version": FEATURE_VERSION, + "label_version": LABEL_VERSION, + "split_version": SPLIT_VERSION, + "training_run_id": args.run_id, + "training_export_id": f"export-{args.run_id}", + "backtest_manifest_id": f"backtest-{args.run_id}", + "required_models_json": list(REQUIRED_MODELS), + "provided_models_json": list(REQUIRED_MODELS), + "missing_models_json": [], + "allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"], + "bundle_hash_sha256": bundle_hash, + "model_quality_status_json": {row["model_name"]: row["quality_status"] for row in model_manifest_rows}, + "calibration_quality_status_json": {row["model_name"]: row["quality_status"] for row in calibration_manifest_rows}, + "backtest_status": backtest_manifest.get("status"), + "backtest_status_reasons_json": backtest_manifest.get("status_reasons", []), + "backtest_metrics_json": backtest_manifest.get("metrics", {}), + "complete": True, + "status": EXPORT_STATUS, + }, + ) + write_json( + artifact_root / "manifests" / "training_export_manifest.json", + {"created_at": utc_now_text(), "status": EXPORT_STATUS, "artifact_root": str(artifact_root), "bundle_hash_sha256": bundle_hash}, + ) + logging.info("trader.training.artifact_exported runId=%s status=%s bundleHash=%s path=%s", args.run_id, EXPORT_STATUS, bundle_hash, artifact_root) + + +def _model_manifest_row( + model_name: str, + artifact_path: str, + artifact_hash: str, + feature_schema_hash: str, + feature_order_hash: str, + output_schema_hash: str, + metrics: dict, + quality_status: str, + quality_reasons: list[str], + train_start: str, + train_end: str, + validation_start: str, + validation_end: str, + status: str, +) -> dict: + return { + "model_bundle_version": MODEL_BUNDLE_VERSION, + "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, + "model_name": model_name, + "model_type": model_name, + "side": "BOTH", + "symbol_scope_json": ["BTC-USDT-PERP"], + "bar_interval": "1m", + "horizon_minutes": 45, + "model_format": "ONNX", + "model_runtime": "ONNX_RUNTIME_JAVA", + "model_runtime_version": "1.22.0", + "onnx_opset_version": 17, + "producer_name": "trader-training", + "producer_version": "v4-p0", + "feature_version": FEATURE_VERSION, + "feature_schema_path": "schemas/feature_schema.json", + "feature_schema_hash": feature_schema_hash, + "feature_order_path": "schemas/feature_order.json", + "feature_order_hash": feature_order_hash, + "input_tensor_name": "features", + "input_dtype": "FLOAT32", + "input_shape_json": {"features": len(FEATURE_ORDER), "batch": 1}, + "input_example_path": "examples/sample_input.json", + "output_schema_path": "schemas/output_schema.json", + "output_schema_hash": output_schema_hash, + "output_tensor_names_json": ["prediction"], + "output_mapping_json": OUTPUT_MAPPING[model_name], + "output_value_rules_json": {"clip_by_output_schema": True}, + "label_version": LABEL_VERSION, + "split_version": SPLIT_VERSION, + "training_fold": "fold_01", + "train_start": train_start, + "train_end": train_end, + "validation_start": validation_start, + "validation_end": validation_end, + "test_start": validation_start, + "test_end": validation_end, + "metrics_json": metrics, + "quality_status": quality_status, + "quality_reasons_json": quality_reasons, + "artifact_path": artifact_path, + "artifact_hash_sha256": artifact_hash, + "source_hash": artifact_hash, + "status": status, + } + + +def _sample_input(root) -> dict: + features = read_parquet(root / "feature" / "feature_frame.parquet") + row = features[features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].iloc[0] + return {feature: float(row[feature]) for feature in FEATURE_ORDER} + + +def _sample_output() -> dict: + return {model: {field: 0.0 for field in fields} for model, fields in MODEL_OUTPUTS.items()} + + +def _split_time(root, split_id: str, key: str) -> str: + manifest = read_json(root / "split" / "split_manifest.json") + for item in manifest["splits"]: + if item["split_id"] == split_id: + return item[key] + return "2026-01-01T00:00:00Z" diff --git a/training/trader_training/features.py b/training/trader_training/features.py new file mode 100644 index 0000000..ce5b2a4 --- /dev/null +++ b/training/trader_training/features.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import pandas as pd + +from trader_training.io_utils import ( + manifest, + read_parquet, + require_columns, + run_root, + sha256_json, + to_utc_series, + write_json, + write_parquet, + write_text, +) +from trader_training.replay import assign_split +from trader_training.schemas import FEATURE_ORDER, FEATURE_VERSION, FEATURES, FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT + + +META_COLUMNS = [ + "sample_id", + "symbol", + "event_time", + "open_time_ms", + "split_id", + "walk_forward_fold", + "feature_version", + "data_quality_flag", +] + + +def _safe_divide(numerator: pd.Series, denominator: pd.Series, default: float = 0.0) -> pd.Series: + result = numerator / denominator.replace(0, np.nan) + return result.replace([np.inf, -np.inf], np.nan).fillna(default) + + +def _rolling_rank_last(values: pd.Series, window: int) -> pd.Series: + def calc(raw: np.ndarray) -> float: + last = raw[-1] + return float(np.sum(raw <= last) / len(raw)) + + return values.rolling(window, min_periods=window).apply(calc, raw=True) + + +def _complete_days(frame: pd.DataFrame) -> pd.DataFrame: + frame = frame.copy() + frame["event_date"] = frame["event_time"].dt.strftime("%Y-%m-%d") + counts = frame.groupby(["symbol", "event_date"])["event_time"].count() + complete = counts[counts == 1440].reset_index()[["symbol", "event_date"]] + return frame.merge(complete, on=["symbol", "event_date"], how="inner").drop(columns=["event_date"]) + + +def build_feature_frame(args: Any) -> None: + root = run_root(args) + replay_path = args.replay_path or root / "replay" / "replay_1m.parquet" + split_manifest_path = args.split_manifest_path or root / "split" / "split_manifest.json" + replay = read_parquet(replay_path) + required = [ + "symbol", + "event_time", + "open_time_ms", + "open", + "high", + "low", + "close", + "volume", + "taker_buy_volume", + "taker_sell_volume", + "funding_bps", + "mark_price", + "index_price", + "next_funding_time", + "open_interest", + "spread_bps", + "level1_ofi_1m", + "liquidation_buy_notional_1m", + "liquidation_sell_notional_1m", + "liquidation_available", + ] + require_columns(replay, required, "replay_1m") + replay = replay.copy() + replay["event_time"] = to_utc_series(replay["event_time"]) + replay["next_funding_time"] = to_utc_series(replay["next_funding_time"]) + replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True) + if not args.allow_incomplete_days: + before = len(replay) + replay = _complete_days(replay) + logging.info("trader.training.feature_complete_days rowBefore=%s rowAfter=%s", before, len(replay)) + + frames: list[pd.DataFrame] = [] + for symbol, group in replay.groupby("symbol", sort=False): + group = group.sort_values("event_time").reset_index(drop=True).copy() + close = group["close"].astype(float) + high = group["high"].astype(float) + low = group["low"].astype(float) + volume = group["volume"].astype(float) + log_ret = np.log(close / close.shift(1)) + group["ret_1m_bps"] = (close / close.shift(1) - 1.0) * 10000.0 + group["ret_5m_bps"] = (close / close.shift(5) - 1.0) * 10000.0 + group["ret_15m_bps"] = (close / close.shift(15) - 1.0) * 10000.0 + group["ret_60m_bps"] = (close / close.shift(60) - 1.0) * 10000.0 + group["ret_240m_bps"] = (close / close.shift(240) - 1.0) * 10000.0 + group["realized_vol_15m_bps"] = log_ret.rolling(15, min_periods=15).std() * 10000.0 + group["realized_vol_60m_bps"] = log_ret.rolling(60, min_periods=60).std() * 10000.0 + group["vol_ratio_15m_60m"] = _safe_divide(group["realized_vol_15m_bps"], group["realized_vol_60m_bps"].clip(lower=1.0)) + group["range_15m_bps"] = (high.rolling(15, min_periods=15).max() / low.rolling(15, min_periods=15).min() - 1.0) * 10000.0 + group["range_60m_bps"] = (high.rolling(60, min_periods=60).max() / low.rolling(60, min_periods=60).min() - 1.0) * 10000.0 + vol_mean = volume.rolling(60, min_periods=60).mean() + vol_std = volume.rolling(60, min_periods=60).std().replace(0, np.nan) + group["volume_zscore_60m"] = ((volume - vol_mean) / vol_std).fillna(0.0) + group["trend_consistency_15m"] = np.sign(group["ret_1m_bps"]).rolling(15, min_periods=15).mean() + high60 = high.rolling(60, min_periods=60).max() + low60 = low.rolling(60, min_periods=60).min() + group["channel_position_60m_pct"] = ((close - low60) / (high60 - low60).clip(lower=1e-12)).clip(0.0, 1.0) + prev_high60 = high.shift(1).rolling(60, min_periods=60).max() + prev_low60 = low.shift(1).rolling(60, min_periods=60).min() + group["upper_breakout_60m_bps"] = ((close / prev_high60 - 1.0).clip(lower=0.0)) * 10000.0 + group["lower_breakout_60m_bps"] = ((prev_low60 / close - 1.0).clip(lower=0.0)) * 10000.0 + recent_high15 = high.rolling(15, min_periods=15).max() + recent_low15 = low.rolling(15, min_periods=15).min() + broke_up = recent_high15 > prev_high60 + broke_down = recent_low15 < prev_low60 + group["upper_failed_break_reclaim_15m_bps"] = np.where(broke_up, ((prev_high60 - close).clip(lower=0.0) / close) * 10000.0, 0.0) + group["lower_failed_break_reclaim_15m_bps"] = np.where(broke_down, ((close - prev_low60).clip(lower=0.0) / close) * 10000.0, 0.0) + group["sweep_up_15m_bps"] = ((recent_high15 / close - 1.0).clip(lower=0.0)) * 10000.0 + group["sweep_down_15m_bps"] = ((close / recent_low15 - 1.0).clip(lower=0.0)) * 10000.0 + rank = _rolling_rank_last(group["range_15m_bps"], 240) + group["compression_score_4h_pct"] = 1.0 - rank + group["compression_release_15m_bps"] = (group["range_15m_bps"] - group["range_15m_bps"].rolling(240, min_periods=240).median()).clip(lower=0.0) + buy = group["taker_buy_volume"].astype(float) + sell = group["taker_sell_volume"].astype(float) + group["taker_imbalance_1m"] = _safe_divide(buy - sell, buy + sell) + group["taker_imbalance_5m"] = _safe_divide(buy.rolling(5, min_periods=5).sum() - sell.rolling(5, min_periods=5).sum(), (buy + sell).rolling(5, min_periods=5).sum()) + group["taker_imbalance_15m"] = _safe_divide(buy.rolling(15, min_periods=15).sum() - sell.rolling(15, min_periods=15).sum(), (buy + sell).rolling(15, min_periods=15).sum()) + group["spread_rank_24h_pct"] = _rolling_rank_last(group["spread_bps"].astype(float), 1440) + group["oi_delta_15m_bps"] = (group["open_interest"].astype(float) / group["open_interest"].astype(float).shift(15) - 1.0) * 10000.0 + group["oi_delta_60m_bps"] = (group["open_interest"].astype(float) / group["open_interest"].astype(float).shift(60) - 1.0) * 10000.0 + group["mark_index_basis_bps"] = (group["mark_price"].astype(float) / group["index_price"].astype(float) - 1.0) * 10000.0 + liq_buy = group["liquidation_buy_notional_1m"].astype(float) + liq_sell = group["liquidation_sell_notional_1m"].astype(float) + liq_total_15 = (liq_buy + liq_sell).rolling(15, min_periods=1).sum() + group["liquidation_imbalance_15m"] = _safe_divide(liq_buy.rolling(15, min_periods=1).sum() - liq_sell.rolling(15, min_periods=1).sum(), liq_total_15) + liq_mean = liq_total_15.rolling(1440, min_periods=60).mean() + liq_std = liq_total_15.rolling(1440, min_periods=60).std().replace(0, np.nan) + group["liquidation_notional_zscore_15m"] = ((liq_total_15 - liq_mean) / liq_std).fillna(0.0) + minute_of_day = group["event_time"].dt.hour * 60 + group["event_time"].dt.minute + group["minute_of_day_sin"] = np.sin(2 * np.pi * minute_of_day / 1440.0) + group["minute_of_day_cos"] = np.cos(2 * np.pi * minute_of_day / 1440.0) + group["minutes_to_next_funding"] = ((group["next_funding_time"] - group["event_time"]).dt.total_seconds() / 60.0).clip(0.0, 480.0) + group["symbol"] = symbol + frames.append(group) + + frame = pd.concat(frames, ignore_index=True) + frame["sample_id"] = frame["symbol"].astype(str) + ":" + frame["open_time_ms"].astype(str) + frame["split_id"] = assign_split(frame["event_time"], split_manifest_path) + frame["walk_forward_fold"] = np.where(frame["split_id"].eq(FIT_SPLIT), "fold_01", "NO_FOLD") + frame["feature_version"] = FEATURE_VERSION + hard_na = frame[FEATURE_ORDER].isna().any(axis=1) + optional_missing = frame["liquidation_available"].fillna(0).eq(0) + frame["data_quality_flag"] = np.where(hard_na, "WARMUP", np.where(optional_missing, "PARTIAL_OPTIONAL", "OK")) + ordered = frame[META_COLUMNS + FEATURE_ORDER].copy() + for feature in FEATURE_ORDER: + ordered[feature] = pd.to_numeric(ordered[feature], errors="coerce").astype("float32") + + feature_dir = root / "feature" + data_hash = write_parquet(feature_dir / "feature_frame.parquet", ordered) + schema = [feature.as_json() for feature in FEATURES] + feature_order_hash = write_json(feature_dir / "feature_order.json", FEATURE_ORDER) + feature_schema_hash = write_json(feature_dir / "feature_schema.json", schema) + write_json( + feature_dir / "feature_frame.manifest.json", + manifest( + feature_dir / "feature_frame.parquet", + { + "row_count": len(ordered), + "ok_row_count": int(ordered["data_quality_flag"].eq("OK").sum()), + "partial_optional_row_count": int(ordered["data_quality_flag"].eq("PARTIAL_OPTIONAL").sum()), + "warmup_row_count": int(ordered["data_quality_flag"].eq("WARMUP").sum()), + "feature_count": len(FEATURE_ORDER), + "feature_version": FEATURE_VERSION, + "feature_order_hash": feature_order_hash, + "feature_schema_hash": feature_schema_hash, + "data_hash_sha256": data_hash, + }, + ), + ) + write_feature_report(feature_dir / "feature_quality_report.md", ordered, feature_schema_hash, feature_order_hash) + logging.info( + "trader.training.feature_written runId=%s rowCount=%s splitCounts=%s eventFrom=%s eventTo=%s path=%s", + args.run_id, + len(ordered), + ordered["split_id"].value_counts().to_dict(), + ordered["event_time"].min(), + ordered["event_time"].max(), + feature_dir / "feature_frame.parquet", + ) + + +def write_feature_report(path, frame: pd.DataFrame, feature_schema_hash: str, feature_order_hash: str) -> None: + split_rows = [] + for split_id, group in frame.groupby("split_id", sort=True): + split_rows.append( + { + "split_id": split_id, + "rows": len(group), + "start": str(group["event_time"].min()), + "end": str(group["event_time"].max()), + "ok": int(group["data_quality_flag"].eq("OK").sum()), + "partial_optional": int(group["data_quality_flag"].eq("PARTIAL_OPTIONAL").sum()), + "warmup": int(group["data_quality_flag"].eq("WARMUP").sum()), + } + ) + finite_rows = [] + for feature in FEATURE_ORDER: + series = pd.to_numeric(frame[feature], errors="coerce") + values = series.to_numpy(dtype=float) + finite_rows.append( + { + "feature": feature, + "nan_count": int(series.isna().sum()), + "inf_count": int(np.isinf(values).sum()), + "finite_count": int(np.isfinite(values).sum()), + } + ) + correlation_rows = _high_correlation_rows(frame) + drift_rows = _drift_rows(frame) + lines = [ + "# Trader Feature Quality Report", + "", + f"- row_count: {len(frame)}", + f"- OK: {int(frame['data_quality_flag'].eq('OK').sum())}", + f"- PARTIAL_OPTIONAL: {int(frame['data_quality_flag'].eq('PARTIAL_OPTIONAL').sum())}", + f"- WARMUP: {int(frame['data_quality_flag'].eq('WARMUP').sum())}", + f"- feature_schema_hash: {feature_schema_hash}", + f"- feature_order_hash: {feature_order_hash}", + "", + "## Split Coverage", + "", + _markdown_table(split_rows, ["split_id", "rows", "start", "end", "ok", "partial_optional", "warmup"]), + "", + "## Source Coverage", + "", + f"- replay_1m_required_columns: present", + f"- liquidation_available_share: {float(frame['liquidation_available'].mean()):.6f}", + f"- feature_rows_with_optional_liquidation_missing: {int(frame['data_quality_flag'].eq('PARTIAL_OPTIONAL').sum())}", + "", + "## Leakage Check", + "", + "- 所有特征只使用当前分钟收盘后已经知道的数据,滚动窗口都只看 `<= t`。", + "- 未来价格、未来收益、目标标签不进入 `feature_frame.parquet`。", + "", + "## Extreme Value Check", + "", + _markdown_table(finite_rows, ["feature", "nan_count", "inf_count", "finite_count"]), + "", + "## High Correlation Check", + "", + _markdown_table(correlation_rows, ["feature_a", "feature_b", "corr_abs"]), + "", + "## Drift Check", + "", + _markdown_table( + drift_rows, + ["feature", "train_p50", "tune_p50", "validation_p50", "p50_diff", "train_p99", "tune_p99", "validation_p99", "p99_diff"], + ), + "", + "## Distribution", + "", + "| feature | null_count | min | p01 | p50 | p99 | max |", + "| --- | ---: | ---: | ---: | ---: | ---: | ---: |", + ] + for feature in FEATURE_ORDER: + series = pd.to_numeric(frame[feature], errors="coerce") + quantiles = series.quantile([0.01, 0.5, 0.99]) + lines.append( + f"| {feature} | {int(series.isna().sum())} | {series.min():.6g} | {quantiles.loc[0.01]:.6g} | {quantiles.loc[0.5]:.6g} | {quantiles.loc[0.99]:.6g} | {series.max():.6g} |" + ) + write_text(path, "\n".join(lines) + "\n") + + +def feature_order_hash() -> str: + return sha256_json(FEATURE_ORDER) + + +def _high_correlation_rows(frame: pd.DataFrame) -> list[dict[str, object]]: + sample = frame[FEATURE_ORDER].apply(pd.to_numeric, errors="coerce").dropna() + if len(sample) > 5000: + sample = sample.sample(5000, random_state=7) + if sample.empty: + return [{"feature_a": "NONE", "feature_b": "NONE", "corr_abs": 0.0}] + corr = sample.corr().abs() + rows = [] + for left_index, left in enumerate(FEATURE_ORDER): + for right in FEATURE_ORDER[left_index + 1 :]: + value = corr.loc[left, right] + if pd.notna(value) and value >= 0.95: + rows.append({"feature_a": left, "feature_b": right, "corr_abs": round(float(value), 6)}) + return rows[:30] or [{"feature_a": "NONE", "feature_b": "NONE", "corr_abs": 0.0}] + + +def _drift_rows(frame: pd.DataFrame) -> list[dict[str, object]]: + train = frame[frame["split_id"].eq(FIT_SPLIT)] + validation = frame[frame["split_id"].eq(VALIDATION_LOCKED_SPLIT)] + tune = frame[frame["split_id"].eq(TUNE_SPLIT)] + rows = [] + for feature in FEATURE_ORDER: + train_series = pd.to_numeric(train[feature], errors="coerce") + validation_series = pd.to_numeric(validation[feature], errors="coerce") + tune_series = pd.to_numeric(tune[feature], errors="coerce") + train_p50 = float(train_series.quantile(0.5)) if not train_series.empty else 0.0 + tune_p50 = float(tune_series.quantile(0.5)) if not tune_series.empty else 0.0 + validation_p50 = float(validation_series.quantile(0.5)) if not validation_series.empty else 0.0 + train_p99 = float(train_series.quantile(0.99)) if not train_series.empty else 0.0 + tune_p99 = float(tune_series.quantile(0.99)) if not tune_series.empty else 0.0 + validation_p99 = float(validation_series.quantile(0.99)) if not validation_series.empty else 0.0 + rows.append( + { + "feature": feature, + "train_p50": round(train_p50, 6), + "tune_p50": round(tune_p50, 6), + "validation_p50": round(validation_p50, 6), + "p50_diff": round(validation_p50 - train_p50, 6), + "train_p99": round(train_p99, 6), + "tune_p99": round(tune_p99, 6), + "validation_p99": round(validation_p99, 6), + "p99_diff": round(validation_p99 - train_p99, 6), + } + ) + return rows + + +def _markdown_table(rows: list[dict[str, object]], columns: list[str]) -> str: + if not rows: + rows = [{column: "" for column in columns}] + lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"] + for row in rows: + lines.append("| " + " | ".join(str(row.get(column, "")) for column in columns) + " |") + return "\n".join(lines) diff --git a/training/trader_training/io_utils.py b/training/trader_training/io_utils.py new file mode 100644 index 0000000..98eb5ab --- /dev/null +++ b/training/trader_training/io_utils.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import logging +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Iterable + +import pandas as pd + + +DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data") +DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw" + + +def setup_logging() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s event=%(message)s", + ) + + +def add_common_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT) + parser.add_argument("--run-id", required=True) + parser.add_argument("--config", type=Path) + parser.add_argument("--workspace", type=Path) + parser.add_argument("--fail-fast", action="store_true") + + +def run_root(args: argparse.Namespace) -> Path: + return args.data_root / "trader-v4" / "runs" / args.run_id + + +def ensure_dir(path: Path) -> Path: + path.mkdir(parents=True, exist_ok=True) + return path + + +def canonical_json_bytes(value: Any) -> bytes: + return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8") + + +def sha256_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def sha256_json(value: Any) -> str: + return sha256_bytes(canonical_json_bytes(value)) + + +def sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def write_json(path: Path, value: Any) -> str: + ensure_dir(path.parent) + data = canonical_json_bytes(value) + path.write_bytes(data + b"\n") + return sha256_bytes(data) + + +def read_json(path: Path) -> Any: + with path.open("r", encoding="utf-8") as fh: + return json.load(fh) + + +def write_parquet(path: Path, frame: pd.DataFrame) -> str: + ensure_dir(path.parent) + frame.to_parquet(path, index=False) + return sha256_file(path) + + +def read_parquet(path: Path) -> pd.DataFrame: + if not path.is_file(): + raise FileNotFoundError(f"required parquet is missing: {path}") + return pd.read_parquet(path) + + +def write_text(path: Path, text: str) -> None: + ensure_dir(path.parent) + path.write_text(text, encoding="utf-8") + + +def utc_now_text() -> str: + return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def to_utc_series(values: pd.Series) -> pd.Series: + if pd.api.types.is_datetime64_any_dtype(values): + return pd.to_datetime(values, utc=True) + numeric = pd.to_numeric(values, errors="coerce") + if numeric.notna().any(): + max_value = numeric.dropna().abs().max() + unit = "ms" if max_value > 10_000_000_000 else "s" + return pd.to_datetime(numeric, unit=unit, utc=True) + return pd.to_datetime(values, utc=True, errors="coerce") + + +def open_time_ms(values: pd.Series) -> pd.Series: + dt = to_utc_series(values) + return (dt.astype("int64") // 1_000_000).astype("Int64") + + +def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]: + return start_date, end_date + + +def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]: + base = raw_root / f"table={table}" + if not base.is_dir(): + return [] + files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet")) + selected: list[Path] = [] + for file in files: + dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") + if start_date and dt_part < start_date: + continue + if end_date and dt_part > end_date: + continue + selected.append(file) + return selected + + +def read_partitioned_table( + raw_root: Path, + table: str, + symbol: str, + start_date: str | None, + end_date: str | None, + columns: Iterable[str] | None = None, +) -> pd.DataFrame: + files = partition_files(raw_root, table, symbol, start_date, end_date) + if not files: + return pd.DataFrame() + logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files)) + frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files] + frame = pd.concat(frames, ignore_index=True) + logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame)) + return frame + + +def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None: + missing = [column for column in columns if column not in frame.columns] + if missing: + raise ValueError(f"{name} is missing required columns: {missing}") + + +def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]: + payload = { + "path": str(path), + "hash_sha256": sha256_file(path) if path.is_file() else None, + "created_at": utc_now_text(), + } + payload.update(extra) + return payload diff --git a/training/trader_training/labels.py b/training/trader_training/labels.py new file mode 100644 index 0000000..f4f5e64 --- /dev/null +++ b/training/trader_training/labels.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import pandas as pd + +from trader_training.io_utils import ( + manifest, + read_json, + read_parquet, + require_columns, + run_root, + sha256_json, + to_utc_series, + write_json, + write_parquet, + write_text, +) +from trader_training.schemas import LABEL_VERSION + + +DEFAULT_LABEL_CONFIG = { + "direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0}, + "entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0}, + "continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0}, + "exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0}, + "risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0}, +} + + +DEFAULT_COST_CONFIG = { + "fee_bps": 4.0, + "slippage_bps": 2.0, + "funding_cost_bps": 0.5, +} + + +def _load_config(path, default): + if path is None: + return default + value = read_json(path) + merged = default.copy() + for key, item in value.items(): + if isinstance(item, dict) and isinstance(merged.get(key), dict): + merged[key] = {**merged[key], **item} + else: + merged[key] = item + return merged + + +def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]: + root = run_root(args) + feature_path = args.feature_path or root / "feature" / "feature_frame.parquet" + replay_path = args.replay_path or root / "replay" / "replay_1m.parquet" + features = read_parquet(feature_path) + replay = read_parquet(replay_path) + require_columns(features, ("sample_id", "symbol", "event_time", "open_time_ms", "split_id", "walk_forward_fold", "data_quality_flag"), "feature_frame") + require_columns(replay, ("symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "spread_bps"), "replay_1m") + features = features.copy() + replay = replay.copy() + features["event_time"] = to_utc_series(features["event_time"]) + replay["event_time"] = to_utc_series(replay["event_time"]) + replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True) + return features, replay + + +def _future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame: + start = index + 1 + end = min(len(group), index + horizon + 1) + return group.iloc[start:end] + + +def _contiguous_future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame: + path = _future_path(group, index, horizon) + if len(path) < horizon: + return pd.DataFrame() + current_ms = int(group.iloc[index]["open_time_ms"]) + expected = current_ms + np.arange(1, horizon + 1, dtype=np.int64) * 60_000 + actual = path["open_time_ms"].astype("int64").to_numpy() + if len(actual) != len(expected) or not np.array_equal(actual, expected): + return pd.DataFrame() + return path + + +def _side_return_bps(side: str, entry_price: float, exit_price: float) -> float: + if side == "LONG": + return (exit_price / entry_price - 1.0) * 10000.0 + return (entry_price / exit_price - 1.0) * 10000.0 + + +def _path_stats(group: pd.DataFrame, index: int, side: str, horizon: int, target_bps: float, stop_bps: float) -> dict[str, Any]: + current = group.iloc[index] + entry = float(current["close"]) + path = _contiguous_future_path(group, index, horizon) + if path.empty: + return {"valid": False} + target_price = entry * (1.0 + target_bps / 10000.0) if side == "LONG" else entry * (1.0 - target_bps / 10000.0) + stop_price = entry * (1.0 - stop_bps / 10000.0) if side == "LONG" else entry * (1.0 + stop_bps / 10000.0) + target_hit = False + stop_hit = False + ambiguous = False + time_to_target_ms = -1 + time_to_stop_ms = -1 + for _, row in path.iterrows(): + high = float(row["high"]) + low = float(row["low"]) + if side == "LONG": + target_now = high >= target_price + stop_now = low <= stop_price + else: + target_now = low <= target_price + stop_now = high >= stop_price + if target_now and stop_now: + ambiguous = True + stop_hit = True + time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"]) + break + if target_now: + target_hit = True + time_to_target_ms = int(row["open_time_ms"] - current["open_time_ms"]) + break + if stop_now: + stop_hit = True + time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"]) + break + exit_price = float(path.iloc[-1]["close"]) + final_return_bps = _side_return_bps(side, entry, exit_price) + if side == "LONG": + mfe = (path["high"].max() / entry - 1.0) * 10000.0 + mae = (entry / path["low"].min() - 1.0) * 10000.0 + else: + mfe = (entry / path["low"].min() - 1.0) * 10000.0 + mae = (path["high"].max() / entry - 1.0) * 10000.0 + if target_hit: + gross = target_bps + elif stop_hit: + gross = -stop_bps + else: + gross = final_return_bps + return { + "valid": True, + "target_hit": int(target_hit), + "stop_hit": int(stop_hit), + "timeout_hit": int(not target_hit and not stop_hit), + "ambiguous_hit": int(ambiguous), + "time_to_target_ms": time_to_target_ms, + "time_to_stop_ms": time_to_stop_ms, + "gross_edge_bps": float(gross), + "future_return_bps": float(final_return_bps), + "mfe_bps": float(mfe), + "mae_bps": float(mae), + "future_spread_p80": float(path["spread_bps"].quantile(0.8)), + "future_realized_vol_bps": float(np.log(path["close"].astype(float) / path["close"].astype(float).shift(1)).std() * 10000.0), + } + + +def write_price_plan_context(args: Any) -> None: + root = run_root(args) + cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG) + labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG) + entry = labels["entry"] + cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"]) + context = { + "pricePlanId": args.price_plan_id, + "pricePlanConfigHash": sha256_json({"entry": entry, "cost": cost}), + "stopDistanceBps": float(entry["stop_bps"]), + "targetDistanceBps": float(entry["target_bps"]), + "maxHoldMinutes": int(entry["max_hold_minutes"]), + "costBps": cost_bps, + } + path = root / "label" / "price_plan_context.json" + write_json(path, context) + frame = pd.DataFrame([{ + "price_plan_id": context["pricePlanId"], + "price_plan_hash": context["pricePlanConfigHash"], + "target_bps": context["targetDistanceBps"], + "stop_bps": context["stopDistanceBps"], + "max_hold_minutes": context["maxHoldMinutes"], + "cost_bps": context["costBps"], + }]) + write_parquet(root / "label" / "price_plan_context.parquet", frame) + logging.info("trader.training.price_plan_written runId=%s path=%s", args.run_id, path) + + +def build_direction_labels(args: Any) -> None: + root = run_root(args) + config = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)["direction"] + features, replay = _base_frames(args) + horizon = int(config["horizon_minutes"]) + replay = replay[["symbol", "event_time", "open_time_ms", "close"]].copy() + future = replay[["symbol", "open_time_ms", "close"]].copy() + future["open_time_ms"] = future["open_time_ms"].astype("int64") - horizon * 60_000 + future = future.rename(columns={"close": "future_close"}) + merged = features.merge(replay[["symbol", "open_time_ms", "close"]], on=["symbol", "open_time_ms"], how="left") + merged = merged.merge(future, on=["symbol", "open_time_ms"], how="left") + merged["future_return_bps"] = (merged["future_close"] / merged["close"] - 1.0) * 10000.0 + merged["direction_label"] = np.select( + [merged["future_return_bps"] >= float(config["long_threshold_bps"]), merged["future_return_bps"] <= float(config["short_threshold_bps"])], + ["LONG", "SHORT"], + default="NEUTRAL", + ) + out = pd.DataFrame( + { + "sample_id": merged["sample_id"], + "symbol": merged["symbol"], + "event_time": merged["event_time"], + "horizon_minutes": horizon, + "future_return_bps": merged["future_return_bps"], + "direction_label": merged["direction_label"], + "long_target": merged["direction_label"].eq("LONG").astype("int8"), + "short_target": merged["direction_label"].eq("SHORT").astype("int8"), + "neutral_target": merged["direction_label"].eq("NEUTRAL").astype("int8"), + "split_id": merged["split_id"], + "walk_forward_fold": merged["walk_forward_fold"], + "label_version": LABEL_VERSION, + } + ).dropna(subset=["future_return_bps"]) + path = root / "label" / "direction_labels.parquet" + data_hash = write_parquet(path, out) + _write_label_manifest(root / "label" / "direction_labels.manifest.json", path, out, data_hash) + _write_distribution_report(root / "label" / "direction_label_report.md", out, "direction_label") + logging.info("trader.training.direction_labels_written runId=%s rowCount=%s", args.run_id, len(out)) + + +def build_entry_labels(args: Any) -> None: + root = run_root(args) + labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG) + cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG) + plan_path = args.price_plan_context_path or root / "label" / "price_plan_context.json" + plan = read_json(plan_path) + features, replay = _base_frames(args) + entry_conf = labels["entry"] + cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"]) + rows: list[dict[str, Any]] = [] + groups, index_by_key = _group_replay_with_index(replay) + for feature in features.itertuples(index=False): + key = (feature.symbol, int(feature.open_time_ms)) + index = index_by_key.get(key) + if index is None: + continue + group = groups[feature.symbol] + for side in ("LONG", "SHORT"): + stats = _path_stats(group, index, side, int(entry_conf["max_hold_minutes"]), float(entry_conf["target_bps"]), float(entry_conf["stop_bps"])) + if not stats["valid"]: + continue + expected = stats["gross_edge_bps"] - cost_bps + rows.append( + { + "sample_id": feature.sample_id, + "symbol": feature.symbol, + "event_time": feature.event_time, + "side": side, + "price_plan_id": plan["pricePlanId"], + "price_plan_hash": plan["pricePlanConfigHash"], + "target_hit": stats["target_hit"], + "stop_hit": stats["stop_hit"], + "timeout_hit": stats["timeout_hit"], + "ambiguous_hit": stats["ambiguous_hit"], + "time_to_target_ms": stats["time_to_target_ms"], + "time_to_stop_ms": stats["time_to_stop_ms"], + "gross_edge_bps": stats["gross_edge_bps"], + "cost_bps": cost_bps, + "expected_net_edge_bps": expected, + "entry_target": int(stats["target_hit"] == 1 and expected >= float(entry_conf["min_expected_net_edge_bps"])), + "split_id": feature.split_id, + "walk_forward_fold": feature.walk_forward_fold, + "label_version": LABEL_VERSION, + } + ) + out = pd.DataFrame(rows) + path = root / "label" / "entry_labels.parquet" + data_hash = write_parquet(path, out) + _write_label_manifest(root / "label" / "entry_labels.manifest.json", path, out, data_hash) + _write_distribution_report(root / "label" / "entry_label_report.md", out, "entry_target") + logging.info("trader.training.entry_labels_written runId=%s rowCount=%s", args.run_id, len(out)) + + +def build_position_state_samples(args: Any) -> None: + root = run_root(args) + entry_path = args.entry_label_path or root / "label" / "entry_labels.parquet" + entry = read_parquet(entry_path) + if entry.empty: + raise ValueError("entry labels are required before building position samples") + samples = entry[entry["entry_target"] == 1].copy() + samples["position_age_minutes"] = 0 + samples["unrealized_pnl_bps"] = 0.0 + samples["mfe_bps"] = samples["gross_edge_bps"].clip(lower=0) + samples["mae_bps"] = (-samples["gross_edge_bps"]).clip(lower=0) + path = root / "label" / "position_state_samples.parquet" + data_hash = write_parquet(path, samples) + write_json(root / "label" / "position_state_samples.manifest.json", manifest(path, {"row_count": len(samples), "data_hash_sha256": data_hash})) + logging.info("trader.training.position_samples_written runId=%s rowCount=%s", args.run_id, len(samples)) + + +def build_continue_exit_risk_labels(args: Any) -> None: + root = run_root(args) + labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG) + cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG) + plan = read_json(args.price_plan_context_path or root / "label" / "price_plan_context.json") + features, replay = _base_frames(args) + cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"]) + horizon = int(labels["continue"]["horizon_minutes"]) + target_bps = float(plan["targetDistanceBps"]) + stop_bps = float(plan["stopDistanceBps"]) + rows_continue: list[dict[str, Any]] = [] + rows_exit: list[dict[str, Any]] = [] + rows_risk: list[dict[str, Any]] = [] + groups, index_by_key = _group_replay_with_index(replay) + for feature in features.itertuples(index=False): + key = (feature.symbol, int(feature.open_time_ms)) + index = index_by_key.get(key) + if index is None: + continue + group = groups[feature.symbol] + long_stats = _path_stats(group, index, "LONG", horizon, target_bps, stop_bps) + short_stats = _path_stats(group, index, "SHORT", horizon, target_bps, stop_bps) + if not long_stats["valid"] or not short_stats["valid"]: + continue + long_edge = long_stats["future_return_bps"] - cost_bps + short_edge = short_stats["future_return_bps"] - cost_bps + min_continue = float(labels["continue"]["min_expected_continue_edge_bps"]) + adverse_threshold = float(labels["exit"]["adverse_move_bps"]) + rows_continue.append( + { + "sample_id": feature.sample_id, + "symbol": feature.symbol, + "event_time": feature.event_time, + "long_continue_target": int(long_edge >= min_continue and long_stats["mae_bps"] < stop_bps), + "short_continue_target": int(short_edge >= min_continue and short_stats["mae_bps"] < stop_bps), + "long_expected_continue_edge_bps": long_edge, + "short_expected_continue_edge_bps": short_edge, + "split_id": feature.split_id, + "walk_forward_fold": feature.walk_forward_fold, + "label_version": LABEL_VERSION, + } + ) + stagnation = int(abs(long_stats["future_return_bps"]) <= float(labels["exit"]["stagnation_abs_return_bps"])) + rows_exit.append( + { + "sample_id": feature.sample_id, + "symbol": feature.symbol, + "event_time": feature.event_time, + "long_exit_target": int(long_stats["stop_hit"] == 1 or long_stats["mae_bps"] >= adverse_threshold), + "short_exit_target": int(short_stats["stop_hit"] == 1 or short_stats["mae_bps"] >= adverse_threshold), + "long_adverse_move_bps": long_stats["mae_bps"], + "short_adverse_move_bps": short_stats["mae_bps"], + "adverse_move_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= adverse_threshold), + "reversal_prob_label": int(np.sign(long_stats["future_return_bps"]) != np.sign(feature.ret_15m_bps) if hasattr(feature, "ret_15m_bps") else 0), + "stop_hit_prob_label": int(long_stats["stop_hit"] == 1 or short_stats["stop_hit"] == 1), + "stagnation_prob_label": stagnation, + "split_id": feature.split_id, + "walk_forward_fold": feature.walk_forward_fold, + "label_version": LABEL_VERSION, + } + ) + path_risk = max(long_stats["mae_bps"], short_stats["mae_bps"]) + vol_ratio = 0.0 if long_stats["future_realized_vol_bps"] != long_stats["future_realized_vol_bps"] else long_stats["future_realized_vol_bps"] + rows_risk.append( + { + "sample_id": feature.sample_id, + "symbol": feature.symbol, + "event_time": feature.event_time, + "market_risk_target": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])), + "market_path_risk_bps": path_risk, + "long_position_path_risk_bps": long_stats["mae_bps"], + "short_position_path_risk_bps": short_stats["mae_bps"], + "long_position_risk_target": int(long_stats["mae_bps"] >= stop_bps), + "short_position_risk_target": int(short_stats["mae_bps"] >= stop_bps), + "market_drawdown_prob_label": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])), + "volatility_expansion_prob_label": int(vol_ratio >= float(labels["risk"]["spike_bps"])), + "spike_prob_label": int(max(long_stats["mfe_bps"], short_stats["mfe_bps"], path_risk) >= float(labels["risk"]["spike_bps"])), + "liquidity_deterioration_prob_label": int(long_stats["future_spread_p80"] >= float(replay["spread_bps"].quantile(0.9))), + "position_drawdown_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= stop_bps), + "split_id": feature.split_id, + "walk_forward_fold": feature.walk_forward_fold, + "label_version": LABEL_VERSION, + } + ) + outputs = [ + ("continue", pd.DataFrame(rows_continue), "long_continue_target"), + ("exit", pd.DataFrame(rows_exit), "long_exit_target"), + ("risk", pd.DataFrame(rows_risk), "market_risk_target"), + ] + report_parts = ["# Continue Exit Risk Label Report", ""] + for name, frame, target in outputs: + path = root / "label" / f"{name}_labels.parquet" + data_hash = write_parquet(path, frame) + _write_label_manifest(root / "label" / f"{name}_labels.manifest.json", path, frame, data_hash) + report_parts.append(f"## {name}") + report_parts.append("") + report_parts.append(str(frame[target].value_counts(dropna=False).to_dict() if not frame.empty else {})) + report_parts.append("") + logging.info("trader.training.%s_labels_written runId=%s rowCount=%s", name, args.run_id, len(frame)) + write_text(root / "label" / "continue_exit_risk_label_report.md", "\n".join(report_parts) + "\n") + + +def _write_label_manifest(path, parquet_path, frame: pd.DataFrame, data_hash: str) -> None: + write_json(path, manifest(parquet_path, {"row_count": len(frame), "label_version": LABEL_VERSION, "data_hash_sha256": data_hash})) + + +def _write_distribution_report(path, frame: pd.DataFrame, column: str) -> None: + counts = frame[column].value_counts(dropna=False).to_dict() if not frame.empty else {} + lines = ["# Label Report", "", f"- row_count: {len(frame)}", f"- target_column: {column}", f"- distribution: {counts}", ""] + write_text(path, "\n".join(lines)) + + +def _group_replay_with_index(replay: pd.DataFrame) -> tuple[dict[str, pd.DataFrame], dict[tuple[str, int], int]]: + groups: dict[str, pd.DataFrame] = {} + index_by_key: dict[tuple[str, int], int] = {} + for symbol, group in replay.groupby("symbol", sort=False): + grouped = group.sort_values("event_time").reset_index(drop=True) + groups[symbol] = grouped + for idx, row in grouped.iterrows(): + index_by_key[(symbol, int(row["open_time_ms"]))] = idx + return groups, index_by_key diff --git a/training/trader_training/onnx_export.py b/training/trader_training/onnx_export.py new file mode 100644 index 0000000..e9618a9 --- /dev/null +++ b/training/trader_training/onnx_export.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + + +@dataclass(frozen=True) +class LinearHead: + name: str + kind: str + weight: np.ndarray + bias: np.ndarray + + +def require_onnx(): + try: + import onnx + from onnx import TensorProto, helper, numpy_helper + except ModuleNotFoundError as exc: + raise SystemExit("Python package 'onnx' is required. Install training/requirements.txt before export.") from exc + return onnx, TensorProto, helper, numpy_helper + + +def export_heads(path: Path, heads: list[LinearHead], feature_count: int = 39, opset: int = 17) -> None: + onnx, TensorProto, helper, numpy_helper = require_onnx() + nodes = [] + initializers = [] + concat_inputs = [] + for idx, head in enumerate(heads): + weight = np.asarray(head.weight, dtype=np.float32) + bias = np.asarray(head.bias, dtype=np.float32).reshape(1, -1) + if weight.ndim == 1: + weight = weight.reshape(feature_count, 1) + weight_name = f"{head.name}_W" + bias_name = f"{head.name}_B" + linear_name = f"{head.name}_linear" + out_name = f"{head.name}_out" + initializers.append(numpy_helper.from_array(weight, weight_name)) + initializers.append(numpy_helper.from_array(bias, bias_name)) + nodes.append(helper.make_node("MatMul", ["features", weight_name], [f"{linear_name}_mm"], name=f"{head.name}_matmul")) + nodes.append(helper.make_node("Add", [f"{linear_name}_mm", bias_name], [linear_name], name=f"{head.name}_add")) + if head.kind == "sigmoid": + nodes.append(helper.make_node("Sigmoid", [linear_name], [out_name], name=f"{head.name}_sigmoid")) + elif head.kind == "softmax": + nodes.append(helper.make_node("Softmax", [linear_name], [out_name], name=f"{head.name}_softmax", axis=1)) + elif head.kind == "identity": + out_name = linear_name + else: + raise ValueError(f"unsupported ONNX head kind: {head.kind}") + concat_inputs.append(out_name) + if len(concat_inputs) == 1: + nodes.append(helper.make_node("Identity", concat_inputs, ["prediction"], name="prediction_identity")) + else: + nodes.append(helper.make_node("Concat", concat_inputs, ["prediction"], name="prediction_concat", axis=1)) + graph = helper.make_graph( + nodes, + "trader_v4_linear_heads", + [helper.make_tensor_value_info("features", TensorProto.FLOAT, [1, feature_count])], + [helper.make_tensor_value_info("prediction", TensorProto.FLOAT, [1, sum(_head_width(head) for head in heads)])], + initializer=initializers, + ) + model = helper.make_model(graph, producer_name="trader-training", opset_imports=[helper.make_opsetid("", opset)]) + model.ir_version = 10 + onnx.checker.check_model(model) + path.parent.mkdir(parents=True, exist_ok=True) + onnx.save(model, path) + + +def _head_width(head: LinearHead) -> int: + bias = np.asarray(head.bias) + return int(bias.size) diff --git a/training/trader_training/pm.py b/training/trader_training/pm.py new file mode 100644 index 0000000..c24df89 --- /dev/null +++ b/training/trader_training/pm.py @@ -0,0 +1,541 @@ +from __future__ import annotations + +import itertools +import logging +from typing import Any + +import numpy as np +import pandas as pd + +from trader_training.io_utils import read_json, read_parquet, run_root, sha256_json, write_json, write_parquet, write_text +from trader_training.schemas import LATEST_STRESS_SPLIT, PM_CONFIG_VERSION, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT + + +def default_pm_config() -> dict: + return { + "pmConfigVersion": PM_CONFIG_VERSION, + "open": { + "longOpenProb": 0.58, + "shortOpenProb": 0.58, + "minLongEntryProb": 0.55, + "minShortEntryProb": 0.55, + "maxMarketRiskProb": 0.45, + "minExpectedEdgeBps": 3.0, + "minDirectionMargin": 0.03, + "minLiquidityCapacityRatio": 0.10, + "maxOodScore": 0.80, + }, + "add": { + "minLongProb": 0.60, + "minShortProb": 0.60, + "minContinueProb": 0.58, + "minEntryProb": 0.55, + "maxExitProb": 0.45, + "maxMarketRiskProb": 0.45, + "maxPositionRiskProb": 0.50, + "minExpectedEdgeBps": 3.0, + "minContinueVsExitEdgeBps": 0.0, + "minLiquidityCapacityRatio": 0.10, + "minPostTradeLiquidationBufferBps": 500.0, + "maxAddCount": 3, + "cooldownMinutes": 5, + }, + "exit": { + "closeExitProb": 0.70, + "closePositionRiskProb": 0.70, + "closeMarketRiskProb": 0.70, + "closeContinueMax": 0.25, + "reduceAdverseMoveProb": 0.62, + "reduceContinueMin": 0.35, + "reduceContinueMax": 0.70, + "minProfitForReduceBps": 5.0, + "maxPositionPathRiskBps": 80.0, + }, + "sizing": { + "baseRatio": 0.80, + "minInitialRatio": 0.05, + "maxSingleLegRatio": 1.0, + "minAddRatio": 0.02, + "maxAddRatio": 0.25, + "maxTotalPositionRatio": 1.0, + "minEdgeBps": 3.0, + "maxLossPerTradeBps": 80.0, + "maxLiquidityUsageRatio": 0.20, + "uncertaintyPenaltyMultiplier": 0.50, + "minPostTradeLiquidationBufferBps": 500.0, + }, + } + + +def search_pm_thresholds(args: Any) -> None: + root = run_root(args) + frame = _pm_tune_frame(root) + candidate_rows: list[dict[str, Any]] = [] + best_score = -float("inf") + best_thresholds: dict[str, float] | None = None + best_metrics: dict[str, Any] | None = None + best_trades = pd.DataFrame() + + for thresholds in _threshold_candidates(): + trades = _simulate_open_trades(frame, thresholds) + metrics = _trade_metrics(trades) + score = _score_thresholds(metrics) + candidate_rows.append({**thresholds, **metrics, "score": score}) + if score > best_score: + best_score = score + best_thresholds = thresholds + best_metrics = metrics + best_trades = trades + + if best_thresholds is None or best_metrics is None: + raise ValueError("PM threshold search did not evaluate any candidate") + + config = _pm_config_from_thresholds(best_thresholds) + threshold_stability = { + "source": "tune_predictions_and_entry_labels", + "method": "deterministic_grid_search_v1", + "candidate_count": len(candidate_rows), + "best_score": best_score, + "best_metrics": best_metrics, + } + payload = { + "pm_config_version": PM_CONFIG_VERSION, + "config": config, + "config_hash_sha256": sha256_json(config), + "threshold_stability_json": threshold_stability, + } + candidate_frame = pd.DataFrame(candidate_rows).sort_values("score", ascending=False).reset_index(drop=True) + equity_curve = _equity_curve(best_trades) + regime_metrics = _regime_metrics(best_trades) + write_json(root / "pm-search" / "position_manager_config.json", payload) + write_json(root / "pm-search" / "pm_threshold_config.json", payload) + write_text(root / "pm-search" / "pm_search_candidates.csv", candidate_frame.to_csv(index=False)) + write_parquet(root / "pm-search" / "pm_backtest_trades.parquet", best_trades) + write_text(root / "pm-search" / "pm_equity_curve.csv", equity_curve.to_csv(index=False)) + write_text(root / "pm-search" / "pm_regime_metrics.csv", regime_metrics.to_csv(index=False)) + _write_pm_report(root / "pm-search" / "pm_threshold_report.md", candidate_frame, best_thresholds, best_metrics) + _write_pm_report(root / "pm-search" / "pm_search_report.md", candidate_frame, best_thresholds, best_metrics) + logging.info( + "trader.training.pm_thresholds_searched runId=%s candidateCount=%s bestScore=%.6f tradeCount=%s totalWeightedEdgeBps=%.6f", + args.run_id, + len(candidate_rows), + best_score, + best_metrics["trade_count"], + best_metrics["total_weighted_edge_bps"], + ) + + +def integrated_backtest(args: Any) -> None: + root = run_root(args) + config_path = root / "pm-search" / "position_manager_config.json" + if not config_path.is_file(): + raise FileNotFoundError(f"PM config is required before backtest: {config_path}") + pm_payload = read_json(config_path) + trades_path = root / "pm-search" / "pm_backtest_trades.parquet" + # PM search is allowed to use tune_inner, but final acceptance must be + # measured on the sealed validation_locked and latest_stress splits. + tune_trades = read_parquet(trades_path) if trades_path.is_file() else _simulate_open_trades(_pm_tune_frame(root), _thresholds_from_config(pm_payload["config"])) + tune_trades["eval_split"] = TUNE_SPLIT + validation_locked_trades = _simulate_open_trades(_pm_frame(root, VALIDATION_LOCKED_SPLIT), _thresholds_from_config(pm_payload["config"])) + validation_locked_trades["eval_split"] = VALIDATION_LOCKED_SPLIT + stress_trades = _simulate_open_trades(_pm_frame(root, LATEST_STRESS_SPLIT), _thresholds_from_config(pm_payload["config"])) + stress_trades["eval_split"] = LATEST_STRESS_SPLIT + trades = pd.concat([tune_trades, validation_locked_trades, stress_trades], ignore_index=True) + metrics = { + TUNE_SPLIT: _trade_metrics(tune_trades), + VALIDATION_LOCKED_SPLIT: _trade_metrics(validation_locked_trades), + LATEST_STRESS_SPLIT: _trade_metrics(stress_trades), + "combined": _trade_metrics(trades), + } + status, status_reasons = _backtest_status(metrics) + equity_curve = _equity_curve(trades) + regime_metrics = _regime_metrics(trades) + result = { + "backtest_manifest_id": f"backtest-{args.run_id}", + "mode": "VALIDATION_PM_BACKTEST", + "pm_config_hash_sha256": pm_payload["config_hash_sha256"], + "metrics": metrics, + "status_reasons": status_reasons, + "status": status, + } + write_json(root / "backtest" / "backtest_manifest.json", result) + write_parquet(root / "backtest" / "backtest_trades.parquet", trades) + write_text(root / "backtest" / "equity_curve.csv", equity_curve.to_csv(index=False)) + write_text(root / "backtest" / "regime_metrics.csv", regime_metrics.to_csv(index=False)) + _write_backtest_report(root / "backtest" / "backtest_report.md", result) + _write_failure_cases(root / "backtest" / "failure_cases.md", trades) + _write_no_baseline_ablation(root / "backtest" / "direction_ablation_backtest_report.md") + logging.info( + "trader.training.backtest_written runId=%s status=%s tradeCount=%s totalWeightedEdgeBps=%.6f maxDrawdownBps=%.6f", + args.run_id, + status, + metrics[VALIDATION_LOCKED_SPLIT]["trade_count"], + metrics[VALIDATION_LOCKED_SPLIT]["total_weighted_edge_bps"], + metrics[VALIDATION_LOCKED_SPLIT]["max_drawdown_bps"], + ) + + +def _pm_tune_frame(root) -> pd.DataFrame: + return _pm_frame(root, TUNE_SPLIT) + + +def _pm_frame(root, split_id: str) -> pd.DataFrame: + prediction_files = { + TUNE_SPLIT: "tune_predictions.parquet", + VALIDATION_LOCKED_SPLIT: "validation_locked_predictions.parquet", + LATEST_STRESS_SPLIT: "latest_stress_predictions.parquet", + } + prediction_file = prediction_files[split_id] + direction = read_parquet(root / "model" / "direction" / prediction_file) + entry = read_parquet(root / "model" / "entry" / prediction_file).rename( + columns={ + "long_expected_net_edge_bps": "pred_long_expected_net_edge_bps", + "short_expected_net_edge_bps": "pred_short_expected_net_edge_bps", + } + ) + risk = read_parquet(root / "model" / "risk" / prediction_file) + entry_dataset = read_parquet(root / "dataset" / "entry_train.parquet").rename( + columns={ + "long_expected_net_edge_bps": "actual_long_expected_net_edge_bps", + "short_expected_net_edge_bps": "actual_short_expected_net_edge_bps", + } + ) + entry_cols = [ + "sample_id", + "long_entry_prob", + "short_entry_prob", + "pred_long_expected_net_edge_bps", + "pred_short_expected_net_edge_bps", + ] + risk_cols = ["sample_id", "market_risk_prob", "long_position_risk_prob", "short_position_risk_prob"] + actual_cols = ["sample_id", "actual_long_expected_net_edge_bps", "actual_short_expected_net_edge_bps", "long_entry_target", "short_entry_target"] + frame = ( + direction[["sample_id", "symbol", "event_time", "split_id", "long_prob", "short_prob", "neutral_prob"]] + .merge(entry[entry_cols], on="sample_id", how="inner") + .merge(risk[risk_cols], on="sample_id", how="inner") + .merge(entry_dataset[actual_cols], on="sample_id", how="inner") + ) + if frame.empty: + raise ValueError(f"PM frame is empty for {split_id}; check model predictions and entry dataset") + logging.info( + "trader.training.pm_frame_loaded splitId=%s rowCount=%s splitCounts=%s", + split_id, + len(frame), + frame["split_id"].value_counts().to_dict(), + ) + return frame + + +def _threshold_candidates() -> list[dict[str, float]]: + values = itertools.product( + [0.54, 0.56, 0.58, 0.60], + [0.54, 0.56, 0.58, 0.60], + [0.50, 0.52, 0.55, 0.58], + [0.35, 0.45, 0.55], + [1.0, 2.0, 3.0, 5.0], + [0.02, 0.03, 0.05], + ) + return [ + { + "long_open_prob": long_prob, + "short_open_prob": short_prob, + "min_entry_prob": entry_prob, + "max_market_risk_prob": risk_prob, + "min_expected_edge_bps": edge_bps, + "min_direction_margin": margin, + } + for long_prob, short_prob, entry_prob, risk_prob, edge_bps, margin in values + ] + + +def _simulate_open_trades(frame: pd.DataFrame, thresholds: dict[str, float]) -> pd.DataFrame: + long_mask = ( + (frame["long_prob"] >= thresholds["long_open_prob"]) + & ((frame["long_prob"] - frame["short_prob"]) >= thresholds["min_direction_margin"]) + & (frame["long_entry_prob"] >= thresholds["min_entry_prob"]) + & (frame["market_risk_prob"] <= thresholds["max_market_risk_prob"]) + & (frame["pred_long_expected_net_edge_bps"] >= thresholds["min_expected_edge_bps"]) + ) + short_mask = ( + (frame["short_prob"] >= thresholds["short_open_prob"]) + & ((frame["short_prob"] - frame["long_prob"]) >= thresholds["min_direction_margin"]) + & (frame["short_entry_prob"] >= thresholds["min_entry_prob"]) + & (frame["market_risk_prob"] <= thresholds["max_market_risk_prob"]) + & (frame["pred_short_expected_net_edge_bps"] >= thresholds["min_expected_edge_bps"]) + ) + long_score = frame["pred_long_expected_net_edge_bps"] + (frame["long_prob"] - frame["short_prob"]) * 10.0 + short_score = frame["pred_short_expected_net_edge_bps"] + (frame["short_prob"] - frame["long_prob"]) * 10.0 + side = np.where(long_mask & (~short_mask | (long_score >= short_score)), "LONG", np.where(short_mask, "SHORT", "")) + trades = frame.loc[side != ""].copy().reset_index(drop=True) + if trades.empty: + return _empty_trade_frame() + trades["side"] = side[side != ""] + is_long = trades["side"].eq("LONG") + trades["direction_prob"] = np.where(is_long, trades["long_prob"], trades["short_prob"]) + trades["entry_prob"] = np.where(is_long, trades["long_entry_prob"], trades["short_entry_prob"]) + trades["predicted_edge_bps"] = np.where(is_long, trades["pred_long_expected_net_edge_bps"], trades["pred_short_expected_net_edge_bps"]) + trades["actual_edge_bps"] = np.where(is_long, trades["actual_long_expected_net_edge_bps"], trades["actual_short_expected_net_edge_bps"]) + trades["entry_target"] = np.where(is_long, trades["long_entry_target"], trades["short_entry_target"]) + trades["planned_ratio"] = _planned_ratio(trades["predicted_edge_bps"], trades["market_risk_prob"], thresholds["min_expected_edge_bps"]) + trades["weighted_edge_bps"] = trades["actual_edge_bps"] * trades["planned_ratio"] + trades["threshold_hash"] = sha256_json(thresholds)[:16] + return trades[ + [ + "sample_id", + "symbol", + "event_time", + "split_id", + "side", + "direction_prob", + "entry_prob", + "market_risk_prob", + "predicted_edge_bps", + "actual_edge_bps", + "entry_target", + "planned_ratio", + "weighted_edge_bps", + "threshold_hash", + ] + ].sort_values("event_time") + + +def _empty_trade_frame() -> pd.DataFrame: + return pd.DataFrame( + columns=[ + "sample_id", + "symbol", + "event_time", + "split_id", + "side", + "direction_prob", + "entry_prob", + "market_risk_prob", + "predicted_edge_bps", + "actual_edge_bps", + "entry_target", + "planned_ratio", + "weighted_edge_bps", + "threshold_hash", + ] + ) + + +def _planned_ratio(predicted_edge: pd.Series, market_risk: pd.Series, min_edge: float) -> np.ndarray: + edge_strength = ((predicted_edge.astype(float) - min_edge) / 20.0).clip(lower=0.0, upper=1.5) + risk_discount = (1.0 - market_risk.astype(float)).clip(lower=0.0, upper=1.0) + return (edge_strength * risk_discount).clip(lower=0.05, upper=1.0).to_numpy() + + +def _trade_metrics(trades: pd.DataFrame) -> dict[str, Any]: + if trades.empty: + return { + "trade_count": 0, + "win_rate": 0.0, + "avg_actual_edge_bps": 0.0, + "avg_weighted_edge_bps": 0.0, + "total_weighted_edge_bps": 0.0, + "max_drawdown_bps": 0.0, + "avg_planned_ratio": 0.0, + "profit_factor": 0.0, + "max_consecutive_losses": 0, + } + equity = trades["weighted_edge_bps"].astype(float).cumsum() + drawdown = equity.cummax() - equity + gains = trades.loc[trades["weighted_edge_bps"] > 0, "weighted_edge_bps"].astype(float).sum() + losses = -trades.loc[trades["weighted_edge_bps"] < 0, "weighted_edge_bps"].astype(float).sum() + return { + "trade_count": int(len(trades)), + "win_rate": float((trades["actual_edge_bps"].astype(float) > 0).mean()), + "avg_actual_edge_bps": float(trades["actual_edge_bps"].astype(float).mean()), + "avg_weighted_edge_bps": float(trades["weighted_edge_bps"].astype(float).mean()), + "total_weighted_edge_bps": float(equity.iloc[-1]), + "max_drawdown_bps": float(drawdown.max()), + "avg_planned_ratio": float(trades["planned_ratio"].astype(float).mean()), + "profit_factor": float(gains / losses) if losses > 0 else float("inf"), + "max_consecutive_losses": _max_consecutive_losses(trades["weighted_edge_bps"].astype(float).to_numpy()), + } + + +def _max_consecutive_losses(values: np.ndarray) -> int: + max_count = 0 + current = 0 + for value in values: + if value < 0: + current += 1 + max_count = max(max_count, current) + else: + current = 0 + return max_count + + +def _backtest_status(metrics: dict[str, dict[str, Any]]) -> tuple[str, list[str]]: + reasons: list[str] = [] + validation_locked = metrics[VALIDATION_LOCKED_SPLIT] + stress = metrics[LATEST_STRESS_SPLIT] + if validation_locked["total_weighted_edge_bps"] <= 0: + reasons.append("validation_locked_net_edge_not_positive") + if validation_locked["trade_count"] < 80: + reasons.append("validation_locked_trade_count_below_80") + if validation_locked["profit_factor"] < 1.15: + reasons.append("validation_locked_profit_factor_below_1.15") + if validation_locked["avg_weighted_edge_bps"] <= 0: + reasons.append("validation_locked_avg_trade_edge_not_positive") + if validation_locked["max_consecutive_losses"] > 8: + reasons.append("validation_locked_max_consecutive_losses_above_8") + if stress["trade_count"] < 20: + reasons.append("latest_stress_trade_count_below_20") + if stress["profit_factor"] < 1.0: + reasons.append("latest_stress_profit_factor_below_1.0") + if stress["avg_weighted_edge_bps"] < -3.0: + reasons.append("latest_stress_avg_trade_edge_below_minus_3") + if stress["max_consecutive_losses"] > 10: + reasons.append("latest_stress_max_consecutive_losses_above_10") + if validation_locked["total_weighted_edge_bps"] > 0 and stress["total_weighted_edge_bps"] < -0.5 * validation_locked["total_weighted_edge_bps"]: + reasons.append("latest_stress_loss_too_large_vs_validation") + return ("REJECTED", reasons) if reasons else ("PASS", []) + + +def _score_thresholds(metrics: dict[str, Any]) -> float: + if metrics["trade_count"] == 0: + return -1_000_000.0 + low_sample_penalty = max(0, 20 - int(metrics["trade_count"])) * 5.0 + return ( + metrics["avg_weighted_edge_bps"] * np.sqrt(metrics["trade_count"]) + + metrics["total_weighted_edge_bps"] * 0.05 + - metrics["max_drawdown_bps"] * 0.25 + - low_sample_penalty + ) + + +def _pm_config_from_thresholds(thresholds: dict[str, float]) -> dict: + config = default_pm_config() + config["open"].update( + { + "longOpenProb": thresholds["long_open_prob"], + "shortOpenProb": thresholds["short_open_prob"], + "minLongEntryProb": thresholds["min_entry_prob"], + "minShortEntryProb": thresholds["min_entry_prob"], + "maxMarketRiskProb": thresholds["max_market_risk_prob"], + "minExpectedEdgeBps": thresholds["min_expected_edge_bps"], + "minDirectionMargin": thresholds["min_direction_margin"], + } + ) + config["add"]["maxMarketRiskProb"] = thresholds["max_market_risk_prob"] + config["add"]["minExpectedEdgeBps"] = thresholds["min_expected_edge_bps"] + config["sizing"]["minEdgeBps"] = thresholds["min_expected_edge_bps"] + config["sizing"]["maxSingleLegRatio"] = 1.0 + return config + + +def _thresholds_from_config(config: dict) -> dict[str, float]: + open_config = config["open"] + return { + "long_open_prob": float(open_config["longOpenProb"]), + "short_open_prob": float(open_config["shortOpenProb"]), + "min_entry_prob": float(min(open_config["minLongEntryProb"], open_config["minShortEntryProb"])), + "max_market_risk_prob": float(open_config["maxMarketRiskProb"]), + "min_expected_edge_bps": float(open_config["minExpectedEdgeBps"]), + "min_direction_margin": float(open_config["minDirectionMargin"]), + } + + +def _equity_curve(trades: pd.DataFrame) -> pd.DataFrame: + if trades.empty: + return pd.DataFrame(columns=["event_time", "trade_index", "weighted_edge_bps", "equity_bps", "drawdown_bps"]) + curve = trades[["event_time", "weighted_edge_bps"]].copy().reset_index(drop=True) + curve["trade_index"] = np.arange(1, len(curve) + 1) + curve["equity_bps"] = curve["weighted_edge_bps"].astype(float).cumsum() + curve["drawdown_bps"] = curve["equity_bps"].cummax() - curve["equity_bps"] + return curve[["event_time", "trade_index", "weighted_edge_bps", "equity_bps", "drawdown_bps"]] + + +def _regime_metrics(trades: pd.DataFrame) -> pd.DataFrame: + if trades.empty: + return pd.DataFrame(columns=["split_id", "side", "trade_count", "win_rate", "avg_actual_edge_bps", "total_weighted_edge_bps"]) + rows = [] + for (split_id, side), group in trades.groupby(["split_id", "side"], sort=True): + metrics = _trade_metrics(group) + rows.append( + { + "split_id": split_id, + "side": side, + "trade_count": metrics["trade_count"], + "win_rate": metrics["win_rate"], + "avg_actual_edge_bps": metrics["avg_actual_edge_bps"], + "total_weighted_edge_bps": metrics["total_weighted_edge_bps"], + } + ) + return pd.DataFrame(rows) + + +def _write_pm_report(path, candidates: pd.DataFrame, best_thresholds: dict[str, float], best_metrics: dict[str, Any]) -> None: + top = candidates.head(10) + lines = [ + "# PM Threshold Report", + "", + "本次不是固定写死阈值,而是在验证集上试一组可复现的阈值,选择净收益、回撤、交易数量综合更好的那组。", + "", + "## Best Thresholds", + "", + "```json", + str(best_thresholds).replace("'", '"'), + "```", + "", + "## Best Metrics", + "", + "```json", + str(best_metrics).replace("'", '"'), + "```", + "", + "## Top Candidates", + "", + _markdown_table(top.to_dict("records"), list(top.columns)), + "", + ] + write_text(path, "\n".join(lines)) + + +def _write_backtest_report(path, result: dict[str, Any]) -> None: + lines = [ + "# Integrated Backtest Report", + "", + "这里用验证集模型输出和 PM 阈值生成交易明细,统计净收益、胜率、回撤和分段表现。", + "", + "```json", + str(result).replace("'", '"'), + "```", + "", + ] + write_text(path, "\n".join(lines)) + + +def _write_failure_cases(path, trades: pd.DataFrame) -> None: + worst = trades.sort_values("weighted_edge_bps").head(20) if not trades.empty else trades + lines = [ + "# Backtest Failure Cases", + "", + "按加权净收益从差到好列出最差样本,方便回看特征、标签和阈值。", + "", + _markdown_table(worst.to_dict("records"), list(worst.columns)) if not worst.empty else "无交易样本。", + "", + ] + write_text(path, "\n".join(lines)) + + +def _write_no_baseline_ablation(path) -> None: + lines = [ + "# Direction Ablation Backtest Report", + "", + "- status: NO_BASELINE", + "- reason: 当前 run 目录没有旧 Direction 基准模型包,所以首版不能做只替换 Direction 的消融回测。", + "- action: 后续版本必须拿上一版 ACTIVE 包做 baseline,再比较新 Direction 是否真的提升。", + "", + ] + write_text(path, "\n".join(lines)) + + +def _markdown_table(rows: list[dict[str, Any]], columns: list[str]) -> str: + lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"] + for row in rows: + lines.append("| " + " | ".join(str(row.get(column, "")) for column in columns) + " |") + return "\n".join(lines) diff --git a/training/trader_training/promote.py b/training/trader_training/promote.py new file mode 100644 index 0000000..2c06621 --- /dev/null +++ b/training/trader_training/promote.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import logging +from typing import Any + +from trader_training.io_utils import read_json, utc_now_text, write_json, write_text + + +def promote_artifact_bundle(args: Any) -> None: + artifact_root = args.artifact_root + validation_path = artifact_root.parent / "artifact_validation_result.json" + validation = read_json(validation_path) + if validation.get("status") != "PASS": + _refuse(artifact_root, args.reason, "validation result is not PASS", validation) + if validation.get("release_gate_status") != "PASS": + _refuse(artifact_root, args.reason, "release gate is not PASS", validation) + + promotion = { + "promoted_at": utc_now_text(), + "reason": args.reason, + "validation_result_path": str(validation_path), + } + + bundle_path = artifact_root / "manifests" / "model_bundle_manifest.json" + bundle = read_json(bundle_path) + + model_manifest_path = artifact_root / "manifests" / "model_manifest.json" + model_rows = read_json(model_manifest_path) + + calibration_manifest_path = artifact_root / "manifests" / "calibration_manifest.json" + calibration_rows = read_json(calibration_manifest_path) + + pm_manifest_path = artifact_root / "manifests" / "position_manager_manifest.json" + pm_manifest = read_json(pm_manifest_path) + + export_manifest_path = artifact_root / "manifests" / "training_export_manifest.json" + export_manifest = read_json(export_manifest_path) + + # Check every manifest before writing any ACTIVE status, so a bad bundle + # cannot be left half-promoted if one file fails late. + _require_candidate("model_bundle_manifest", bundle.get("status")) + for row in model_rows: + _require_candidate(f"model_manifest.{row.get('model_name')}", row.get("status")) + for row in calibration_rows: + _require_candidate(f"calibration_manifest.{row.get('model_name')}", row.get("status")) + _require_candidate("position_manager_manifest", pm_manifest.get("status")) + _require_candidate("training_export_manifest", export_manifest.get("status")) + + bundle["status"] = "ACTIVE" + bundle["promotion_json"] = promotion + for row in model_rows: + row["status"] = "ACTIVE" + row["promotion_json"] = promotion + for row in calibration_rows: + row["status"] = "ACTIVE" + row["promotion_json"] = promotion + pm_manifest["status"] = "ACTIVE" + pm_manifest["promotion_json"] = promotion + export_manifest["status"] = "ACTIVE" + export_manifest["promotion_json"] = promotion + + write_json(bundle_path, bundle) + write_json(model_manifest_path, model_rows) + write_json(calibration_manifest_path, calibration_rows) + write_json(pm_manifest_path, pm_manifest) + write_json(export_manifest_path, export_manifest) + + result = {"status": "ACTIVE", "artifact_root": str(artifact_root), "promotion": promotion} + write_json(artifact_root.parent / "artifact_promotion_result.json", result) + write_text( + artifact_root.parent / "artifact_promotion_report.md", + "\n".join( + [ + "# Artifact Promotion Report", + "", + "- status: ACTIVE", + f"- artifact_root: {artifact_root}", + f"- reason: {args.reason}", + f"- promoted_at: {promotion['promoted_at']}", + "", + ] + ), + ) + logging.info("trader.training.artifact_promoted status=ACTIVE path=%s reason=%s", artifact_root, args.reason) + + +def _require_candidate(name: str, status: str | None) -> None: + if status != "CANDIDATE": + raise SystemExit(f"artifact promotion refused: {name} status must be CANDIDATE, actual={status}") + + +def _refuse(artifact_root: Any, reason: str, message: str, validation: dict[str, Any]) -> None: + result = { + "status": "REFUSED", + "artifact_root": str(artifact_root), + "reason": reason, + "message": message, + "validation_status": validation.get("status"), + "release_gate_status": validation.get("release_gate_status"), + "release_gate_reasons": validation.get("release_gate_reasons", []), + "refused_at": utc_now_text(), + } + write_json(artifact_root.parent / "artifact_promotion_result.json", result) + write_text( + artifact_root.parent / "artifact_promotion_report.md", + "\n".join( + [ + "# Artifact Promotion Report", + "", + "- status: REFUSED", + f"- artifact_root: {artifact_root}", + f"- reason: {reason}", + f"- message: {message}", + f"- validation_status: {result['validation_status']}", + f"- release_gate_status: {result['release_gate_status']}", + f"- release_gate_reasons: {result['release_gate_reasons']}", + f"- refused_at: {result['refused_at']}", + "", + ] + ), + ) + logging.warning( + "trader.training.artifact_promotion_refused status=REFUSED path=%s reason=%s message=%s releaseGate=%s", + artifact_root, + reason, + message, + result["release_gate_status"], + ) + raise SystemExit(f"artifact promotion refused: {message}, reasons={result['release_gate_reasons']}") diff --git a/training/trader_training/replay.py b/training/trader_training/replay.py new file mode 100644 index 0000000..537b4d4 --- /dev/null +++ b/training/trader_training/replay.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + +from trader_training.io_utils import ( + DEFAULT_RAW_ROOT, + ensure_dir, + manifest, + open_time_ms, + partition_files, + read_json, + read_partitioned_table, + read_parquet, + require_columns, + run_root, + to_utc_series, + utc_now_text, + write_json, + write_parquet, + write_text, +) +from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, SPLIT_VERSION, TRAINING_SPLITS, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT + + +def audit_source_data(data_root: Path, symbol: str, start_date: str | None, end_date: str | None, min_ready_days: int = 250) -> dict[str, Any]: + raw_root = data_root / "crypto-lake" / "raw" + required_tables = ("candles", "trades", "level_1", "funding", "open_interest") + optional_tables = ("liquidations",) + rows: list[dict[str, Any]] = [] + table_dates: dict[str, set[str]] = {} + for table in required_tables + optional_tables: + files = partition_files(raw_root, table, symbol, start_date, end_date) + dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files}) + table_dates[table] = set(dates) + rows.append( + { + "table": table, + "required": table in required_tables, + "file_count": len(files), + "first_date": dates[0] if dates else None, + "last_date": dates[-1] if dates else None, + "status": "OK" if files or table in optional_tables else "MISSING", + } + ) + all_dates = _audit_date_range(table_dates, required_tables, start_date, end_date) + replay_ready_days = [] + excluded_days = [] + for day in all_dates: + missing_required = [table for table in required_tables if day not in table_dates[table]] + missing_optional = [table for table in optional_tables if day not in table_dates[table]] + if missing_required: + excluded_days.append({"date": day, "reason": "MISSING_REQUIRED_TABLE", "missing_required_tables": missing_required, "missing_optional_tables": missing_optional}) + else: + replay_ready_days.append(day) + result = { + "symbol": symbol, + "start_date": start_date, + "end_date": end_date, + "raw_root": str(raw_root), + "tables": rows, + "replay_ready_day_count": len(replay_ready_days), + "excluded_day_count": len(excluded_days), + "replay_ready_days": replay_ready_days, + "excluded_days": excluded_days, + "created_at": utc_now_text(), + "ready": all(row["status"] == "OK" for row in rows if row["required"]) and len(replay_ready_days) >= min_ready_days, + } + return result + + +def write_audit_outputs(args: Any) -> None: + root = run_root(args) + result = audit_source_data(args.data_root, args.symbol, args.start_date, args.end_date, int(args.min_ready_days)) + path = root / "raw-manifest" / "source_data_audit.json" + write_json(path, result) + write_json(root / "raw-manifest" / "source_data_manifest.json", result) + write_json(root / "raw-manifest" / "excluded_days.json", result["excluded_days"]) + write_text(root / "raw-manifest" / "replay_ready_days.txt", "\n".join(result["replay_ready_days"]) + ("\n" if result["replay_ready_days"] else "")) + report_lines = [ + "# Trader Source Data Audit", + "", + f"- symbol: {result['symbol']}", + f"- raw_root: {result['raw_root']}", + f"- ready: {result['ready']}", + f"- replay_ready_day_count: {result['replay_ready_day_count']}", + f"- excluded_day_count: {result['excluded_day_count']}", + "", + "| table | required | file_count | first_date | last_date | status |", + "| --- | --- | ---: | --- | --- | --- |", + ] + for row in result["tables"]: + report_lines.append( + f"| {row['table']} | {row['required']} | {row['file_count']} | {row['first_date']} | {row['last_date']} | {row['status']} |" + ) + write_text(root / "raw-manifest" / "source_data_audit.md", "\n".join(report_lines) + "\n") + logging.info( + "trader.training.audit_written runId=%s ready=%s readyDays=%s excludedDays=%s path=%s", + args.run_id, + result["ready"], + result["replay_ready_day_count"], + result["excluded_day_count"], + path, + ) + if not result["ready"]: + raise SystemExit("required raw tables are missing; see source_data_audit.md") + + +def _audit_date_range(table_dates: dict[str, set[str]], required_tables: tuple[str, ...], start_date: str | None, end_date: str | None) -> list[str]: + if start_date and end_date: + start = pd.Timestamp(start_date) + end = pd.Timestamp(end_date) + else: + dates = sorted(set().union(*(table_dates[table] for table in required_tables))) + if not dates: + return [] + start = pd.Timestamp(start_date or dates[0]) + end = pd.Timestamp(end_date or dates[-1]) + return [day.strftime("%Y-%m-%d") for day in pd.date_range(start, end, freq="D")] + + +def _minute_frame(frame: pd.DataFrame, time_column: str = "origin_time") -> pd.DataFrame: + frame = frame.copy() + frame["event_time"] = to_utc_series(frame[time_column]).dt.floor("min") + frame["open_time_ms"] = open_time_ms(frame["event_time"]) + return frame + + +def _read_candles(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: + candles = read_partitioned_table( + raw_root, + "candles", + symbol, + start_date, + end_date, + columns=("origin_time", "start", "open", "high", "low", "close", "volume", "symbol"), + ) + if candles.empty: + raise ValueError("candles raw data is required to build replay_1m") + time_col = "start" if "start" in candles.columns else "origin_time" + candles = _minute_frame(candles, time_col) + keep = ["symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "volume"] + candles = candles[keep].sort_values(["symbol", "event_time"]).drop_duplicates(["symbol", "event_time"], keep="last") + for column in ("open", "high", "low", "close", "volume"): + candles[column] = pd.to_numeric(candles[column], errors="coerce") + return candles + + +def _read_trades(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: + trades = read_partitioned_table( + raw_root, + "trades", + symbol, + start_date, + end_date, + columns=("origin_time", "side", "quantity", "symbol"), + ) + if trades.empty: + raise ValueError("trades raw data is required for taker imbalance") + trades = _minute_frame(trades) + trades["quantity"] = pd.to_numeric(trades["quantity"], errors="coerce").fillna(0.0) + side = trades["side"].astype(str).str.upper() + trades["taker_buy_volume"] = np.where(side.eq("BUY"), trades["quantity"], 0.0) + trades["taker_sell_volume"] = np.where(side.eq("SELL"), trades["quantity"], 0.0) + return ( + trades.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[["taker_buy_volume", "taker_sell_volume"]] + .sum() + ) + + +def _read_level1(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: + level1 = read_partitioned_table( + raw_root, + "level_1", + symbol, + start_date, + end_date, + columns=("origin_time", "bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size", "symbol"), + ) + if level1.empty: + raise ValueError("level_1 raw data is required for spread and OFI") + level1 = _minute_frame(level1) + for column in ("bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"): + level1[column] = pd.to_numeric(level1[column], errors="coerce") + level1 = level1.dropna(subset=["bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"]) + level1 = level1.sort_values(["symbol", "event_time", "origin_time"]) + group = level1.groupby("symbol", sort=False, observed=True) + prev_bid_price = group["bid_0_price"].shift(1) + prev_bid_size = group["bid_0_size"].shift(1) + prev_ask_price = group["ask_0_price"].shift(1) + prev_ask_size = group["ask_0_size"].shift(1) + bid_ofi = np.select( + [level1["bid_0_price"] > prev_bid_price, level1["bid_0_price"].eq(prev_bid_price)], + [level1["bid_0_size"], level1["bid_0_size"] - prev_bid_size], + default=-prev_bid_size, + ) + ask_ofi = np.select( + [level1["ask_0_price"] < prev_ask_price, level1["ask_0_price"].eq(prev_ask_price)], + [level1["ask_0_size"], prev_ask_size - level1["ask_0_size"]], + default=-prev_ask_size, + ) + level1["ofi_raw"] = np.nan_to_num(bid_ofi + ask_ofi, nan=0.0) + level1["depth"] = (level1["bid_0_size"] + level1["ask_0_size"]).clip(lower=1e-12) + level1["mid"] = (level1["bid_0_price"] + level1["ask_0_price"]) / 2.0 + level1["spread_bps"] = (level1["ask_0_price"] - level1["bid_0_price"]) / level1["mid"] * 10000.0 + agg = level1.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True).agg( + best_bid_price=("bid_0_price", "last"), + best_ask_price=("ask_0_price", "last"), + spread_bps=("spread_bps", "last"), + ofi_sum=("ofi_raw", "sum"), + depth_mean=("depth", "mean"), + ) + agg["level1_ofi_1m"] = agg["ofi_sum"] / agg["depth_mean"].clip(lower=1e-12) + return agg.drop(columns=["ofi_sum", "depth_mean"]) + + +def _read_liquidations(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: + files = partition_files(raw_root, "liquidations", symbol, start_date, end_date) + if not files: + return pd.DataFrame(columns=["symbol", "event_time", "open_time_ms", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"]) + liquidations = read_partitioned_table( + raw_root, + "liquidations", + symbol, + start_date, + end_date, + columns=("origin_time", "side", "quantity", "price", "symbol"), + ) + liquidations = _minute_frame(liquidations) + liquidations["quantity"] = pd.to_numeric(liquidations["quantity"], errors="coerce").fillna(0.0) + liquidations["price"] = pd.to_numeric(liquidations["price"], errors="coerce").fillna(0.0) + liquidations["notional"] = liquidations["quantity"] * liquidations["price"] + side = liquidations["side"].astype(str).str.upper() + liquidations["liquidation_buy_notional_1m"] = np.where(side.eq("BUY"), liquidations["notional"], 0.0) + liquidations["liquidation_sell_notional_1m"] = np.where(side.eq("SELL"), liquidations["notional"], 0.0) + agg = liquidations.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[ + ["liquidation_buy_notional_1m", "liquidation_sell_notional_1m"] + ].sum() + agg["liquidation_available"] = 1.0 + return agg + + +def _asof_column( + replay: pd.DataFrame, + raw_root: Path, + table: str, + symbol: str, + start_date: str | None, + end_date: str | None, + columns: tuple[str, ...], +) -> pd.DataFrame: + frame = read_partitioned_table(raw_root, table, symbol, start_date, end_date, columns=("origin_time", "symbol", *columns)) + if frame.empty: + raise ValueError(f"{table} raw data is required") + frame = _minute_frame(frame) + for column in columns: + if column.endswith("time"): + continue + frame[column] = pd.to_numeric(frame[column], errors="coerce") + frame = frame.sort_values(["symbol", "event_time"]) + left = replay[["symbol", "event_time"]].sort_values(["symbol", "event_time"]) + merged = pd.merge_asof( + left, + frame[["symbol", "event_time", *columns]].sort_values(["symbol", "event_time"]), + by="symbol", + on="event_time", + direction="backward", + tolerance=pd.Timedelta(hours=12), + ) + return merged + + +def build_replay_1m(args: Any) -> None: + root = run_root(args) + raw_root = args.raw_root or DEFAULT_RAW_ROOT + logging.info("trader.training.replay_started runId=%s symbol=%s rawRoot=%s", args.run_id, args.symbol, raw_root) + replay = _read_candles(raw_root, args.symbol, args.start_date, args.end_date) + trades = _read_trades(raw_root, args.symbol, args.start_date, args.end_date) + level1 = _read_level1(raw_root, args.symbol, args.start_date, args.end_date) + liquidations = _read_liquidations(raw_root, args.symbol, args.start_date, args.end_date) + replay = replay.merge(trades, on=["symbol", "event_time", "open_time_ms"], how="left") + replay = replay.merge(level1, on=["symbol", "event_time", "open_time_ms"], how="left") + replay = replay.merge(liquidations, on=["symbol", "event_time", "open_time_ms"], how="left") + replay[["taker_buy_volume", "taker_sell_volume"]] = replay[["taker_buy_volume", "taker_sell_volume"]].fillna(0.0) + for column in ("liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"): + replay[column] = replay[column].fillna(0.0) + + funding = _asof_column(replay, raw_root, "funding", args.symbol, args.start_date, args.end_date, ("rate", "mark_price", "index_price", "next_funding_time")) + funding = funding.rename(columns={"rate": "funding_rate"}) + funding["funding_bps"] = pd.to_numeric(funding["funding_rate"], errors="coerce") * 10000.0 + replay = replay.merge(funding.drop(columns=["funding_rate"]), on=["symbol", "event_time"], how="left") + replay["next_funding_time"] = to_utc_series(replay["next_funding_time"]) + + oi = _asof_column(replay, raw_root, "open_interest", args.symbol, args.start_date, args.end_date, ("open_interest",)) + replay = replay.merge(oi, on=["symbol", "event_time"], how="left") + replay["timeframe"] = "1m" + replay["source_coverage"] = "crypto_lake_raw" + + required = [ + "open", + "high", + "low", + "close", + "volume", + "best_bid_price", + "best_ask_price", + "spread_bps", + "level1_ofi_1m", + "funding_bps", + "mark_price", + "index_price", + "open_interest", + ] + replay["event_date"] = replay["event_time"].dt.strftime("%Y-%m-%d") + missing_required = replay[required].isna().any(axis=1) + day_quality = ( + replay.assign(missing_required=missing_required.astype(int)) + .groupby("event_date", as_index=False, observed=True) + .agg(row_count=("event_time", "count"), missing_required_rows=("missing_required", "sum")) + ) + day_quality["ready"] = (day_quality["row_count"] >= int(args.min_minutes_per_day)) & day_quality["missing_required_rows"].eq(0) + ready_days = sorted(day_quality.loc[day_quality["ready"], "event_date"].astype(str).tolist()) + excluded_days = [ + { + "date": row.event_date, + "row_count": int(row.row_count), + "missing_required_rows": int(row.missing_required_rows), + "reason": "MISSING_REQUIRED_MARKET_FIELDS" if int(row.missing_required_rows) else "INCOMPLETE_MINUTE_COUNT", + } + for row in day_quality.loc[~day_quality["ready"]].itertuples(index=False) + ] + if len(ready_days) < int(args.min_replay_ready_days): + write_json(root / "replay" / "excluded_days.json", excluded_days) + write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + ("\n" if ready_days else "")) + raise ValueError(f"replay_1m has only {len(ready_days)} replay-ready days, required {args.min_replay_ready_days}") + before_filter = len(replay) + replay = replay[replay["event_date"].isin(ready_days)].copy() + logging.info( + "trader.training.replay_ready_days_selected runId=%s readyDays=%s excludedDays=%s rowBefore=%s rowAfter=%s", + args.run_id, + len(ready_days), + len(excluded_days), + before_filter, + len(replay), + ) + + columns = [ + "symbol", + "timeframe", + "event_time", + "open_time_ms", + "open", + "high", + "low", + "close", + "volume", + "taker_buy_volume", + "taker_sell_volume", + "funding_bps", + "mark_price", + "index_price", + "next_funding_time", + "open_interest", + "best_bid_price", + "best_ask_price", + "spread_bps", + "level1_ofi_1m", + "liquidation_buy_notional_1m", + "liquidation_sell_notional_1m", + "liquidation_available", + "source_coverage", + ] + replay = replay[columns].sort_values(["symbol", "event_time"]).reset_index(drop=True) + path = root / "replay" / "replay_1m.parquet" + data_hash = write_parquet(path, replay) + write_json( + root / "replay" / "replay_1m.manifest.json", + manifest( + path, + { + "row_count": len(replay), + "hash_sha256": data_hash, + "replay_ready_day_count": len(ready_days), + "excluded_day_count": len(excluded_days), + "min_minutes_per_day": int(args.min_minutes_per_day), + }, + ), + ) + write_json(root / "replay" / "excluded_days.json", excluded_days) + write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + "\n") + logging.info("trader.training.replay_written runId=%s rowCount=%s readyDays=%s path=%s", args.run_id, len(replay), len(ready_days), path) + + +def build_splits(args: Any) -> None: + root = run_root(args) + replay_path = args.replay_path or root / "replay" / "replay_1m.parquet" + replay = read_parquet(replay_path) + require_columns(replay, ("event_time", "symbol"), "replay_1m") + replay["event_time"] = to_utc_series(replay["event_time"]) + replay = replay.sort_values(["event_time", "symbol"]).reset_index(drop=True) + if len(replay) < 10: + raise ValueError("not enough replay rows to build time splits") + gap = int(args.gap_minutes) + intervals = _fixed_split_intervals(args, gap) + replay_start = replay["event_time"].min() + replay_end = replay["event_time"].max() + intervals = [ + (split_id, max(start, replay_start), min(end, replay_end)) + for split_id, start, end in intervals + if max(start, replay_start) <= min(end, replay_end) + ] + if {item[0] for item in intervals} != set(TRAINING_SPLITS): + raise ValueError(f"fixed split dates do not fit replay coverage: replay_start={replay_start} replay_end={replay_end}") + split_manifest = { + "split_version": SPLIT_VERSION, + "created_at": utc_now_text(), + "source_replay_path": str(replay_path), + "gap_minutes": gap, + # Sealed splits are withheld from broad parameter search. They only answer + # whether a finished candidate survives final validation and recent stress. + "sealed_splits": [VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], + "latest_stress_policy": "FINAL_GATE_ONLY", + "requested_splits": { + FIT_SPLIT: [args.fit_inner_start, args.fit_inner_end], + TUNE_SPLIT: [args.tune_inner_start, args.tune_inner_end], + VALIDATION_LOCKED_SPLIT: [args.validation_locked_start, args.validation_locked_end], + LATEST_STRESS_SPLIT: [args.latest_stress_start, args.latest_stress_end], + }, + "splits": [ + {"split_id": split_id, "start": start.isoformat().replace("+00:00", "Z"), "end": end.isoformat().replace("+00:00", "Z")} + for split_id, start, end in intervals + if start <= end + ], + } + fold_count = max(1, int(args.fold_count)) + fit_interval = next(item for item in intervals if item[0] == FIT_SPLIT) + tune_interval = next(item for item in intervals if item[0] == TUNE_SPLIT) + train_times = pd.Series(pd.date_range(fit_interval[1], fit_interval[2], periods=fold_count + 1)) + folds = [] + for idx in range(fold_count): + folds.append( + { + "walk_forward_fold": f"fold_{idx + 1:02d}", + "train_start": fit_interval[1].isoformat().replace("+00:00", "Z"), + "train_end": train_times.iloc[idx + 1].isoformat().replace("+00:00", "Z"), + "validation_start": tune_interval[1].isoformat().replace("+00:00", "Z"), + "validation_end": tune_interval[2].isoformat().replace("+00:00", "Z"), + } + ) + ensure_dir(root / "split") + write_json(root / "split" / "split_manifest.json", split_manifest) + write_json(root / "split" / "walk_forward_folds.json", {"split_version": SPLIT_VERSION, "folds": folds}) + _write_purge_embargo_report(root / "split" / "purge_embargo_report.md", intervals, gap) + logging.info("trader.training.splits_written runId=%s splitCount=%s foldCount=%s", args.run_id, len(split_manifest["splits"]), len(folds)) + + +def assign_split(event_times: pd.Series, split_manifest_path: Path) -> pd.Series: + manifest_data = read_json(split_manifest_path) + result = pd.Series("NO_SPLIT", index=event_times.index, dtype="object") + values = to_utc_series(event_times) + for item in manifest_data["splits"]: + start = pd.Timestamp(item["start"]) + end = pd.Timestamp(item["end"]) + mask = values.between(start, end, inclusive="both") + result.loc[mask] = item["split_id"] + return result + + +def _fixed_split_intervals(args: Any, gap_minutes: int) -> list[tuple[str, pd.Timestamp, pd.Timestamp]]: + gap = pd.Timedelta(minutes=gap_minutes) + return [ + (FIT_SPLIT, _start_of_day(args.fit_inner_start), _end_of_day(args.fit_inner_end) - gap), + (TUNE_SPLIT, _start_of_day(args.tune_inner_start) + gap, _end_of_day(args.tune_inner_end) - gap), + (VALIDATION_LOCKED_SPLIT, _start_of_day(args.validation_locked_start) + gap, _end_of_day(args.validation_locked_end) - gap), + (LATEST_STRESS_SPLIT, _start_of_day(args.latest_stress_start) + gap, _end_of_day(args.latest_stress_end)), + ] + + +def _start_of_day(value: str) -> pd.Timestamp: + return pd.Timestamp(value, tz="UTC") + + +def _end_of_day(value: str) -> pd.Timestamp: + return pd.Timestamp(value, tz="UTC") + pd.Timedelta(days=1) - pd.Timedelta(minutes=1) + + +def _write_purge_embargo_report(path: Path, intervals: list[tuple[str, pd.Timestamp, pd.Timestamp]], gap_minutes: int) -> None: + lines = ["# Purge Embargo Report", "", f"- gap_minutes: {gap_minutes}", "", "| split_id | start | end |", "| --- | --- | --- |"] + for split_id, start, end in intervals: + lines.append(f"| {split_id} | {start.isoformat()} | {end.isoformat()} |") + write_text(path, "\n".join(lines) + "\n") diff --git a/training/trader_training/schemas.py b/training/trader_training/schemas.py new file mode 100644 index 0000000..8a7ff2c --- /dev/null +++ b/training/trader_training/schemas.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +FEATURE_VERSION = "feature-v4-p0" +LABEL_VERSION = "label-v4-p0" +SPLIT_VERSION = "split-v4-p0" +MODEL_BUNDLE_VERSION = "trader-v4-btc-p0" +CALIBRATION_BUNDLE_VERSION = "cal-v4-btc-p0" +PM_CONFIG_VERSION = "pm-v4-btc-p0" +OUTPUT_SCHEMA_VERSION = "output-schema-v4-btc-p0" + +FIT_SPLIT = "fit_inner" +TUNE_SPLIT = "tune_inner" +VALIDATION_LOCKED_SPLIT = "validation_locked" +LATEST_STRESS_SPLIT = "latest_stress" +TRAINING_SPLITS = (FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT) + + +@dataclass(frozen=True) +class FeatureDef: + order: int + name: str + cn_name: str + meaning: str + source_tables: tuple[str, ...] + formula: str + lookback_window: str + unit: str + dtype: str + null_rule: str + live_available: bool + leakage_check: str + owner_models: tuple[str, ...] + + def as_json(self) -> dict[str, Any]: + return { + "order": self.order, + "name": self.name, + "cn_name": self.cn_name, + "meaning": self.meaning, + "source_tables": list(self.source_tables), + "formula": self.formula, + "lookback_window": self.lookback_window, + "unit": self.unit, + "dtype": self.dtype, + "null_rule": self.null_rule, + "live_available": self.live_available, + "leakage_check": self.leakage_check, + "owner_models": list(self.owner_models), + } + + +ALL_MODELS = ("Direction", "Entry", "Continue", "Exit", "Risk") + + +FEATURES: tuple[FeatureDef, ...] = ( + FeatureDef(1, "ret_1m_bps", "最近1分钟收益", "Latest short return.", ("replay_1m",), "close_t / close_t-1m - 1", "1m", "bps", "float32", "WARMUP", True, "uses <= t close only", ALL_MODELS), + FeatureDef(2, "ret_5m_bps", "最近5分钟收益", "Short trend.", ("replay_1m",), "close_t / close_t-5m - 1", "5m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Entry", "Continue", "Exit")), + FeatureDef(3, "ret_15m_bps", "最近15分钟收益", "Near trend.", ("replay_1m",), "close_t / close_t-15m - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Entry", "Continue", "Exit")), + FeatureDef(4, "ret_60m_bps", "最近60分钟收益", "Baseline trend.", ("replay_1m",), "close_t / close_t-60m - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Continue", "Exit", "Risk")), + FeatureDef(5, "ret_240m_bps", "最近240分钟收益", "Four-hour trend.", ("replay_1m",), "close_t / close_t-240m - 1", "240m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Continue", "Risk")), + FeatureDef(6, "realized_vol_15m_bps", "15分钟波动", "Near realized volatility.", ("replay_1m",), "std(log_return_1m, 15) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Entry", "Exit", "Risk")), + FeatureDef(7, "realized_vol_60m_bps", "60分钟波动", "Baseline realized volatility.", ("replay_1m",), "std(log_return_1m, 60) * 10000", "60m", "bps", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Entry", "Exit", "Risk")), + FeatureDef(8, "vol_ratio_15m_60m", "近端波动放大", "Near volatility versus baseline.", ("feature",), "realized_vol_15m_bps / max(realized_vol_60m_bps, 1)", "15m/60m", "ratio", "float32", "WARMUP", True, "derived from <= t features", ("Entry", "Exit", "Risk")), + FeatureDef(9, "range_15m_bps", "15分钟振幅", "Near high-low range.", ("replay_1m",), "max(high_15m) / min(low_15m) - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t high/low only", ("Entry", "Exit", "Risk")), + FeatureDef(10, "range_60m_bps", "60分钟振幅", "Baseline high-low range.", ("replay_1m",), "max(high_60m) / min(low_60m) - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t high/low only", ("Direction", "Entry", "Risk")), + FeatureDef(11, "volume_zscore_60m", "60分钟成交量异常", "Current volume abnormality.", ("replay_1m",), "(volume_t - mean(volume_60m)) / std(volume_60m)", "60m", "zscore", "float32", "std=0 -> 0", True, "uses <= t volume only", ("Direction", "Entry", "Risk")), + FeatureDef(12, "trend_consistency_15m", "15分钟方向连续性", "Signed return consistency.", ("replay_1m",), "mean(sign(ret_1m), 15)", "15m", "ratio", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Continue", "Exit")), + FeatureDef(13, "channel_position_60m_pct", "60分钟通道位置", "Close position in recent channel.", ("replay_1m",), "(close_t - low_60m) / max(high_60m - low_60m, tick)", "60m", "pct", "float32", "WARMUP", True, "uses <= t high/low/close only", ("Direction", "Entry", "Continue")), + FeatureDef(14, "upper_breakout_60m_bps", "向上突破距离", "Upper breakout distance.", ("replay_1m",), "max(0, close_t / prev_high_60m_excl_t - 1) * 10000", "60m", "bps", "float32", "WARMUP", True, "current close versus prior window only", ("Direction", "Entry", "Continue")), + FeatureDef(15, "lower_breakout_60m_bps", "向下跌破距离", "Lower breakdown distance.", ("replay_1m",), "max(0, prev_low_60m_excl_t / close_t - 1) * 10000", "60m", "bps", "float32", "WARMUP", True, "current close versus prior window only", ("Direction", "Entry", "Continue")), + FeatureDef(16, "upper_failed_break_reclaim_15m_bps", "上破失败回落", "Failed upper breakout reclaim.", ("replay_1m",), "if high_15m broke prior high then max(0, prev_high_60m - close_t) / close_t * 10000", "15m/60m", "bps", "float32", "no event -> 0", True, "prior high excludes t", ("Entry", "Exit", "Risk")), + FeatureDef(17, "lower_failed_break_reclaim_15m_bps", "下破失败收回", "Failed lower breakdown reclaim.", ("replay_1m",), "if low_15m broke prior low then max(0, close_t - prev_low_60m) / close_t * 10000", "15m/60m", "bps", "float32", "no event -> 0", True, "prior low excludes t", ("Entry", "Exit", "Risk")), + FeatureDef(18, "sweep_up_15m_bps", "上影扫高", "Upper sweep size.", ("replay_1m",), "max(0, max(high_15m) / close_t - 1) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t high/close only", ("Exit", "Risk")), + FeatureDef(19, "sweep_down_15m_bps", "下影扫低", "Lower sweep size.", ("replay_1m",), "max(0, close_t / min(low_15m) - 1) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t low/close only", ("Exit", "Risk")), + FeatureDef(20, "compression_score_4h_pct", "4小时压缩分位", "Higher means recent range is compressed.", ("feature",), "1 - percentile_rank(range_15m_bps over 240m)", "240m", "pct", "float32", "WARMUP", True, "rolling rank uses <= t", ("Direction", "Entry")), + FeatureDef(21, "compression_release_15m_bps", "压缩释放幅度", "Range release versus 4h median.", ("feature",), "max(0, range_15m_bps - median(range_15m_bps over 240m))", "15m/240m", "bps", "float32", "WARMUP", True, "rolling median uses <= t", ("Direction", "Entry", "Risk")), + FeatureDef(22, "taker_imbalance_1m", "1分钟主动买卖差", "Taker buy/sell imbalance.", ("trades", "replay_1m"), "(buy_1m - sell_1m) / max(total_1m, eps)", "1m", "ratio", "float32", "volume=0 -> 0", True, "uses current closed minute trades only", ("Direction", "Entry", "Continue")), + FeatureDef(23, "taker_imbalance_5m", "5分钟主动买卖差", "Short taker imbalance.", ("trades", "replay_1m"), "(buy_5m - sell_5m) / max(total_5m, eps)", "5m", "ratio", "float32", "WARMUP", True, "uses <= t trades only", ("Direction", "Entry", "Continue")), + FeatureDef(24, "taker_imbalance_15m", "15分钟主动买卖差", "Near taker imbalance.", ("trades", "replay_1m"), "(buy_15m - sell_15m) / max(total_15m, eps)", "15m", "ratio", "float32", "WARMUP", True, "uses <= t trades only", ("Direction", "Continue", "Exit")), + FeatureDef(25, "level1_ofi_1m", "1分钟盘口订单流", "Best bid/ask order-flow imbalance.", ("level_1", "replay_1m"), "sum(OFI changes in minute) / mean(level1 depth)", "1m", "ratio", "float32", "missing -> fail", True, "uses current closed minute L1 only", ("Direction", "Entry", "Risk")), + FeatureDef(26, "spread_bps", "买卖价差", "Best bid/ask spread.", ("level_1", "replay_1m"), "(best_ask - best_bid) / mid * 10000", "1m", "bps", "float32", "missing -> fail", True, "uses current closed minute L1 only", ("Entry", "Exit", "Risk")), + FeatureDef(27, "spread_rank_24h_pct", "24小时价差分位", "Spread congestion rank.", ("feature",), "percentile_rank(spread_bps over 24h)", "24h", "pct", "float32", "WARMUP", True, "rolling rank uses <= t", ("Entry", "Exit", "Risk")), + FeatureDef(28, "oi_delta_15m_bps", "15分钟持仓变化", "Open-interest short change.", ("open_interest", "replay_1m"), "open_interest_t / open_interest_t-15m - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t OI only", ("Direction", "Continue", "Risk")), + FeatureDef(29, "oi_delta_60m_bps", "60分钟持仓变化", "Open-interest baseline change.", ("open_interest", "replay_1m"), "open_interest_t / open_interest_t-60m - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t OI only", ("Direction", "Continue", "Risk")), + FeatureDef(30, "funding_bps", "资金费率", "Current funding rate.", ("funding", "replay_1m"), "rate * 10000", "as-of", "bps", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Direction", "Entry", "Risk")), + FeatureDef(31, "mark_index_basis_bps", "标记价指数价偏离", "Mark-index basis.", ("funding", "replay_1m"), "mark_price / index_price - 1", "as-of", "bps", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Direction", "Entry", "Risk")), + FeatureDef(32, "liquidation_buy_notional_1m", "1分钟买向爆仓金额", "Buy-side liquidation notional.", ("liquidations", "replay_1m"), "sum(quantity * price for BUY)", "1m", "quote", "float32", "missing partition -> 0 with flag", True, "uses current closed minute liquidations only", ("Entry", "Exit", "Risk")), + FeatureDef(33, "liquidation_sell_notional_1m", "1分钟卖向爆仓金额", "Sell-side liquidation notional.", ("liquidations", "replay_1m"), "sum(quantity * price for SELL)", "1m", "quote", "float32", "missing partition -> 0 with flag", True, "uses current closed minute liquidations only", ("Entry", "Exit", "Risk")), + FeatureDef(34, "liquidation_imbalance_15m", "15分钟爆仓方向差", "Liquidation imbalance.", ("liquidations", "replay_1m"), "(buy_15m - sell_15m) / max(total_15m, eps)", "15m", "ratio", "float32", "missing partition -> 0 with flag", True, "uses <= t liquidations only", ("Direction", "Entry", "Exit", "Risk")), + FeatureDef(35, "liquidation_notional_zscore_15m", "爆仓金额异常", "Liquidation notional zscore.", ("liquidations", "replay_1m"), "(liq_15m - mean_24h) / std_24h", "15m/24h", "zscore", "float32", "missing partition -> 0 with flag", True, "rolling window uses <= t", ("Entry", "Exit", "Risk")), + FeatureDef(36, "liquidation_available", "爆仓数据可用", "Whether liquidation data exists.", ("liquidations", "replay_1m"), "day partition exists", "day", "0/1", "float32", "never null", True, "partition availability known by event day", ("Entry", "Exit", "Risk")), + FeatureDef(37, "minute_of_day_sin", "日内时间正弦", "Time of day cyclic feature.", ("event_time",), "sin(2*pi*minute_of_day/1440)", "event_time", "ratio", "float32", "never null", True, "event timestamp only", ("Direction", "Entry", "Risk")), + FeatureDef(38, "minute_of_day_cos", "日内时间余弦", "Time of day cyclic feature.", ("event_time",), "cos(2*pi*minute_of_day/1440)", "event_time", "ratio", "float32", "never null", True, "event timestamp only", ("Direction", "Entry", "Risk")), + FeatureDef(39, "minutes_to_next_funding", "距离下次资金费分钟", "Minutes to next funding settlement.", ("funding", "replay_1m"), "clip((next_funding_time - event_time) / 60000, 0, 480)", "as-of", "minute", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Entry", "Continue", "Risk")), +) + + +FEATURE_ORDER = [feature.name for feature in FEATURES] + + +OUTPUT_SCHEMA: dict[str, Any] = { + "output_schema_version": OUTPUT_SCHEMA_VERSION, + "direction": { + "longProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortProb": {"type": "decimal", "range": [0.0, 1.0]}, + "neutralProb": {"type": "decimal", "range": [0.0, 1.0]}, + "sum_rule": "longProb + shortProb + neutralProb must equal 1.0 within 0.000001", + }, + "entry": { + "longEntryProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortEntryProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longExpectedNetEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]}, + "shortExpectedNetEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]}, + }, + "continuation": { + "longContinueProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortContinueProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longExpectedContinueEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]}, + "shortExpectedContinueEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]}, + }, + "exit": { + "longExitProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortExitProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longAdverseMoveBps": {"type": "decimal", "range": [0.0, 500.0]}, + "shortAdverseMoveBps": {"type": "decimal", "range": [0.0, 500.0]}, + "exitReasonScores": { + "adverse_move_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "reversal_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "stop_hit_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "stagnation_prob": {"type": "decimal", "range": [0.0, 1.0]}, + }, + }, + "risk": { + "marketRiskProb": {"type": "decimal", "range": [0.0, 1.0]}, + "longPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]}, + "shortPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]}, + "marketPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "longPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "shortPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]}, + "riskReasonScores": { + "market_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "volatility_expansion_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "spike_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "liquidity_deterioration_prob": {"type": "decimal", "range": [0.0, 1.0]}, + "position_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]}, + }, + }, +} + + +MODEL_OUTPUTS: dict[str, list[str]] = { + "DIRECTION": ["long_prob", "short_prob", "neutral_prob"], + "ENTRY": ["long_entry_prob", "short_entry_prob", "long_expected_net_edge_bps", "short_expected_net_edge_bps"], + "CONTINUE": ["long_continue_prob", "short_continue_prob", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"], + "EXIT": [ + "long_exit_prob", + "short_exit_prob", + "long_adverse_move_bps", + "short_adverse_move_bps", + "adverse_move_prob", + "reversal_prob", + "stop_hit_prob", + "stagnation_prob", + ], + "RISK": [ + "market_risk_prob", + "long_position_risk_prob", + "short_position_risk_prob", + "market_path_risk_bps", + "long_position_path_risk_bps", + "short_position_path_risk_bps", + "market_drawdown_prob", + "volatility_expansion_prob", + "spike_prob", + "liquidity_deterioration_prob", + "position_drawdown_prob", + ], +} + + +PROBABILITY_TARGET_NAMES: dict[str, list[str]] = { + "DIRECTION": ["longProb", "shortProb", "neutralProb"], + "ENTRY": ["longEntryProb", "shortEntryProb"], + "CONTINUE": ["longContinueProb", "shortContinueProb"], + "EXIT": ["longExitProb", "shortExitProb", "adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"], + "RISK": [ + "marketRiskProb", + "longPositionRiskProb", + "shortPositionRiskProb", + "market_drawdown_prob", + "volatility_expansion_prob", + "spike_prob", + "liquidity_deterioration_prob", + "position_drawdown_prob", + ], +} + + +OUTPUT_MAPPING: dict[str, dict[str, str]] = { + model: {field: f"prediction[{index}]" for index, field in enumerate(fields)} + for model, fields in MODEL_OUTPUTS.items() +} diff --git a/training/trader_training/training.py b/training/trader_training/training.py new file mode 100644 index 0000000..a10fedf --- /dev/null +++ b/training/trader_training/training.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +from sklearn.linear_model import LogisticRegression, Ridge +from sklearn.metrics import accuracy_score, log_loss, mean_absolute_error, roc_auc_score +from sklearn.preprocessing import StandardScaler + +from trader_training.io_utils import read_parquet, run_root, sha256_file, write_json, write_parquet, write_text +from trader_training.onnx_export import LinearHead, export_heads +from trader_training.schemas import FEATURE_ORDER, FIT_SPLIT, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, PROBABILITY_TARGET_NAMES, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT + + +@dataclass +class HeadResult: + field: str + target_name: str | None + kind: str + weight: np.ndarray + bias: np.ndarray + metrics: dict[str, Any] + tune_prediction: np.ndarray + tune_target: np.ndarray | None + + +TARGETS = { + "DIRECTION": { + "dataset": "direction_train.parquet", + "heads": [("direction", "multiclass", ["long_target", "short_target", "neutral_target"], ["long_prob", "short_prob", "neutral_prob"], ["longProb", "shortProb", "neutralProb"])], + }, + "ENTRY": { + "dataset": "entry_train.parquet", + "heads": [ + ("long_entry_prob", "binary", "long_entry_target", ["long_entry_prob"], ["longEntryProb"]), + ("short_entry_prob", "binary", "short_entry_target", ["short_entry_prob"], ["shortEntryProb"]), + ("long_expected_net_edge_bps", "regression", "long_expected_net_edge_bps", ["long_expected_net_edge_bps"], [None]), + ("short_expected_net_edge_bps", "regression", "short_expected_net_edge_bps", ["short_expected_net_edge_bps"], [None]), + ], + }, + "CONTINUE": { + "dataset": "continue_train.parquet", + "heads": [ + ("long_continue_prob", "binary", "long_continue_target", ["long_continue_prob"], ["longContinueProb"]), + ("short_continue_prob", "binary", "short_continue_target", ["short_continue_prob"], ["shortContinueProb"]), + ("long_expected_continue_edge_bps", "regression", "long_expected_continue_edge_bps", ["long_expected_continue_edge_bps"], [None]), + ("short_expected_continue_edge_bps", "regression", "short_expected_continue_edge_bps", ["short_expected_continue_edge_bps"], [None]), + ], + }, + "EXIT": { + "dataset": "exit_train.parquet", + "heads": [ + ("long_exit_prob", "binary", "long_exit_target", ["long_exit_prob"], ["longExitProb"]), + ("short_exit_prob", "binary", "short_exit_target", ["short_exit_prob"], ["shortExitProb"]), + ("long_adverse_move_bps", "regression", "long_adverse_move_bps", ["long_adverse_move_bps"], [None]), + ("short_adverse_move_bps", "regression", "short_adverse_move_bps", ["short_adverse_move_bps"], [None]), + ("adverse_move_prob", "binary", "adverse_move_prob_label", ["adverse_move_prob"], ["adverse_move_prob"]), + ("reversal_prob", "binary", "reversal_prob_label", ["reversal_prob"], ["reversal_prob"]), + ("stop_hit_prob", "binary", "stop_hit_prob_label", ["stop_hit_prob"], ["stop_hit_prob"]), + ("stagnation_prob", "binary", "stagnation_prob_label", ["stagnation_prob"], ["stagnation_prob"]), + ], + }, + "RISK": { + "dataset": "risk_train.parquet", + "heads": [ + ("market_risk_prob", "binary", "market_risk_target", ["market_risk_prob"], ["marketRiskProb"]), + ("long_position_risk_prob", "binary", "long_position_risk_target", ["long_position_risk_prob"], ["longPositionRiskProb"]), + ("short_position_risk_prob", "binary", "short_position_risk_target", ["short_position_risk_prob"], ["shortPositionRiskProb"]), + ("market_path_risk_bps", "regression", "market_path_risk_bps", ["market_path_risk_bps"], [None]), + ("long_position_path_risk_bps", "regression", "long_position_path_risk_bps", ["long_position_path_risk_bps"], [None]), + ("short_position_path_risk_bps", "regression", "short_position_path_risk_bps", ["short_position_path_risk_bps"], [None]), + ("market_drawdown_prob", "binary", "market_drawdown_prob_label", ["market_drawdown_prob"], ["market_drawdown_prob"]), + ("volatility_expansion_prob", "binary", "volatility_expansion_prob_label", ["volatility_expansion_prob"], ["volatility_expansion_prob"]), + ("spike_prob", "binary", "spike_prob_label", ["spike_prob"], ["spike_prob"]), + ("liquidity_deterioration_prob", "binary", "liquidity_deterioration_prob_label", ["liquidity_deterioration_prob"], ["liquidity_deterioration_prob"]), + ("position_drawdown_prob", "binary", "position_drawdown_prob_label", ["position_drawdown_prob"], ["position_drawdown_prob"]), + ], + }, +} + + +def train_small_models(args: Any) -> None: + root = run_root(args) + model_manifest: dict[str, Any] = {} + for model_name, spec in TARGETS.items(): + dataset = read_parquet(root / "dataset" / spec["dataset"]) + if args.max_rows and len(dataset) > args.max_rows: + dataset = dataset.sort_values("event_time").tail(args.max_rows).copy() + if dataset.empty: + raise ValueError(f"dataset is empty for {model_name}") + train = dataset[dataset["split_id"] == FIT_SPLIT].copy() + tune = dataset[dataset["split_id"] == TUNE_SPLIT].copy() + validation_locked = dataset[dataset["split_id"] == VALIDATION_LOCKED_SPLIT].copy() + latest_stress = dataset[dataset["split_id"] == LATEST_STRESS_SPLIT].copy() + if train.empty or tune.empty: + raise ValueError(f"{model_name} needs {FIT_SPLIT} and {TUNE_SPLIT} rows") + logging.info( + "trader.training.model_dataset_loaded runId=%s model=%s totalRows=%s trainRows=%s tuneRows=%s validationLockedRows=%s latestStressRows=%s splitCounts=%s", + args.run_id, + model_name, + len(dataset), + len(train), + len(tune), + len(validation_locked), + len(latest_stress), + dataset["split_id"].value_counts().to_dict(), + ) + scaler = StandardScaler() + x_train_scaled = scaler.fit_transform(train[FEATURE_ORDER].astype("float32")) + x_tune_scaled = scaler.transform(tune[FEATURE_ORDER].astype("float32")) + heads: list[LinearHead] = [] + head_results: list[HeadResult] = [] + for item in spec["heads"]: + head_results.extend(_fit_head(item, x_train_scaled, x_tune_scaled, train, tune, scaler)) + for result in head_results: + logging.info( + "trader.training.model_head_trained runId=%s model=%s head=%s kind=%s metrics=%s", + args.run_id, + model_name, + result.field, + result.kind, + result.metrics, + ) + for result in head_results: + heads.append(LinearHead(result.field, _onnx_kind(result.kind), result.weight, result.bias)) + model_dir = root / "model" / model_name.lower() + model_path = model_dir / f"{model_name.lower()}.onnx" + export_heads(model_path, heads, feature_count=len(FEATURE_ORDER), opset=17) + predictions = _tune_prediction_frame(tune, head_results) + write_parquet(model_dir / "tune_predictions.parquet", predictions) + if not validation_locked.empty: + write_parquet(model_dir / "validation_locked_predictions.parquet", _predict_frame(validation_locked, head_results, include_labels=True)) + if not latest_stress.empty: + write_parquet(model_dir / "latest_stress_predictions.parquet", _predict_frame(latest_stress, head_results, include_labels=False)) + metrics = {result.field: result.metrics for result in head_results} + model_hash = sha256_file(model_path) + quality_status, quality_reasons = _model_quality(head_results) + write_json( + model_dir / "model_train_result.json", + { + "model_name": model_name, + "metrics": metrics, + "quality_status": quality_status, + "quality_reasons": quality_reasons, + "artifact_hash_sha256": model_hash, + }, + ) + write_json( + model_dir / "model_manifest.json", + { + "model_name": model_name, + "model_path": str(model_path), + "model_format": "ONNX", + "input_tensor_name": "features", + "input_feature_count": len(FEATURE_ORDER), + "output_tensor_name": "prediction", + "output_fields": MODEL_OUTPUTS[model_name], + "quality_status": quality_status, + "quality_reasons": quality_reasons, + "artifact_hash_sha256": model_hash, + }, + ) + _write_model_examples(model_dir, model_name, tune, predictions) + _write_feature_importance(model_dir / "feature_importance.csv", head_results) + _write_version_compare(model_dir, model_name, metrics) + _write_training_report(model_dir / "training_report.md", model_name, metrics, quality_status, quality_reasons) + model_manifest[model_name] = {"path": str(model_path), "hash_sha256": model_hash, "metrics": metrics, "quality_status": quality_status, "quality_reasons": quality_reasons} + logging.info( + "trader.training.model_trained runId=%s model=%s qualityStatus=%s qualityReasons=%s path=%s tunePredictionRows=%s featureImportancePath=%s", + args.run_id, + model_name, + quality_status, + quality_reasons, + model_path, + len(predictions), + model_dir / "feature_importance.csv", + ) + write_json(root / "model" / "model_train_manifest.json", model_manifest) + + +def _fit_head(item, x_train, x_tune, train: pd.DataFrame, tune: pd.DataFrame, scaler: StandardScaler) -> list[HeadResult]: + name, kind, target, fields, target_names = item + if kind == "multiclass": + y_train = train[target].to_numpy().argmax(axis=1) + y_val = tune[target].to_numpy().argmax(axis=1) + model = LogisticRegression(max_iter=500) + model.fit(x_train, y_train) + proba = model.predict_proba(x_tune) + weight, bias = _fold_scaler(model.coef_.T, model.intercept_, scaler) + train_prior = train[target].to_numpy().mean(axis=0) + metrics = _multiclass_metrics(y_train, y_val, proba, train_prior) + return [HeadResult("direction", target_names[0], "softmax", weight, bias, metrics, proba, y_val)] + if kind == "binary": + y_train = pd.to_numeric(train[target], errors="coerce").fillna(0).astype(int).to_numpy() + y_val = pd.to_numeric(tune[target], errors="coerce").fillna(0).astype(int).to_numpy() + if len(np.unique(y_train)) < 2: + prevalence = float(np.clip(y_train.mean(), 1e-6, 1 - 1e-6)) + coef = np.zeros((1, len(FEATURE_ORDER)), dtype=np.float32) + intercept = np.array([np.log(prevalence / (1 - prevalence))], dtype=np.float32) + proba = np.full(len(y_val), prevalence, dtype=np.float32) + else: + model = LogisticRegression(max_iter=500) + model.fit(x_train, y_train) + coef = model.coef_ + intercept = model.intercept_ + proba = model.predict_proba(x_tune)[:, 1] + weight, bias = _fold_scaler(coef.T, intercept, scaler) + metrics = _binary_metrics(y_train, y_val, proba) + if len(np.unique(y_val)) == 2: + metrics["auc"] = float(roc_auc_score(y_val, proba)) + return [HeadResult(fields[0], target_names[0], "sigmoid", weight, bias, metrics, proba.reshape(-1, 1), y_val)] + if kind == "regression": + y_train = pd.to_numeric(train[target], errors="coerce").fillna(0.0).to_numpy() + y_val = pd.to_numeric(tune[target], errors="coerce").fillna(0.0).to_numpy() + model = Ridge(alpha=1.0) + model.fit(x_train, y_train) + pred = model.predict(x_tune) + weight, bias = _fold_scaler(model.coef_.reshape(1, -1).T, np.array([model.intercept_]), scaler) + return [HeadResult(fields[0], None, "identity", weight, bias, _regression_metrics(y_train, y_val, pred), pred.reshape(-1, 1), y_val)] + raise ValueError(f"unsupported head kind: {kind}") + + +def _fold_scaler(weight_scaled: np.ndarray, bias_scaled: np.ndarray, scaler: StandardScaler) -> tuple[np.ndarray, np.ndarray]: + scale = np.where(scaler.scale_ == 0, 1.0, scaler.scale_) + weight = weight_scaled / scale.reshape(-1, 1) + bias = bias_scaled - np.sum((scaler.mean_ / scale).reshape(-1, 1) * weight_scaled, axis=0) + return weight.astype(np.float32), bias.astype(np.float32) + + +def _onnx_kind(kind: str) -> str: + if kind in ("softmax", "sigmoid", "identity"): + return kind + raise ValueError(f"unsupported result kind: {kind}") + + +def _multiclass_metrics(y_train: np.ndarray, y_val: np.ndarray, proba: np.ndarray, train_prior: np.ndarray) -> dict[str, Any]: + one_hot = np.eye(proba.shape[1], dtype=float)[y_val] + train_prior = np.asarray(train_prior, dtype=float) + train_prior = train_prior / train_prior.sum() if train_prior.sum() > 0 else np.full(proba.shape[1], 1.0 / proba.shape[1]) + constant = np.tile(train_prior.reshape(1, -1), (len(y_val), 1)) + proba_for_logloss = _clip_normalize(proba) + constant_for_logloss = _clip_normalize(constant) + metrics: dict[str, Any] = { + "accuracy": float(accuracy_score(y_val, proba.argmax(axis=1))), + "logloss": float(log_loss(y_val, proba_for_logloss, labels=list(range(proba.shape[1])))), + "constant_logloss": float(log_loss(y_val, constant_for_logloss, labels=list(range(proba.shape[1])))), + "brier_multiclass": float(np.mean(np.sum((one_hot - proba) ** 2, axis=1))), + "constant_brier_multiclass": float(np.mean(np.sum((one_hot - constant) ** 2, axis=1))), + } + for idx, name in enumerate(("long", "short", "neutral")): + binary_target = (y_val == idx).astype(int) + positives = int(binary_target.sum()) + negatives = int(len(binary_target) - positives) + if positives >= 200 and negatives >= 200: + metrics[f"{name}_auc"] = float(roc_auc_score(binary_target, proba[:, idx])) + else: + metrics[f"{name}_auc_status"] = "INSUFFICIENT_SAMPLE" + metrics[f"{name}_positive_count"] = positives + metrics[f"{name}_negative_count"] = negatives + max_prob = proba.max(axis=1) + predicted_class = proba.argmax(axis=1) + top_count = max(1, int(len(y_val) * 0.10)) + top_idx = np.argsort(max_prob)[-top_count:] + metrics["top10_hit_rate"] = float((predicted_class[top_idx] == y_val[top_idx]).mean()) + metrics["all_hit_rate"] = float((predicted_class == y_val).mean()) + return _with_quality(metrics) + + +def _clip_normalize(values: np.ndarray) -> np.ndarray: + values = np.clip(np.asarray(values, dtype=float), 1e-6, 1.0) + return values / values.sum(axis=1, keepdims=True) + + +def _binary_metrics(y_train: np.ndarray, y_val: np.ndarray, proba: np.ndarray) -> dict[str, Any]: + proba = np.asarray(proba, dtype=float) + train_rate = float(np.mean(y_train)) if len(y_train) else 0.0 + constant = np.full(len(y_val), train_rate) + metrics: dict[str, Any] = { + "positive_rate": train_rate, + "tune_positive_rate": float(np.mean(y_val)) if len(y_val) else 0.0, + "brier": float(np.mean((y_val - proba) ** 2)) if len(y_val) else 0.0, + "constant_brier": float(np.mean((y_val - constant) ** 2)) if len(y_val) else 0.0, + } + if len(y_val): + top_count = max(1, int(len(y_val) * 0.10)) + top_idx = np.argsort(proba)[-top_count:] + metrics["top10_hit_rate"] = float(np.mean(y_val[top_idx])) + metrics["all_hit_rate"] = float(np.mean(y_val)) + return _with_quality(metrics) + + +def _regression_metrics(y_train: np.ndarray, y_val: np.ndarray, pred: np.ndarray) -> dict[str, Any]: + mae = float(mean_absolute_error(y_val, pred)) + train_std = float(np.std(y_train)) + metrics: dict[str, Any] = { + "mae": mae, + "train_target_std": train_std, + "mae_vs_train_std_ratio": float(mae / train_std) if train_std > 0 else None, + } + return _with_quality(metrics) + + +def _with_quality(metrics: dict[str, Any]) -> dict[str, Any]: + reasons: list[str] = [] + for key, value in metrics.items(): + if key.endswith("_auc") and isinstance(value, float) and value < 0.53: + reasons.append(f"{key}_below_0.53") + if "brier" in metrics and metrics.get("constant_brier") is not None and metrics["brier"] >= metrics["constant_brier"]: + reasons.append("brier_not_better_than_constant") + if "brier_multiclass" in metrics and metrics["brier_multiclass"] >= metrics["constant_brier_multiclass"]: + reasons.append("brier_not_better_than_constant") + if "mae" in metrics and metrics.get("train_target_std") is not None and metrics["train_target_std"] > 0 and metrics["mae"] > metrics["train_target_std"]: + reasons.append("mae_above_train_target_std") + if "top10_hit_rate" in metrics and "all_hit_rate" in metrics and metrics["top10_hit_rate"] <= metrics["all_hit_rate"]: + reasons.append("top10_not_better_than_all") + metrics["quality_status"] = "REJECTED" if reasons else "PASS" + metrics["quality_reasons"] = reasons + return metrics + + +def _model_quality(results: list[HeadResult]) -> tuple[str, list[str]]: + reasons = [] + for result in results: + if result.metrics.get("quality_status") == "REJECTED": + for reason in result.metrics.get("quality_reasons", []): + reasons.append(f"{result.field}:{reason}") + return ("REJECTED", reasons) if reasons else ("PASS", []) + + +def _tune_prediction_frame(tune: pd.DataFrame, results: list[HeadResult]) -> pd.DataFrame: + out = tune[["sample_id", "symbol", "event_time", "split_id"]].copy().reset_index(drop=True) + for result in results: + values = result.tune_prediction + if result.kind == "softmax": + for idx, field in enumerate(MODEL_OUTPUTS["DIRECTION"]): + out[field] = values[:, idx] + if result.tune_target is not None: + out["label__longProb"] = (result.tune_target == 0).astype(int) + out["label__shortProb"] = (result.tune_target == 1).astype(int) + out["label__neutralProb"] = (result.tune_target == 2).astype(int) + else: + out[result.field] = values.reshape(-1) + if result.kind != "softmax" and result.target_name and result.tune_target is not None: + out[f"label__{result.target_name}"] = result.tune_target + return out + + +def _predict_frame(frame: pd.DataFrame, results: list[HeadResult], include_labels: bool) -> pd.DataFrame: + out = frame[["sample_id", "symbol", "event_time", "split_id"]].copy().reset_index(drop=True) + features = frame[FEATURE_ORDER].astype("float32").to_numpy() + for result in results: + values = features @ result.weight + result.bias.reshape(1, -1) + if result.kind == "softmax": + values = _softmax(values) + for idx, field in enumerate(MODEL_OUTPUTS["DIRECTION"]): + out[field] = values[:, idx] + elif result.kind == "sigmoid": + out[result.field] = (1.0 / (1.0 + np.exp(-values))).reshape(-1) + else: + out[result.field] = values.reshape(-1) + if include_labels and result.kind != "softmax" and result.target_name and result.target_name in frame.columns: + out[f"label__{result.target_name}"] = frame[result.target_name].to_numpy() + return out + + +def _softmax(values: np.ndarray) -> np.ndarray: + shifted = values - np.max(values, axis=1, keepdims=True) + exp = np.exp(shifted) + return exp / exp.sum(axis=1, keepdims=True) + + +def _write_training_report(path: Path, model_name: str, metrics: dict[str, Any], quality_status: str, quality_reasons: list[str]) -> None: + lines = [ + "# Trader Model Training Report", + "", + f"- model: {model_name}", + f"- quality_status: {quality_status}", + f"- quality_reasons: {quality_reasons}", + "", + "```json", + json.dumps(metrics, indent=2, sort_keys=True), + "```", + "", + ] + write_text(path, "\n".join(lines)) + + +def _write_model_examples(model_dir: Path, model_name: str, tune: pd.DataFrame, predictions: pd.DataFrame) -> None: + sample_input = {feature: float(tune.iloc[0][feature]) for feature in FEATURE_ORDER} + sample_output = {field: float(predictions.iloc[0][field]) for field in MODEL_OUTPUTS[model_name]} + write_json(model_dir / "sample_input.json", sample_input) + write_json(model_dir / "sample_output.json", sample_output) + + +def _write_feature_importance(path: Path, results: list[HeadResult]) -> None: + rows = [] + for result in results: + importance = np.mean(np.abs(result.weight), axis=1) + for feature, value in zip(FEATURE_ORDER, importance): + rows.append({"head": result.field, "feature": feature, "abs_weight": float(value)}) + frame = pd.DataFrame(rows).sort_values(["head", "abs_weight"], ascending=[True, False]) + write_text(path, frame.to_csv(index=False)) + + +def _write_version_compare(model_dir: Path, model_name: str, metrics: dict[str, Any]) -> None: + payload = { + "model_name": model_name, + "status": "NO_BASELINE", + "reason": "first executable V4 training chain has no previous approved artifact bundle under this run root", + "current_metrics": metrics, + } + write_json(model_dir / "version_compare_metrics.json", payload) + write_text(model_dir / "version_compare_by_regime.csv", "regime,status,reason\nNO_BASELINE,NO_BASELINE,no previous approved artifact bundle\n") + write_text(model_dir / "version_compare_top_bucket.csv", "bucket,status,reason\nNO_BASELINE,NO_BASELINE,no previous approved artifact bundle\n") + lines = [ + "# Version Compare Report", + "", + f"- model: {model_name}", + "- status: NO_BASELINE", + "- reason: 当前 run 目录没有上一版已验收模型包,所以首版只能记录当前指标,不能做新旧优劣判断。", + "", + ] + write_text(model_dir / "version_compare_report.md", "\n".join(lines)) + + +def build_calibrators(args: Any) -> None: + root = run_root(args) + manifest_rows = [] + for model_name, target_names in PROBABILITY_TARGET_NAMES.items(): + prediction_path = root / "model" / model_name.lower() / "tune_predictions.parquet" + predictions = read_parquet(prediction_path) + targets = {} + reliability_rows = [] + quality_reasons = [] + for target_name in target_names: + raw_field = _target_to_raw_field(model_name, target_name) + label_field = f"label__{target_name}" + labels = predictions[label_field].to_numpy() if label_field in predictions.columns else None + raw = predictions[raw_field].to_numpy() + bins = _calibration_bins(raw, labels) + metrics, rows = _calibration_metrics(raw, labels, bins, target_name) + targets[target_name] = {"bins": bins, "metrics": metrics} + reliability_rows.extend(rows) + if metrics.get("quality_status") == "REJECTED": + quality_reasons.append(f"{target_name}:{metrics.get('quality_reason')}") + calibrator = { + "calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0", + "method": "BINNING", + "targets": targets, + "clip": {"min": 0.0, "max": 1.0}, + "fallback_policy": "FAIL_FAST", + } + path = root / "calibration" / model_name.lower() / "calibrator.json" + cal_hash = write_json(path, calibrator) + quality_status = "REJECTED" if quality_reasons else "PASS" + manifest_rows.append( + { + "model_name": model_name, + "calibrator_path": str(path), + "calibrator_hash_sha256": cal_hash, + "target_count": len(targets), + "quality_status": quality_status, + "quality_reasons": quality_reasons, + } + ) + write_text(root / "calibration" / model_name.lower() / "reliability_curve.csv", pd.DataFrame(reliability_rows).to_csv(index=False)) + _write_calibration_report(root / "calibration" / model_name.lower() / "calibration_report.md", model_name, targets, quality_status, quality_reasons) + logging.info("trader.training.calibrator_written runId=%s model=%s path=%s", args.run_id, model_name, path) + write_json(root / "calibration" / "calibration_train_manifest.json", {"calibrators": manifest_rows}) + + +def _target_to_raw_field(model_name: str, target_name: str) -> str: + mapping = { + "longProb": "long_prob", + "shortProb": "short_prob", + "neutralProb": "neutral_prob", + "longEntryProb": "long_entry_prob", + "shortEntryProb": "short_entry_prob", + "longContinueProb": "long_continue_prob", + "shortContinueProb": "short_continue_prob", + "longExitProb": "long_exit_prob", + "shortExitProb": "short_exit_prob", + "marketRiskProb": "market_risk_prob", + "longPositionRiskProb": "long_position_risk_prob", + "shortPositionRiskProb": "short_position_risk_prob", + } + return mapping.get(target_name, target_name) + + +def _calibration_bins(raw: np.ndarray, labels: np.ndarray | None) -> list[dict[str, float]]: + raw = np.asarray(raw, dtype=float) + raw = np.clip(np.nan_to_num(raw, nan=0.5), 0.0, 1.0) + if labels is None or len(labels) != len(raw): + return [{"min": 0.0, "max": 1.0, "calibrated": 0.5}] + labels = np.asarray(labels, dtype=float) + bins = [] + edges = np.linspace(0.0, 1.0, 11) + for left, right in zip(edges[:-1], edges[1:]): + if right == 1.0: + mask = (raw >= left) & (raw <= right) + else: + mask = (raw >= left) & (raw < right) + calibrated = float(labels[mask].mean()) if mask.any() else float((left + right) / 2.0) + bins.append({"min": float(left), "max": float(right), "calibrated": float(np.clip(calibrated, 0.0, 1.0))}) + return bins + + +def _apply_calibration(raw: np.ndarray, bins: list[dict[str, float]]) -> np.ndarray: + out = np.zeros_like(raw, dtype=float) + for item in bins: + left = float(item["min"]) + right = float(item["max"]) + if right >= 1.0: + mask = (raw >= left) & (raw <= right) + else: + mask = (raw >= left) & (raw < right) + out[mask] = float(item["calibrated"]) + return np.clip(out, 0.0, 1.0) + + +def _calibration_metrics(raw: np.ndarray, labels: np.ndarray | None, bins: list[dict[str, float]], target_name: str) -> tuple[dict[str, Any], list[dict[str, Any]]]: + raw = np.clip(np.asarray(raw, dtype=float), 0.0, 1.0) + if labels is None or len(labels) != len(raw): + return {"quality_status": "REJECTED", "quality_reason": "missing_labels"}, [] + labels = np.asarray(labels, dtype=float) + calibrated = _apply_calibration(raw, bins) + raw_ece, rows = _ece(raw, labels, target_name, "raw") + calibrated_ece, calibrated_rows = _ece(calibrated, labels, target_name, "calibrated") + rows.extend(calibrated_rows) + quality_status = "PASS" if calibrated_ece <= raw_ece else "REJECTED" + return ( + { + "raw_ece": raw_ece, + "calibrated_ece": calibrated_ece, + "quality_status": quality_status, + "quality_reason": None if quality_status == "PASS" else "calibrated_ece_not_improved", + }, + rows, + ) + + +def _ece(values: np.ndarray, labels: np.ndarray, target_name: str, series_name: str) -> tuple[float, list[dict[str, Any]]]: + rows = [] + total = len(values) + ece = 0.0 + edges = np.linspace(0.0, 1.0, 11) + for left, right in zip(edges[:-1], edges[1:]): + mask = (values >= left) & (values <= right) if right >= 1.0 else (values >= left) & (values < right) + count = int(mask.sum()) + confidence = float(values[mask].mean()) if count else float((left + right) / 2.0) + accuracy = float(labels[mask].mean()) if count else 0.0 + ece += (count / total) * abs(confidence - accuracy) if total else 0.0 + rows.append( + { + "target": target_name, + "series": series_name, + "bin_min": left, + "bin_max": right, + "count": count, + "confidence": confidence, + "accuracy": accuracy, + } + ) + return float(ece), rows + + +def _write_calibration_report(path: Path, model_name: str, targets: dict[str, Any], quality_status: str, quality_reasons: list[str]) -> None: + lines = ["# Trader Calibration Report", "", f"- model: {model_name}", f"- target_count: {len(targets)}", f"- quality_status: {quality_status}", f"- quality_reasons: {quality_reasons}", ""] + for target_name, payload in targets.items(): + lines.append(f"## {target_name}") + lines.append("") + lines.append("```json") + lines.append(json.dumps(payload.get("metrics", {}), indent=2, sort_keys=True)) + lines.append("```") + lines.append("") + write_text(path, "\n".join(lines)) diff --git a/training/trader_training/validator.py b/training/trader_training/validator.py new file mode 100644 index 0000000..2ddeb1b --- /dev/null +++ b/training/trader_training/validator.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import numpy as np + +from trader_training.io_utils import read_json, sha256_file, write_json, write_text +from trader_training.schemas import FEATURE_ORDER, MODEL_OUTPUTS, OUTPUT_MAPPING + + +def validate_artifact_bundle(args: Any) -> None: + root = args.artifact_root + errors: list[str] = [] + release_gate = {"release_gate_status": "UNKNOWN", "release_gate_reasons": ["artifact content not checked"]} + required = [ + "manifests/model_bundle_manifest.json", + "manifests/model_manifest.json", + "manifests/calibration_manifest.json", + "manifests/position_manager_manifest.json", + "schemas/feature_schema.json", + "schemas/feature_order.json", + "schemas/output_schema.json", + "price_plan_context.json", + "examples/sample_input.json", + "examples/sample_output.json", + ] + for relative in required: + if not (root / relative).is_file(): + errors.append(f"missing file: {relative}") + if not errors: + _validate_content(root, errors, args.require_active, args.run_onnx) + # Structure validation only proves the bundle is loadable. The release + # gate is the separate business decision for whether it may become ACTIVE. + release_gate = _release_gate(root) + if args.require_active and release_gate["release_gate_status"] != "PASS": + errors.append(f"release gate must be PASS for ACTIVE, actual={release_gate['release_gate_status']}") + status = "PASS" if not errors else "FAIL" + result = {"status": status, "error_count": len(errors), "errors": errors, "artifact_root": str(root), **release_gate} + output_root = root.parent + write_json(output_root / "artifact_validation_result.json", result) + lines = ["# Artifact Validation Report", "", f"- status: {status}", f"- error_count: {len(errors)}", ""] + for error in errors: + lines.append(f"- {error}") + write_text(output_root / "artifact_validation_report.md", "\n".join(lines) + "\n") + logging.info("trader.training.artifact_validated status=%s errorCount=%s path=%s", status, len(errors), root) + if errors: + raise SystemExit("artifact validation failed; see artifact_validation_report.md") + + +def _validate_content(root: Path, errors: list[str], require_active: bool, run_onnx: bool) -> None: + feature_order = read_json(root / "schemas/feature_order.json") + if feature_order != FEATURE_ORDER: + errors.append("feature_order.json does not match V4 39-feature order") + model_bundle = read_json(root / "manifests/model_bundle_manifest.json") + if require_active and model_bundle.get("status") != "ACTIVE": + errors.append("model_bundle_manifest.status must be ACTIVE for Java SHADOW") + if model_bundle.get("status") not in {"CANDIDATE", "ACTIVE"}: + errors.append("model_bundle_manifest.status must be CANDIDATE or ACTIVE") + manifests = read_json(root / "manifests/model_manifest.json") + if len(manifests) != 5: + errors.append("model_manifest.json must contain exactly five logical models") + seen = {item.get("model_type") for item in manifests} + if seen != set(MODEL_OUTPUTS): + errors.append(f"model types mismatch: {seen}") + for item in manifests: + model_type = item.get("model_type") + if item.get("model_format") != "ONNX": + errors.append(f"{model_type} model_format must be ONNX") + if item.get("input_tensor_name") != "features": + errors.append(f"{model_type} input tensor must be features") + if item.get("input_shape_json", {}).get("features") != 39: + errors.append(f"{model_type} input_shape_json.features must be 39") + if item.get("onnx_opset_version") != 17: + errors.append(f"{model_type} opset must be 17") + if item.get("output_mapping_json") != OUTPUT_MAPPING.get(model_type): + errors.append(f"{model_type} output_mapping_json does not match Java contract") + _check_hash(root, item.get("artifact_path"), item.get("artifact_hash_sha256"), errors) + _check_hash(root, item.get("feature_schema_path"), item.get("feature_schema_hash"), errors) + _check_hash(root, item.get("feature_order_path"), item.get("feature_order_hash"), errors) + _check_hash(root, item.get("output_schema_path"), item.get("output_schema_hash"), errors) + if require_active and item.get("status") != "ACTIVE": + errors.append(f"{model_type} status must be ACTIVE for Java SHADOW") + if item.get("status") != model_bundle.get("status"): + errors.append(f"{model_type} status does not match model_bundle_manifest.status") + calibrators = read_json(root / "manifests/calibration_manifest.json") + if len(calibrators) != 5: + errors.append("calibration_manifest.json must contain five calibrators") + for item in calibrators: + _check_hash(root, item.get("calibrator_path"), item.get("calibrator_hash_sha256"), errors) + if require_active and item.get("status") != "ACTIVE": + errors.append(f"{item.get('model_name')} calibrator status must be ACTIVE") + if item.get("status") != model_bundle.get("status"): + errors.append(f"{item.get('model_name')} calibrator status does not match model_bundle_manifest.status") + pm_manifest = read_json(root / "manifests/position_manager_manifest.json") + if require_active and pm_manifest.get("status") != "ACTIVE": + errors.append("position_manager_manifest.status must be ACTIVE for Java SHADOW") + if pm_manifest.get("status") != model_bundle.get("status"): + errors.append("position_manager_manifest.status does not match model_bundle_manifest.status") + if run_onnx and not errors: + _run_sample_inference(root, manifests, errors) + + +def _release_gate(root: Path) -> dict[str, Any]: + reasons: list[str] = [] + model_bundle = read_json(root / "manifests/model_bundle_manifest.json") + if model_bundle.get("backtest_status") != "PASS": + reasons.append(f"backtest_status={model_bundle.get('backtest_status')}") + reasons.extend(model_bundle.get("backtest_status_reasons_json", [])) + for item in read_json(root / "manifests/model_manifest.json"): + if item.get("quality_status") != "PASS": + reasons.append(f"{item.get('model_type')}.quality_status={item.get('quality_status')}") + reasons.extend([f"{item.get('model_type')}:{reason}" for reason in item.get("quality_reasons_json", [])]) + for item in read_json(root / "manifests/calibration_manifest.json"): + if item.get("quality_status") != "PASS": + reasons.append(f"{item.get('model_name')}.calibration_quality_status={item.get('quality_status')}") + reasons.extend([f"{item.get('model_name')}:calibration:{reason}" for reason in item.get("quality_reasons_json", [])]) + return { + "release_gate_status": "PASS" if not reasons else "REJECTED", + "release_gate_reasons": reasons, + } + + +def _check_hash(root: Path, relative: str | None, expected: str | None, errors: list[str]) -> None: + if not relative or not expected: + errors.append(f"missing hash contract for {relative}") + return + path = root / relative + if not path.is_file(): + errors.append(f"hash target missing: {relative}") + return + actual = sha256_file(path) + if actual != expected: + errors.append(f"hash mismatch: {relative}") + + +def _run_sample_inference(root: Path, manifests: list[dict[str, Any]], errors: list[str]) -> None: + try: + import onnxruntime as ort + except ModuleNotFoundError as exc: + raise SystemExit("Python package 'onnxruntime' is required for --run-onnx validation.") from exc + sample = read_json(root / "examples/sample_input.json") + features = np.array([[float(sample[name]) for name in FEATURE_ORDER]], dtype=np.float32) + for item in manifests: + session = ort.InferenceSession(str(root / item["artifact_path"])) + outputs = session.run(None, {"features": features}) + if not outputs or outputs[0].shape[-1] != len(MODEL_OUTPUTS[item["model_type"]]): + errors.append(f"{item['model_type']} sample output shape is invalid")