Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -3,6 +3,8 @@ target/
|
||||
*.iml
|
||||
*.log
|
||||
.DS_Store
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Runtime and local data stay outside source control.
|
||||
logs/
|
||||
|
||||
@@ -47,10 +47,19 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-jdbc</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-data-redis</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-flyway</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>1.22.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.flywaydb</groupId>
|
||||
<artifactId>flyway-mysql</artifactId>
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.persistence.TraderJsonCodec;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public class JdbcTraderArtifactManifestRepository implements TraderArtifactManifestRepository {
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
private final TraderJsonCodec jsonCodec;
|
||||
|
||||
public JdbcTraderArtifactManifestRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
this.jsonCodec = new TraderJsonCodec(objectMapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void upsertActiveBundle(TraderArtifactBundle bundle) {
|
||||
upsertModelBundle(bundle.modelBundleManifest());
|
||||
bundle.modelManifests().forEach(this::upsertModelManifest);
|
||||
bundle.calibrationManifests().forEach(this::upsertCalibrationManifest);
|
||||
upsertPmConfigManifest(bundle.pmConfigManifest());
|
||||
}
|
||||
|
||||
private void upsertModelBundle(TraderModelBundleManifest manifest) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_model_bundle_manifest
|
||||
(manifest_schema_version, model_bundle_version, calibration_bundle_version,
|
||||
feature_version, label_version, split_version, training_run_id, training_export_id,
|
||||
backtest_manifest_id, required_models_json, provided_models_json, missing_models_json,
|
||||
allowed_run_modes_json, bundle_hash_sha256, complete, status)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
on duplicate key update
|
||||
manifest_schema_version = values(manifest_schema_version),
|
||||
feature_version = values(feature_version),
|
||||
label_version = values(label_version),
|
||||
split_version = values(split_version),
|
||||
training_run_id = values(training_run_id),
|
||||
training_export_id = values(training_export_id),
|
||||
backtest_manifest_id = values(backtest_manifest_id),
|
||||
required_models_json = values(required_models_json),
|
||||
provided_models_json = values(provided_models_json),
|
||||
missing_models_json = values(missing_models_json),
|
||||
allowed_run_modes_json = values(allowed_run_modes_json),
|
||||
bundle_hash_sha256 = values(bundle_hash_sha256),
|
||||
complete = values(complete),
|
||||
status = values(status)
|
||||
""",
|
||||
manifest.manifestSchemaVersion(), manifest.modelBundleVersion(), manifest.calibrationBundleVersion(),
|
||||
manifest.featureVersion(), manifest.labelVersion(), manifest.splitVersion(),
|
||||
manifest.trainingRunId(), manifest.trainingExportId(), manifest.backtestManifestId(),
|
||||
jsonCodec.toJson(manifest.requiredModels()),
|
||||
jsonCodec.toJson(manifest.providedModels()), jsonCodec.toJson(manifest.missingModels()),
|
||||
jsonCodec.toJson(manifest.allowedRunModes().stream().map(TraderRunMode::name).toList()),
|
||||
manifest.bundleHashSha256(), manifest.complete(), manifest.status());
|
||||
}
|
||||
|
||||
private void upsertModelManifest(TraderModelManifest manifest) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_model_manifest
|
||||
(model_bundle_version, calibration_bundle_version, model_name, model_type, side,
|
||||
symbol_scope_json, bar_interval, horizon_minutes, model_format, model_runtime,
|
||||
model_runtime_version, onnx_opset_version, producer_name, producer_version,
|
||||
artifact_path, artifact_hash_sha256, source_hash, feature_version, feature_schema_path,
|
||||
feature_schema_hash, feature_order_path, feature_order_hash, input_tensor_name, input_dtype, input_shape_json,
|
||||
input_example_path, output_schema_path, output_schema_hash, output_tensor_names_json,
|
||||
output_mapping_json, output_value_rules_json, label_version, split_version, training_fold,
|
||||
train_start, train_end, validation_start, validation_end, test_start, test_end,
|
||||
metrics_json, status)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
on duplicate key update
|
||||
model_type = values(model_type),
|
||||
symbol_scope_json = values(symbol_scope_json),
|
||||
bar_interval = values(bar_interval),
|
||||
model_format = values(model_format),
|
||||
model_runtime = values(model_runtime),
|
||||
model_runtime_version = values(model_runtime_version),
|
||||
onnx_opset_version = values(onnx_opset_version),
|
||||
producer_name = values(producer_name),
|
||||
producer_version = values(producer_version),
|
||||
artifact_path = values(artifact_path),
|
||||
artifact_hash_sha256 = values(artifact_hash_sha256),
|
||||
source_hash = values(source_hash),
|
||||
feature_version = values(feature_version),
|
||||
feature_schema_path = values(feature_schema_path),
|
||||
feature_schema_hash = values(feature_schema_hash),
|
||||
feature_order_path = values(feature_order_path),
|
||||
feature_order_hash = values(feature_order_hash),
|
||||
input_tensor_name = values(input_tensor_name),
|
||||
input_dtype = values(input_dtype),
|
||||
input_shape_json = values(input_shape_json),
|
||||
input_example_path = values(input_example_path),
|
||||
output_schema_path = values(output_schema_path),
|
||||
output_schema_hash = values(output_schema_hash),
|
||||
output_tensor_names_json = values(output_tensor_names_json),
|
||||
output_mapping_json = values(output_mapping_json),
|
||||
output_value_rules_json = values(output_value_rules_json),
|
||||
label_version = values(label_version),
|
||||
split_version = values(split_version),
|
||||
training_fold = values(training_fold),
|
||||
train_start = values(train_start),
|
||||
train_end = values(train_end),
|
||||
validation_start = values(validation_start),
|
||||
validation_end = values(validation_end),
|
||||
test_start = values(test_start),
|
||||
test_end = values(test_end),
|
||||
metrics_json = values(metrics_json),
|
||||
status = values(status)
|
||||
""",
|
||||
manifest.modelBundleVersion(), manifest.calibrationBundleVersion(), manifest.modelName(),
|
||||
manifest.modelType(), manifest.side(), jsonCodec.toJson(manifest.symbolScope()),
|
||||
manifest.barInterval(), manifest.horizonMinutes(), manifest.modelFormat(), manifest.modelRuntime(),
|
||||
manifest.modelRuntimeVersion(), manifest.onnxOpsetVersion(), manifest.producerName(),
|
||||
manifest.producerVersion(), manifest.artifactPath(), manifest.artifactHashSha256(),
|
||||
manifest.sourceHash(), manifest.featureVersion(), manifest.featureSchemaPath(),
|
||||
manifest.featureSchemaHash(), manifest.featureOrderPath(), manifest.featureOrderHash(), manifest.inputTensorName(),
|
||||
manifest.inputDtype(), jsonCodec.toJson(manifest.inputShapeJson()), manifest.inputExamplePath(),
|
||||
manifest.outputSchemaPath(), manifest.outputSchemaHash(), jsonCodec.toJson(manifest.outputTensorNames()),
|
||||
jsonCodec.toJson(manifest.outputMapping()), jsonCodec.toJson(manifest.outputValueRules()),
|
||||
manifest.labelVersion(), manifest.splitVersion(), manifest.trainingFold(),
|
||||
java.sql.Timestamp.from(manifest.trainStart()), java.sql.Timestamp.from(manifest.trainEnd()),
|
||||
java.sql.Timestamp.from(manifest.validationStart()), java.sql.Timestamp.from(manifest.validationEnd()),
|
||||
java.sql.Timestamp.from(manifest.testStart()), java.sql.Timestamp.from(manifest.testEnd()),
|
||||
jsonCodec.toJson(manifest.metricsJson()), manifest.status());
|
||||
}
|
||||
|
||||
private void upsertCalibrationManifest(TraderCalibrationManifest manifest) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_calibration_manifest
|
||||
(calibration_bundle_version, model_bundle_version, model_name, calibrator_version,
|
||||
calibration_method, calibrator_path, calibrator_hash_sha256,
|
||||
calibration_window_from, calibration_window_to, calibration_metrics_json,
|
||||
bucket_metrics_json, output_after_calibration_schema_hash, status)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
on duplicate key update
|
||||
calibrator_version = values(calibrator_version),
|
||||
calibration_method = values(calibration_method),
|
||||
calibrator_path = values(calibrator_path),
|
||||
calibrator_hash_sha256 = values(calibrator_hash_sha256),
|
||||
calibration_window_from = values(calibration_window_from),
|
||||
calibration_window_to = values(calibration_window_to),
|
||||
calibration_metrics_json = values(calibration_metrics_json),
|
||||
bucket_metrics_json = values(bucket_metrics_json),
|
||||
output_after_calibration_schema_hash = values(output_after_calibration_schema_hash),
|
||||
status = values(status)
|
||||
""",
|
||||
manifest.calibrationBundleVersion(), manifest.modelBundleVersion(), manifest.modelName(),
|
||||
manifest.calibratorVersion(), manifest.calibrationMethod(), manifest.calibratorPath(),
|
||||
manifest.calibratorHashSha256(), java.sql.Timestamp.from(manifest.calibrationWindowFrom()),
|
||||
java.sql.Timestamp.from(manifest.calibrationWindowTo()), jsonCodec.toJson(manifest.calibrationMetrics()),
|
||||
jsonCodec.toJson(manifest.bucketMetricsJson()), manifest.outputAfterCalibrationSchemaHash(),
|
||||
manifest.status());
|
||||
}
|
||||
|
||||
private void upsertPmConfigManifest(TraderPmConfigManifest manifest) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_pm_config_manifest
|
||||
(pm_config_version, model_bundle_version, calibration_bundle_version, threshold_stability_json,
|
||||
allowed_run_modes_json, config_json, config_hash_sha256, status)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
on duplicate key update
|
||||
model_bundle_version = values(model_bundle_version),
|
||||
calibration_bundle_version = values(calibration_bundle_version),
|
||||
threshold_stability_json = values(threshold_stability_json),
|
||||
allowed_run_modes_json = values(allowed_run_modes_json),
|
||||
config_json = values(config_json),
|
||||
config_hash_sha256 = values(config_hash_sha256),
|
||||
status = values(status)
|
||||
""",
|
||||
manifest.pmConfigVersion(), manifest.modelBundleVersion(), manifest.calibrationBundleVersion(),
|
||||
jsonCodec.toJson(manifest.thresholdStabilityJson()),
|
||||
jsonCodec.toJson(manifest.allowedRunModes().stream().map(TraderRunMode::name).toList()),
|
||||
jsonCodec.toJson(manifest.config()), manifest.configHashSha256(), manifest.status());
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,11 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.domain.TraderPmConfig;
|
||||
import com.quantai.trader.domain.TraderPricePlanContext;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public record TraderArtifactBundle(
|
||||
@@ -10,14 +14,31 @@ public record TraderArtifactBundle(
|
||||
String pmConfigVersion,
|
||||
String bundleHashSha256,
|
||||
Set<String> providedModels,
|
||||
TraderModelBundleManifest modelBundleManifest,
|
||||
List<TraderModelManifest> modelManifests,
|
||||
List<TraderCalibrationManifest> calibrationManifests,
|
||||
TraderPmConfigManifest pmConfigManifest,
|
||||
TraderPmConfig pmConfig,
|
||||
TraderArtifactModelPolicy modelPolicy
|
||||
TraderPricePlanContext pricePlanContext,
|
||||
TraderReplayModelFixture replayModelFixture
|
||||
) {
|
||||
public TraderArtifactBundle {
|
||||
if (providedModels == null || !providedModels.containsAll(Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"))) {
|
||||
throw new IllegalArgumentException("artifact bundle must provide all five V4 models");
|
||||
}
|
||||
modelBundleManifest = java.util.Objects.requireNonNull(modelBundleManifest, "modelBundleManifest is required");
|
||||
modelManifests = List.copyOf(modelManifests);
|
||||
calibrationManifests = List.copyOf(calibrationManifests);
|
||||
pmConfigManifest = java.util.Objects.requireNonNull(pmConfigManifest, "pmConfigManifest is required");
|
||||
pmConfig = java.util.Objects.requireNonNull(pmConfig, "pmConfig is required");
|
||||
modelPolicy = java.util.Objects.requireNonNull(modelPolicy, "modelPolicy is required");
|
||||
pricePlanContext = java.util.Objects.requireNonNull(pricePlanContext, "pricePlanContext is required");
|
||||
}
|
||||
|
||||
public TraderReplayModelFixture requireReplayModelFixture() {
|
||||
if (replayModelFixture == null) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"replay model fixture is required for replay fixture inference");
|
||||
}
|
||||
return replayModelFixture;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package com.quantai.trader.artifact;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.domain.TraderPmConfig;
|
||||
import com.quantai.trader.domain.TraderPricePlanContext;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.slf4j.Logger;
|
||||
@@ -14,6 +16,11 @@ import org.springframework.stereotype.Component;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
@@ -22,6 +29,27 @@ import java.util.stream.StreamSupport;
|
||||
public class TraderArtifactLoader {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderArtifactLoader.class);
|
||||
private static final Set<String> REQUIRED_MODELS = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
|
||||
private static final int REQUIRED_FEATURE_COUNT = 39;
|
||||
private static final int REQUIRED_ONNX_OPSET_VERSION = 17;
|
||||
private static final Map<String, Set<String>> REQUIRED_OUTPUT_MAPPING_KEYS = Map.of(
|
||||
"DIRECTION", Set.of("long_prob", "short_prob", "neutral_prob"),
|
||||
"ENTRY", Set.of("long_entry_prob", "short_entry_prob", "long_expected_net_edge_bps", "short_expected_net_edge_bps"),
|
||||
"CONTINUE", Set.of("long_continue_prob", "short_continue_prob", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"),
|
||||
"EXIT", Set.of("long_exit_prob", "short_exit_prob", "long_adverse_move_bps", "short_adverse_move_bps",
|
||||
"adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"),
|
||||
"RISK", Set.of("market_risk_prob", "long_position_risk_prob", "short_position_risk_prob",
|
||||
"market_path_risk_bps", "long_position_path_risk_bps", "short_position_path_risk_bps",
|
||||
"market_drawdown_prob", "volatility_expansion_prob", "spike_prob",
|
||||
"liquidity_deterioration_prob", "position_drawdown_prob")
|
||||
);
|
||||
private static final Set<String> REJECTED_OUTPUT_MAPPING_KEYS = Set.of(
|
||||
"expected_net_edge_bps",
|
||||
"expected_continue_edge_bps",
|
||||
"expected_giveback_bps",
|
||||
"market_expected_shortfall_bps",
|
||||
"position_expected_shortfall_bps",
|
||||
"position_risk_target"
|
||||
);
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ObjectMapper objectMapper;
|
||||
@@ -35,10 +63,15 @@ public class TraderArtifactLoader {
|
||||
TraderProperties.Artifact artifact = properties.artifact();
|
||||
Path root = Path.of(artifact.artifactRoot());
|
||||
TraderModelBundleManifest modelManifest = readModelBundleManifest(root.resolve("manifests/model_bundle_manifest.json"));
|
||||
List<TraderModelManifest> modelManifests = readModelManifests(root.resolve("manifests/model_manifest.json"));
|
||||
List<TraderCalibrationManifest> calibrationManifests = readCalibrationManifests(root.resolve("manifests/calibration_manifest.json"));
|
||||
TraderPmConfigManifest pmManifest = readPmConfigManifest(root.resolve("manifests/position_manager_manifest.json"));
|
||||
TraderArtifactModelPolicy modelPolicy = readJson(root.resolve("model_output_policy.json"), TraderArtifactModelPolicy.class);
|
||||
TraderPricePlanContext pricePlanContext = readJson(root.resolve("price_plan_context.json"), TraderPricePlanContext.class);
|
||||
TraderReplayModelFixture replayModelFixture = readOptionalJson(root.resolve("replay_model_fixture.json"), TraderReplayModelFixture.class);
|
||||
validateVersions(artifact, modelManifest, pmManifest);
|
||||
validateModelManifest(modelManifest);
|
||||
validateModelArtifacts(root, modelManifest, modelManifests);
|
||||
validateCalibrationArtifacts(root, modelManifest, calibrationManifests);
|
||||
validatePmManifest(pmManifest, properties.runMode());
|
||||
TraderArtifactBundle bundle = new TraderArtifactBundle(
|
||||
modelManifest.modelBundleVersion(),
|
||||
@@ -46,24 +79,111 @@ public class TraderArtifactLoader {
|
||||
pmManifest.pmConfigVersion(),
|
||||
modelManifest.bundleHashSha256(),
|
||||
modelManifest.providedModels(),
|
||||
modelManifest,
|
||||
modelManifests,
|
||||
calibrationManifests,
|
||||
pmManifest,
|
||||
pmManifest.config(),
|
||||
modelPolicy);
|
||||
pricePlanContext,
|
||||
replayModelFixture);
|
||||
log.info("event=trader.artifact.loaded modelBundleVersion={} calibrationBundleVersion={} pmConfigVersion={} providedModels={}",
|
||||
bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion(), bundle.providedModels());
|
||||
return bundle;
|
||||
}
|
||||
|
||||
private List<TraderModelManifest> readModelManifests(Path path) {
|
||||
JsonNode root = readJsonNode(path);
|
||||
if (!root.isArray()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model manifest must be an array: " + path);
|
||||
}
|
||||
return StreamSupport.stream(root.spliterator(), false)
|
||||
.map(node -> new TraderModelManifest(
|
||||
requiredText(node, "model_bundle_version", path),
|
||||
requiredText(node, "calibration_bundle_version", path),
|
||||
requiredText(node, "model_name", path),
|
||||
requiredText(node, "model_type", path),
|
||||
requiredText(node, "side", path),
|
||||
textSet(node, "symbol_scope_json", path),
|
||||
requiredText(node, "bar_interval", path),
|
||||
node.path("horizon_minutes").asInt(-1),
|
||||
requiredText(node, "model_format", path),
|
||||
requiredText(node, "model_runtime", path),
|
||||
requiredText(node, "model_runtime_version", path),
|
||||
node.path("onnx_opset_version").asInt(-1),
|
||||
requiredText(node, "producer_name", path),
|
||||
requiredText(node, "producer_version", path),
|
||||
requiredText(node, "feature_version", path),
|
||||
requiredText(node, "feature_schema_path", path),
|
||||
requiredText(node, "feature_schema_hash", path),
|
||||
requiredText(node, "feature_order_path", path),
|
||||
requiredText(node, "feature_order_hash", path),
|
||||
requiredText(node, "input_tensor_name", path),
|
||||
requiredText(node, "input_dtype", path),
|
||||
objectMap(node, "input_shape_json", path),
|
||||
requiredText(node, "input_example_path", path),
|
||||
requiredText(node, "output_schema_path", path),
|
||||
requiredText(node, "output_schema_hash", path),
|
||||
textSet(node, "output_tensor_names_json", path),
|
||||
objectMap(node, "output_mapping_json", path),
|
||||
objectMap(node, "output_value_rules_json", path),
|
||||
requiredText(node, "label_version", path),
|
||||
requiredText(node, "split_version", path),
|
||||
requiredText(node, "training_fold", path),
|
||||
instant(node, "train_start", path),
|
||||
instant(node, "train_end", path),
|
||||
instant(node, "validation_start", path),
|
||||
instant(node, "validation_end", path),
|
||||
instant(node, "test_start", path),
|
||||
instant(node, "test_end", path),
|
||||
objectMap(node, "metrics_json", path),
|
||||
requiredText(node, "artifact_path", path),
|
||||
requiredText(node, "artifact_hash_sha256", path),
|
||||
requiredText(node, "source_hash", path),
|
||||
requiredText(node, "status", path)))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private List<TraderCalibrationManifest> readCalibrationManifests(Path path) {
|
||||
JsonNode root = readJsonNode(path);
|
||||
if (!root.isArray()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"calibration manifest must be an array: " + path);
|
||||
}
|
||||
return StreamSupport.stream(root.spliterator(), false)
|
||||
.map(node -> new TraderCalibrationManifest(
|
||||
requiredText(node, "calibration_bundle_version", path),
|
||||
requiredText(node, "model_bundle_version", path),
|
||||
requiredText(node, "model_name", path),
|
||||
requiredText(node, "calibrator_version", path),
|
||||
requiredText(node, "calibration_method", path),
|
||||
requiredText(node, "calibrator_path", path),
|
||||
requiredText(node, "calibrator_hash_sha256", path),
|
||||
instant(node, "calibration_window_from", path),
|
||||
instant(node, "calibration_window_to", path),
|
||||
objectMap(node, "calibration_metrics_json", path),
|
||||
objectMap(node, "bucket_metrics_json", path),
|
||||
requiredText(node, "output_after_calibration_schema_hash", path),
|
||||
requiredText(node, "status", path)))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private TraderModelBundleManifest readModelBundleManifest(Path path) {
|
||||
JsonNode root = readJsonNode(path);
|
||||
return new TraderModelBundleManifest(
|
||||
requiredText(root, "manifest_schema_version", path),
|
||||
requiredText(root, "model_bundle_version", path),
|
||||
requiredText(root, "calibration_bundle_version", path),
|
||||
requiredText(root, "feature_version", path),
|
||||
requiredText(root, "label_version", path),
|
||||
requiredText(root, "split_version", path),
|
||||
requiredText(root, "training_run_id", path),
|
||||
requiredText(root, "training_export_id", path),
|
||||
requiredText(root, "backtest_manifest_id", path),
|
||||
textSet(root, "required_models_json", path),
|
||||
textSet(root, "provided_models_json", path),
|
||||
textSet(root, "missing_models_json", path),
|
||||
enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path),
|
||||
requiredText(root, "bundle_hash_sha256", path),
|
||||
root.path("complete").asBoolean(false),
|
||||
requiredText(root, "status", path));
|
||||
@@ -75,6 +195,7 @@ public class TraderArtifactLoader {
|
||||
requiredText(root, "pm_config_version", path),
|
||||
requiredText(root, "model_bundle_version", path),
|
||||
requiredText(root, "calibration_bundle_version", path),
|
||||
objectMap(root, "threshold_stability_json", path),
|
||||
enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path),
|
||||
convert(root.path("config_json"), TraderPmConfig.class, path),
|
||||
requiredText(root, "config_hash_sha256", path),
|
||||
@@ -106,6 +227,86 @@ public class TraderArtifactLoader {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model bundle must provide all five V4 models with no missing model");
|
||||
}
|
||||
if (!manifest.allowedRunModes().contains(properties.runMode())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model bundle does not allow current run mode");
|
||||
}
|
||||
}
|
||||
|
||||
private void validateModelArtifacts(Path root, TraderModelBundleManifest bundleManifest,
|
||||
List<TraderModelManifest> manifests) {
|
||||
Set<String> modelNames = manifests.stream().map(TraderModelManifest::modelName).collect(Collectors.toUnmodifiableSet());
|
||||
if (!modelNames.containsAll(REQUIRED_MODELS)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model_manifest.json must contain all five V4 model families");
|
||||
}
|
||||
for (TraderModelManifest manifest : manifests) {
|
||||
if (!bundleManifest.modelBundleVersion().equals(manifest.modelBundleVersion())
|
||||
|| !bundleManifest.calibrationBundleVersion().equals(manifest.calibrationBundleVersion())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
|
||||
"model manifest version does not match bundle manifest");
|
||||
}
|
||||
if (!"ACTIVE".equals(manifest.status())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model manifest must be ACTIVE: " + manifest.modelName());
|
||||
}
|
||||
if (!REQUIRED_MODELS.contains(manifest.modelType())
|
||||
|| !"ONNX".equals(manifest.modelFormat())
|
||||
|| !"ONNX_RUNTIME_JAVA".equals(manifest.modelRuntime())
|
||||
|| !"FLOAT32".equals(manifest.inputDtype())
|
||||
|| manifest.horizonMinutes() <= 0
|
||||
|| manifest.onnxOpsetVersion() != REQUIRED_ONNX_OPSET_VERSION
|
||||
|| inputFeatureCount(manifest) != REQUIRED_FEATURE_COUNT
|
||||
|| !"features".equals(manifest.inputTensorName())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model manifest runtime contract is invalid: " + manifest.modelName());
|
||||
}
|
||||
validateOutputMapping(manifest);
|
||||
validateSha256(root.resolve(manifest.artifactPath()), manifest.artifactHashSha256());
|
||||
validateSha256(root.resolve(manifest.featureSchemaPath()), manifest.featureSchemaHash());
|
||||
validateSha256(root.resolve(manifest.featureOrderPath()), manifest.featureOrderHash());
|
||||
validateFeatureOrder(root.resolve(manifest.featureOrderPath()));
|
||||
validateSha256(root.resolve(manifest.outputSchemaPath()), manifest.outputSchemaHash());
|
||||
if (!Files.isRegularFile(root.resolve(manifest.inputExamplePath()))) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model input example is missing: " + manifest.inputExamplePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void validateOutputMapping(TraderModelManifest manifest) {
|
||||
if (manifest.outputMapping().keySet().stream().anyMatch(REJECTED_OUTPUT_MAPPING_KEYS::contains)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model output mapping contains rejected legacy key: " + manifest.modelName());
|
||||
}
|
||||
Set<String> requiredKeys = REQUIRED_OUTPUT_MAPPING_KEYS.get(manifest.modelType());
|
||||
if (requiredKeys == null || !manifest.outputMapping().keySet().containsAll(requiredKeys)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model output mapping does not match V4 contract: " + manifest.modelName());
|
||||
}
|
||||
}
|
||||
|
||||
private void validateCalibrationArtifacts(Path root, TraderModelBundleManifest modelManifest,
|
||||
List<TraderCalibrationManifest> calibrationManifests) {
|
||||
Set<String> calibratedModels = calibrationManifests.stream()
|
||||
.map(TraderCalibrationManifest::modelName)
|
||||
.collect(Collectors.toUnmodifiableSet());
|
||||
if (!calibratedModels.containsAll(REQUIRED_MODELS)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
|
||||
"calibration_manifest.json must contain all five V4 model families");
|
||||
}
|
||||
for (TraderCalibrationManifest calibrationManifest : calibrationManifests) {
|
||||
if (!modelManifest.modelBundleVersion().equals(calibrationManifest.modelBundleVersion())
|
||||
|| !modelManifest.calibrationBundleVersion().equals(calibrationManifest.calibrationBundleVersion())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
|
||||
"calibration manifest version does not match bundle manifest");
|
||||
}
|
||||
if (!"ACTIVE".equals(calibrationManifest.status())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
|
||||
"calibration manifest must be ACTIVE");
|
||||
}
|
||||
validateSha256(root.resolve(calibrationManifest.calibratorPath()), calibrationManifest.calibratorHashSha256());
|
||||
}
|
||||
}
|
||||
|
||||
private void validatePmManifest(TraderPmConfigManifest manifest, TraderRunMode runMode) {
|
||||
@@ -119,6 +320,29 @@ public class TraderArtifactLoader {
|
||||
}
|
||||
}
|
||||
|
||||
private int inputFeatureCount(TraderModelManifest manifest) {
|
||||
Object features = manifest.inputShapeJson().get("features");
|
||||
if (features instanceof Number number) {
|
||||
return number.intValue();
|
||||
}
|
||||
if (features instanceof String text) {
|
||||
try {
|
||||
return Integer.parseInt(text);
|
||||
} catch (NumberFormatException ignored) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
private void validateFeatureOrder(Path path) {
|
||||
JsonNode featureOrder = readJsonNode(path);
|
||||
if (!featureOrder.isArray() || featureOrder.size() != REQUIRED_FEATURE_COUNT) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"feature_order.json must contain exactly " + REQUIRED_FEATURE_COUNT + " fields: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
private JsonNode readJsonNode(Path path) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
@@ -132,6 +356,38 @@ public class TraderArtifactLoader {
|
||||
}
|
||||
}
|
||||
|
||||
private void validateSha256(Path path, String expectedHash) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact referenced by manifest is missing: " + path);
|
||||
}
|
||||
String actualHash;
|
||||
try {
|
||||
actualHash = sha256(path);
|
||||
} catch (IOException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact referenced by manifest cannot be read: " + path);
|
||||
}
|
||||
if (!expectedHash.equals(actualHash)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact sha256 mismatch: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
private String sha256(Path path) throws IOException {
|
||||
try {
|
||||
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
||||
byte[] hash = digest.digest(Files.readAllBytes(path));
|
||||
StringBuilder builder = new StringBuilder(hash.length * 2);
|
||||
for (byte value : hash) {
|
||||
builder.append(String.format("%02x", value & 0xff));
|
||||
}
|
||||
return builder.toString();
|
||||
} catch (NoSuchAlgorithmException exception) {
|
||||
throw new IllegalStateException("SHA-256 is not available", exception);
|
||||
}
|
||||
}
|
||||
|
||||
private <T> T readJson(Path path, Class<T> type) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
@@ -145,6 +401,13 @@ public class TraderArtifactLoader {
|
||||
}
|
||||
}
|
||||
|
||||
private <T> T readOptionalJson(Path path, Class<T> type) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
return null;
|
||||
}
|
||||
return readJson(path, type);
|
||||
}
|
||||
|
||||
private <T> T convert(JsonNode node, Class<T> type, Path path) {
|
||||
if (node == null || node.isMissingNode() || node.isNull()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
@@ -167,6 +430,20 @@ public class TraderArtifactLoader {
|
||||
return value;
|
||||
}
|
||||
|
||||
private Instant instant(JsonNode node, String field, Path path) {
|
||||
return Instant.parse(requiredText(node, field, path));
|
||||
}
|
||||
|
||||
private Map<String, Object> objectMap(JsonNode node, String field, Path path) {
|
||||
JsonNode value = node.path(field);
|
||||
if (value.isMissingNode() || value.isNull() || !value.isObject()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact object field is required: " + path + "#" + field);
|
||||
}
|
||||
return objectMapper.convertValue(value, new TypeReference<>() {
|
||||
});
|
||||
}
|
||||
|
||||
private Set<String> textSet(JsonNode node, String field, Path path) {
|
||||
JsonNode array = node.path(field);
|
||||
if (!array.isArray()) {
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
public interface TraderArtifactManifestRepository {
|
||||
void upsertActiveBundle(TraderArtifactBundle bundle);
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
public record TraderArtifactModelPolicy(
|
||||
DirectionPolicy direction,
|
||||
EntryPolicy entry,
|
||||
ContinuePolicy continuation,
|
||||
ExitPolicy exit,
|
||||
RiskPolicy risk,
|
||||
BigDecimal uncertainty,
|
||||
BigDecimal oodScore
|
||||
) {
|
||||
public record DirectionPolicy(
|
||||
BigDecimal longProbWhenMarkGteIndex,
|
||||
BigDecimal longProbWhenMarkLtIndex,
|
||||
BigDecimal neutralProb,
|
||||
BigDecimal expectedReturnBps,
|
||||
int horizonMinutes,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record EntryPolicy(
|
||||
BigDecimal longEntryProb,
|
||||
BigDecimal shortEntryProb,
|
||||
BigDecimal entryQualityScore,
|
||||
BigDecimal expectedEdgeBps,
|
||||
String pricePlanId,
|
||||
String pricePlanConfigHash,
|
||||
BigDecimal stopDistanceBps,
|
||||
BigDecimal targetDistanceBps,
|
||||
int maxHoldMinutes,
|
||||
BigDecimal costBps,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record ContinuePolicy(
|
||||
BigDecimal longContinueProb,
|
||||
BigDecimal shortContinueProb,
|
||||
BigDecimal trendPersistenceProb,
|
||||
BigDecimal holdEdgeBps,
|
||||
BigDecimal continueVsExitEdgeBps,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record ExitPolicy(
|
||||
BigDecimal longExitProb,
|
||||
BigDecimal shortExitProb,
|
||||
BigDecimal profitGivebackProb,
|
||||
BigDecimal reversalProb,
|
||||
BigDecimal stopRiskProb,
|
||||
BigDecimal stagnationProb,
|
||||
BigDecimal expectedGivebackBps,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record RiskPolicy(
|
||||
BigDecimal marketRiskProb,
|
||||
BigDecimal positionRiskProb,
|
||||
BigDecimal marketRiskSeverityBps,
|
||||
BigDecimal positionRiskSeverityBps,
|
||||
BigDecimal drawdownProb,
|
||||
BigDecimal expectedShortfallBps,
|
||||
BigDecimal volatilityExpansionProb,
|
||||
BigDecimal spikeProb,
|
||||
BigDecimal liquidityRiskProb,
|
||||
BigDecimal liquidityCapacityRatioWhenReady,
|
||||
BigDecimal liquidityCapacityRatioWhenNotReady,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Map;
|
||||
|
||||
public record TraderCalibrationManifest(
|
||||
String calibrationBundleVersion,
|
||||
String modelBundleVersion,
|
||||
String modelName,
|
||||
String calibratorVersion,
|
||||
String calibrationMethod,
|
||||
String calibratorPath,
|
||||
String calibratorHashSha256,
|
||||
Instant calibrationWindowFrom,
|
||||
Instant calibrationWindowTo,
|
||||
Map<String, Object> calibrationMetrics,
|
||||
Map<String, Object> bucketMetricsJson,
|
||||
String outputAfterCalibrationSchemaHash,
|
||||
String status
|
||||
) {
|
||||
public TraderCalibrationManifest {
|
||||
calibrationMetrics = Map.copyOf(calibrationMetrics == null ? Map.of() : calibrationMetrics);
|
||||
bucketMetricsJson = Map.copyOf(bucketMetricsJson == null ? Map.of() : bucketMetricsJson);
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,23 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public record TraderModelBundleManifest(
|
||||
String manifestSchemaVersion,
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
String featureVersion,
|
||||
String labelVersion,
|
||||
String splitVersion,
|
||||
String trainingRunId,
|
||||
String trainingExportId,
|
||||
String backtestManifestId,
|
||||
Set<String> requiredModels,
|
||||
Set<String> providedModels,
|
||||
Set<String> missingModels,
|
||||
Set<TraderRunMode> allowedRunModes,
|
||||
String bundleHashSha256,
|
||||
boolean complete,
|
||||
String status
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public record TraderModelManifest(
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
String modelName,
|
||||
String modelType,
|
||||
String side,
|
||||
Set<String> symbolScope,
|
||||
String barInterval,
|
||||
int horizonMinutes,
|
||||
String modelFormat,
|
||||
String modelRuntime,
|
||||
String modelRuntimeVersion,
|
||||
int onnxOpsetVersion,
|
||||
String producerName,
|
||||
String producerVersion,
|
||||
String featureVersion,
|
||||
String featureSchemaPath,
|
||||
String featureSchemaHash,
|
||||
String featureOrderPath,
|
||||
String featureOrderHash,
|
||||
String inputTensorName,
|
||||
String inputDtype,
|
||||
Map<String, Object> inputShapeJson,
|
||||
String inputExamplePath,
|
||||
String outputSchemaPath,
|
||||
String outputSchemaHash,
|
||||
Set<String> outputTensorNames,
|
||||
Map<String, Object> outputMapping,
|
||||
Map<String, Object> outputValueRules,
|
||||
String labelVersion,
|
||||
String splitVersion,
|
||||
String trainingFold,
|
||||
Instant trainStart,
|
||||
Instant trainEnd,
|
||||
Instant validationStart,
|
||||
Instant validationEnd,
|
||||
Instant testStart,
|
||||
Instant testEnd,
|
||||
Map<String, Object> metricsJson,
|
||||
String artifactPath,
|
||||
String artifactHashSha256,
|
||||
String sourceHash,
|
||||
String status
|
||||
) {
|
||||
public TraderModelManifest {
|
||||
symbolScope = Set.copyOf(symbolScope == null ? Set.of() : symbolScope);
|
||||
inputShapeJson = Map.copyOf(inputShapeJson == null ? Map.of() : inputShapeJson);
|
||||
outputTensorNames = Set.copyOf(outputTensorNames == null ? Set.of() : outputTensorNames);
|
||||
outputMapping = Map.copyOf(outputMapping == null ? Map.of() : outputMapping);
|
||||
outputValueRules = Map.copyOf(outputValueRules == null ? Map.of() : outputValueRules);
|
||||
metricsJson = Map.copyOf(metricsJson == null ? Map.of() : metricsJson);
|
||||
}
|
||||
}
|
||||
@@ -3,15 +3,20 @@ package com.quantai.trader.artifact;
|
||||
import com.quantai.trader.domain.TraderPmConfig;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public record TraderPmConfigManifest(
|
||||
String pmConfigVersion,
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
Map<String, Object> thresholdStabilityJson,
|
||||
Set<TraderRunMode> allowedRunModes,
|
||||
TraderPmConfig config,
|
||||
String configHashSha256,
|
||||
String status
|
||||
) {
|
||||
public TraderPmConfigManifest {
|
||||
thresholdStabilityJson = Map.copyOf(thresholdStabilityJson == null ? Map.of() : thresholdStabilityJson);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
|
||||
public record TraderReplayModelFixture(
|
||||
DirectionFixture direction,
|
||||
EntryFixture entry,
|
||||
ContinueFixture continuation,
|
||||
ExitFixture exit,
|
||||
RiskFixture risk,
|
||||
BigDecimal uncertainty,
|
||||
BigDecimal oodScore,
|
||||
String featureSchemaHash,
|
||||
String featureOrderHash,
|
||||
String outputSchemaHash
|
||||
) {
|
||||
public record DirectionFixture(
|
||||
BigDecimal longProbWhenMarkGteIndex,
|
||||
BigDecimal longProbWhenMarkLtIndex,
|
||||
BigDecimal neutralProb
|
||||
) {
|
||||
}
|
||||
|
||||
public record EntryFixture(
|
||||
BigDecimal longEntryProb,
|
||||
BigDecimal shortEntryProb,
|
||||
BigDecimal longExpectedNetEdgeBps,
|
||||
BigDecimal shortExpectedNetEdgeBps
|
||||
) {
|
||||
}
|
||||
|
||||
public record ContinueFixture(
|
||||
BigDecimal longContinueProb,
|
||||
BigDecimal shortContinueProb,
|
||||
BigDecimal longExpectedContinueEdgeBps,
|
||||
BigDecimal shortExpectedContinueEdgeBps
|
||||
) {
|
||||
}
|
||||
|
||||
public record ExitFixture(
|
||||
BigDecimal longExitProb,
|
||||
BigDecimal shortExitProb,
|
||||
BigDecimal longAdverseMoveBps,
|
||||
BigDecimal shortAdverseMoveBps,
|
||||
Map<String, BigDecimal> exitReasonScores
|
||||
) {
|
||||
public ExitFixture {
|
||||
exitReasonScores = Map.copyOf(exitReasonScores == null ? Map.of() : exitReasonScores);
|
||||
}
|
||||
}
|
||||
|
||||
public record RiskFixture(
|
||||
BigDecimal marketRiskProb,
|
||||
BigDecimal longPositionRiskProb,
|
||||
BigDecimal shortPositionRiskProb,
|
||||
BigDecimal marketPathRiskBps,
|
||||
BigDecimal longPositionPathRiskBps,
|
||||
BigDecimal shortPositionPathRiskBps,
|
||||
Map<String, BigDecimal> riskReasonScores
|
||||
) {
|
||||
public RiskFixture {
|
||||
riskReasonScores = Map.copyOf(riskReasonScores == null ? Map.of() : riskReasonScores);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.nonNegative;
|
||||
import static com.quantai.trader.util.TraderNumbers.positive;
|
||||
import static com.quantai.trader.util.TraderNumbers.requiredText;
|
||||
|
||||
@ConfigurationProperties(prefix = "trader")
|
||||
@@ -23,21 +25,24 @@ public record TraderProperties(
|
||||
PositionManager positionManager
|
||||
) {
|
||||
public TraderProperties {
|
||||
serviceName = defaultText(serviceName, "quant-trader-service");
|
||||
runMode = runMode == null ? TraderRunMode.SHADOW : runMode;
|
||||
symbol = defaultText(symbol, "BTC-USDT-PERP");
|
||||
artifact = artifact == null ? new Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", ".") : artifact;
|
||||
feedback = feedback == null ? new Feedback(false) : feedback;
|
||||
execution = execution == null ? new Execution(TraderExecutionMode.SHADOW, 3, 1500) : execution;
|
||||
runtime = runtime == null ? new Runtime("trader:v4", true, false) : runtime;
|
||||
outbox = outbox == null ? new Outbox(true, 5) : outbox;
|
||||
release = release == null ? new Release(true, true, true) : release;
|
||||
risk = risk == null ? new Risk(new BigDecimal("200"), BigDecimal.ONE, new BigDecimal("500")) : risk;
|
||||
positionManager = positionManager == null ? new PositionManager(BigDecimal.ONE, BigDecimal.ONE) : positionManager;
|
||||
serviceName = requiredText(serviceName, "serviceName");
|
||||
runMode = require(runMode, "runMode");
|
||||
symbol = requiredText(symbol, "symbol");
|
||||
artifact = require(artifact, "artifact");
|
||||
feedback = require(feedback, "feedback");
|
||||
execution = require(execution, "execution");
|
||||
runtime = require(runtime, "runtime");
|
||||
outbox = require(outbox, "outbox");
|
||||
release = require(release, "release");
|
||||
risk = require(risk, "risk");
|
||||
positionManager = require(positionManager, "positionManager");
|
||||
}
|
||||
|
||||
private static String defaultText(String value, String defaultValue) {
|
||||
return value == null || value.isBlank() ? defaultValue : value;
|
||||
private static <T> T require(T value, String field) {
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException(field + " is required");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
public record Artifact(
|
||||
@@ -57,19 +62,30 @@ public record TraderProperties(
|
||||
public record Feedback(boolean httpEnabled) {
|
||||
}
|
||||
|
||||
public record Execution(TraderExecutionMode mode, int maxApiErrorCount, long maxExchangeLatencyMs) {
|
||||
public record Execution(TraderExecutionMode mode, Integer maxApiErrorCount, Long maxExchangeLatencyMs) {
|
||||
public Execution {
|
||||
mode = mode == null ? TraderExecutionMode.SHADOW : mode;
|
||||
mode = require(mode, "execution.mode");
|
||||
if (require(maxApiErrorCount, "execution.maxApiErrorCount") < 0) {
|
||||
throw new IllegalArgumentException("execution.maxApiErrorCount must be >= 0");
|
||||
}
|
||||
if (require(maxExchangeLatencyMs, "execution.maxExchangeLatencyMs") <= 0) {
|
||||
throw new IllegalArgumentException("execution.maxExchangeLatencyMs must be > 0");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public record Runtime(String redisKeyPrefix, boolean requireRedisForOpenAdd, boolean tradingEnabled) {
|
||||
public Runtime {
|
||||
redisKeyPrefix = defaultText(redisKeyPrefix, "trader:v4");
|
||||
redisKeyPrefix = requiredText(redisKeyPrefix, "runtime.redisKeyPrefix");
|
||||
}
|
||||
}
|
||||
|
||||
public record Outbox(boolean enabled, int maxRetryCount) {
|
||||
public record Outbox(boolean enabled, Integer maxRetryCount) {
|
||||
public Outbox {
|
||||
if (require(maxRetryCount, "outbox.maxRetryCount") < 0) {
|
||||
throw new IllegalArgumentException("outbox.maxRetryCount must be >= 0");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public record Release(boolean requireReviewForPaper, boolean requireReviewForLiveProbe, boolean activePointerCheckEnabled) {
|
||||
@@ -77,16 +93,16 @@ public record TraderProperties(
|
||||
|
||||
public record Risk(BigDecimal maxDailyLossBps, BigDecimal maxTotalExposureRatio, BigDecimal minLiquidationBufferBps) {
|
||||
public Risk {
|
||||
maxDailyLossBps = maxDailyLossBps == null ? new BigDecimal("200") : maxDailyLossBps;
|
||||
maxTotalExposureRatio = maxTotalExposureRatio == null ? BigDecimal.ONE : maxTotalExposureRatio;
|
||||
minLiquidationBufferBps = minLiquidationBufferBps == null ? new BigDecimal("500") : minLiquidationBufferBps;
|
||||
maxDailyLossBps = nonNegative(maxDailyLossBps, "risk.maxDailyLossBps");
|
||||
maxTotalExposureRatio = positive(maxTotalExposureRatio, "risk.maxTotalExposureRatio");
|
||||
minLiquidationBufferBps = nonNegative(minLiquidationBufferBps, "risk.minLiquidationBufferBps");
|
||||
}
|
||||
}
|
||||
|
||||
public record PositionManager(BigDecimal maxSingleLegRatio, BigDecimal maxTotalPositionRatio) {
|
||||
public PositionManager {
|
||||
maxSingleLegRatio = maxSingleLegRatio == null ? BigDecimal.ONE : maxSingleLegRatio;
|
||||
maxTotalPositionRatio = maxTotalPositionRatio == null ? BigDecimal.ONE : maxTotalPositionRatio;
|
||||
maxSingleLegRatio = positive(maxSingleLegRatio, "positionManager.maxSingleLegRatio");
|
||||
maxTotalPositionRatio = positive(maxTotalPositionRatio, "positionManager.maxTotalPositionRatio");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,6 @@ public class TraderApiExceptionHandler {
|
||||
|
||||
@ExceptionHandler(IllegalArgumentException.class)
|
||||
ResponseEntity<TraderApiError> illegalArgument(IllegalArgumentException exception) {
|
||||
return ResponseEntity.badRequest().body(new TraderApiError(com.quantai.trader.enums.TraderErrorCode.TRADER_MODEL_OUTPUT_INVALID, exception.getMessage()));
|
||||
return ResponseEntity.badRequest().body(new TraderApiError(com.quantai.trader.enums.TraderErrorCode.TRADER_REQUEST_INVALID, exception.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.quantai.trader.domain.FeedbackValidator;
|
||||
import com.quantai.trader.domain.TraderAppFeedback;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import com.quantai.trader.feedback.TraderFeedbackRepository;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
@@ -18,10 +19,13 @@ public class TraderFeedbackController {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderFeedbackController.class);
|
||||
private final TraderProperties properties;
|
||||
private final FeedbackValidator feedbackValidator;
|
||||
private final TraderFeedbackRepository feedbackRepository;
|
||||
|
||||
public TraderFeedbackController(TraderProperties properties, FeedbackValidator feedbackValidator) {
|
||||
public TraderFeedbackController(TraderProperties properties, FeedbackValidator feedbackValidator,
|
||||
TraderFeedbackRepository feedbackRepository) {
|
||||
this.properties = properties;
|
||||
this.feedbackValidator = feedbackValidator;
|
||||
this.feedbackRepository = feedbackRepository;
|
||||
}
|
||||
|
||||
@PostMapping("/api/trader/feedback")
|
||||
@@ -30,6 +34,7 @@ public class TraderFeedbackController {
|
||||
throw new TraderException(TraderErrorCode.TRADER_FEEDBACK_INVALID, "HTTP feedback is disabled in P0");
|
||||
}
|
||||
feedbackValidator.validateP0(feedback);
|
||||
feedbackRepository.insert(feedback);
|
||||
log.info("event=trader.feedback.accepted runId={} cycleId={} actionId={} source={}",
|
||||
feedback.runId(), feedback.cycleId(), feedback.actionId(), feedback.feedbackSource());
|
||||
return Map.of("accepted", true, "feedbackId", feedback.feedbackId());
|
||||
|
||||
@@ -2,27 +2,29 @@ package com.quantai.trader.domain;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
public record ContinueOutput(
|
||||
BigDecimal longContinueProb,
|
||||
BigDecimal shortContinueProb,
|
||||
BigDecimal trendPersistenceProb,
|
||||
BigDecimal holdEdgeBps,
|
||||
BigDecimal continueVsExitEdgeBps,
|
||||
String modelVersion,
|
||||
String calibrationVersion,
|
||||
Map<String, Object> explanation
|
||||
BigDecimal longExpectedContinueEdgeBps,
|
||||
BigDecimal shortExpectedContinueEdgeBps
|
||||
) {
|
||||
public ContinueOutput {
|
||||
longContinueProb = probability(longContinueProb, "continue.longContinueProb");
|
||||
shortContinueProb = probability(shortContinueProb, "continue.shortContinueProb");
|
||||
trendPersistenceProb = probability(trendPersistenceProb, "continue.trendPersistenceProb");
|
||||
holdEdgeBps = required(holdEdgeBps, "continue.holdEdgeBps");
|
||||
continueVsExitEdgeBps = required(continueVsExitEdgeBps, "continue.continueVsExitEdgeBps");
|
||||
modelVersion = requiredText(modelVersion, "continue.modelVersion");
|
||||
calibrationVersion = requiredText(calibrationVersion, "continue.calibrationVersion");
|
||||
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
|
||||
longExpectedContinueEdgeBps = required(longExpectedContinueEdgeBps, "continue.longExpectedContinueEdgeBps");
|
||||
shortExpectedContinueEdgeBps = required(shortExpectedContinueEdgeBps, "continue.shortExpectedContinueEdgeBps");
|
||||
}
|
||||
|
||||
public BigDecimal continueEdgeBpsFor(PositionSide side) {
|
||||
if (side == PositionSide.LONG) {
|
||||
return longExpectedContinueEdgeBps;
|
||||
}
|
||||
if (side == PositionSide.SHORT) {
|
||||
return shortExpectedContinueEdgeBps;
|
||||
}
|
||||
throw new IllegalArgumentException("continue edge requires LONG or SHORT side");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,32 +3,27 @@ package com.quantai.trader.domain;
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
|
||||
public record DirectionOutput(
|
||||
BigDecimal longProb,
|
||||
BigDecimal shortProb,
|
||||
BigDecimal neutralProb,
|
||||
BigDecimal directionConfidence,
|
||||
BigDecimal directionMargin,
|
||||
BigDecimal expectedReturnBps,
|
||||
Integer horizonMinutes,
|
||||
String modelVersion,
|
||||
String calibrationVersion,
|
||||
Map<String, Object> explanation
|
||||
BigDecimal neutralProb
|
||||
) {
|
||||
public DirectionOutput {
|
||||
longProb = probability(longProb, "direction.longProb");
|
||||
shortProb = probability(shortProb, "direction.shortProb");
|
||||
neutralProb = probability(neutralProb, "direction.neutralProb");
|
||||
directionConfidence = probability(directionConfidence, "direction.directionConfidence");
|
||||
directionMargin = nonNegative(directionMargin, "direction.directionMargin");
|
||||
expectedReturnBps = required(expectedReturnBps, "direction.expectedReturnBps");
|
||||
if (horizonMinutes == null || horizonMinutes <= 0) {
|
||||
throw new IllegalArgumentException("direction.horizonMinutes must be > 0");
|
||||
BigDecimal sum = longProb.add(shortProb).add(neutralProb);
|
||||
if (sum.subtract(BigDecimal.ONE).abs().compareTo(new BigDecimal("0.000001")) > 0) {
|
||||
throw new IllegalArgumentException("direction probabilities must sum to 1");
|
||||
}
|
||||
modelVersion = requiredText(modelVersion, "direction.modelVersion");
|
||||
calibrationVersion = requiredText(calibrationVersion, "direction.calibrationVersion");
|
||||
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
|
||||
}
|
||||
|
||||
public BigDecimal margin() {
|
||||
return longProb.subtract(shortProb).abs();
|
||||
}
|
||||
|
||||
public BigDecimal confidence() {
|
||||
return longProb.max(shortProb).max(neutralProb);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,39 +2,29 @@ package com.quantai.trader.domain;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
public record EntryOutput(
|
||||
BigDecimal longEntryProb,
|
||||
BigDecimal shortEntryProb,
|
||||
BigDecimal entryQualityScore,
|
||||
BigDecimal expectedEdgeBps,
|
||||
String pricePlanId,
|
||||
String pricePlanConfigHash,
|
||||
BigDecimal stopDistanceBps,
|
||||
BigDecimal targetDistanceBps,
|
||||
Integer maxHoldMinutes,
|
||||
BigDecimal costBps,
|
||||
String modelVersion,
|
||||
String calibrationVersion,
|
||||
Map<String, Object> explanation
|
||||
BigDecimal longExpectedNetEdgeBps,
|
||||
BigDecimal shortExpectedNetEdgeBps
|
||||
) {
|
||||
public EntryOutput {
|
||||
longEntryProb = probability(longEntryProb, "entry.longEntryProb");
|
||||
shortEntryProb = probability(shortEntryProb, "entry.shortEntryProb");
|
||||
entryQualityScore = probability(entryQualityScore, "entry.entryQualityScore");
|
||||
expectedEdgeBps = required(expectedEdgeBps, "entry.expectedEdgeBps");
|
||||
pricePlanId = requiredText(pricePlanId, "entry.pricePlanId");
|
||||
pricePlanConfigHash = requiredText(pricePlanConfigHash, "entry.pricePlanConfigHash");
|
||||
stopDistanceBps = positive(stopDistanceBps, "entry.stopDistanceBps");
|
||||
targetDistanceBps = positive(targetDistanceBps, "entry.targetDistanceBps");
|
||||
if (maxHoldMinutes == null || maxHoldMinutes <= 0) {
|
||||
throw new IllegalArgumentException("entry.maxHoldMinutes must be > 0");
|
||||
longExpectedNetEdgeBps = required(longExpectedNetEdgeBps, "entry.longExpectedNetEdgeBps");
|
||||
shortExpectedNetEdgeBps = required(shortExpectedNetEdgeBps, "entry.shortExpectedNetEdgeBps");
|
||||
}
|
||||
|
||||
public BigDecimal netEdgeBpsFor(PositionSide side) {
|
||||
if (side == PositionSide.LONG) {
|
||||
return longExpectedNetEdgeBps;
|
||||
}
|
||||
costBps = nonNegative(costBps, "entry.costBps");
|
||||
modelVersion = requiredText(modelVersion, "entry.modelVersion");
|
||||
calibrationVersion = requiredText(calibrationVersion, "entry.calibrationVersion");
|
||||
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
|
||||
if (side == PositionSide.SHORT) {
|
||||
return shortExpectedNetEdgeBps;
|
||||
}
|
||||
throw new IllegalArgumentException("entry edge requires LONG or SHORT side");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,29 +4,39 @@ import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public record ExitOutput(
|
||||
BigDecimal longExitProb,
|
||||
BigDecimal shortExitProb,
|
||||
BigDecimal profitGivebackProb,
|
||||
BigDecimal reversalProb,
|
||||
BigDecimal stopRiskProb,
|
||||
BigDecimal stagnationProb,
|
||||
BigDecimal expectedGivebackBps,
|
||||
String modelVersion,
|
||||
String calibrationVersion,
|
||||
Map<String, Object> explanation
|
||||
BigDecimal longAdverseMoveBps,
|
||||
BigDecimal shortAdverseMoveBps,
|
||||
Map<String, BigDecimal> exitReasonScores
|
||||
) {
|
||||
private static final Set<String> REQUIRED_REASON_KEYS = Set.of(
|
||||
"adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob");
|
||||
|
||||
public ExitOutput {
|
||||
longExitProb = probability(longExitProb, "exit.longExitProb");
|
||||
shortExitProb = probability(shortExitProb, "exit.shortExitProb");
|
||||
profitGivebackProb = probability(profitGivebackProb, "exit.profitGivebackProb");
|
||||
reversalProb = probability(reversalProb, "exit.reversalProb");
|
||||
stopRiskProb = probability(stopRiskProb, "exit.stopRiskProb");
|
||||
stagnationProb = probability(stagnationProb, "exit.stagnationProb");
|
||||
expectedGivebackBps = nonNegative(expectedGivebackBps, "exit.expectedGivebackBps");
|
||||
modelVersion = requiredText(modelVersion, "exit.modelVersion");
|
||||
calibrationVersion = requiredText(calibrationVersion, "exit.calibrationVersion");
|
||||
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
|
||||
longAdverseMoveBps = nonNegative(longAdverseMoveBps, "exit.longAdverseMoveBps");
|
||||
shortAdverseMoveBps = nonNegative(shortAdverseMoveBps, "exit.shortAdverseMoveBps");
|
||||
exitReasonScores = checkedProbabilities(exitReasonScores, "exit.exitReasonScores");
|
||||
}
|
||||
|
||||
public BigDecimal reasonScore(String key) {
|
||||
return exitReasonScores.get(requiredText(key, "exit reason key"));
|
||||
}
|
||||
|
||||
private static Map<String, BigDecimal> checkedProbabilities(Map<String, BigDecimal> scores, String field) {
|
||||
Map<String, BigDecimal> source = scores == null ? Map.of() : scores;
|
||||
if (!source.keySet().containsAll(REQUIRED_REASON_KEYS)) {
|
||||
throw new IllegalArgumentException(field + " must contain " + REQUIRED_REASON_KEYS);
|
||||
}
|
||||
source.forEach((key, value) -> {
|
||||
requiredText(key, field + ".key");
|
||||
probability(value, field + "." + key);
|
||||
});
|
||||
return Map.copyOf(source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ public record PositionManagerInput(
|
||||
TraderDecisionCycle cycle,
|
||||
TraderMarketSnapshot snapshot,
|
||||
TraderModelOutput modelOutput,
|
||||
TraderPricePlanContext pricePlanContext,
|
||||
TraderPositionState positionState,
|
||||
TraderAccountState accountState,
|
||||
TraderExecutionState executionState,
|
||||
@@ -15,6 +16,7 @@ public record PositionManagerInput(
|
||||
cycle = Objects.requireNonNull(cycle, "cycle is required");
|
||||
snapshot = Objects.requireNonNull(snapshot, "snapshot is required");
|
||||
modelOutput = Objects.requireNonNull(modelOutput, "modelOutput is required");
|
||||
pricePlanContext = Objects.requireNonNull(pricePlanContext, "pricePlanContext is required");
|
||||
positionState = Objects.requireNonNull(positionState, "positionState is required");
|
||||
accountState = Objects.requireNonNull(accountState, "accountState is required");
|
||||
executionState = Objects.requireNonNull(executionState, "executionState is required");
|
||||
|
||||
@@ -2,37 +2,68 @@ package com.quantai.trader.domain;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public record RiskOutput(
|
||||
BigDecimal marketRiskProb,
|
||||
BigDecimal positionRiskProb,
|
||||
BigDecimal marketRiskSeverityBps,
|
||||
BigDecimal positionRiskSeverityBps,
|
||||
BigDecimal drawdownProb,
|
||||
BigDecimal expectedShortfallBps,
|
||||
BigDecimal volatilityExpansionProb,
|
||||
BigDecimal spikeProb,
|
||||
BigDecimal liquidityRiskProb,
|
||||
BigDecimal liquidityCapacityRatio,
|
||||
String modelVersion,
|
||||
String calibrationVersion,
|
||||
Map<String, Object> explanation
|
||||
BigDecimal longPositionRiskProb,
|
||||
BigDecimal shortPositionRiskProb,
|
||||
BigDecimal marketPathRiskBps,
|
||||
BigDecimal longPositionPathRiskBps,
|
||||
BigDecimal shortPositionPathRiskBps,
|
||||
Map<String, BigDecimal> riskReasonScores
|
||||
) {
|
||||
private static final Set<String> REQUIRED_REASON_KEYS = Set.of(
|
||||
"market_drawdown_prob", "volatility_expansion_prob", "spike_prob",
|
||||
"liquidity_deterioration_prob", "position_drawdown_prob");
|
||||
|
||||
public RiskOutput {
|
||||
marketRiskProb = probability(marketRiskProb, "risk.marketRiskProb");
|
||||
positionRiskProb = probability(positionRiskProb, "risk.positionRiskProb");
|
||||
marketRiskSeverityBps = nonNegative(marketRiskSeverityBps, "risk.marketRiskSeverityBps");
|
||||
positionRiskSeverityBps = nonNegative(positionRiskSeverityBps, "risk.positionRiskSeverityBps");
|
||||
drawdownProb = probability(drawdownProb, "risk.drawdownProb");
|
||||
expectedShortfallBps = nonNegative(expectedShortfallBps, "risk.expectedShortfallBps");
|
||||
volatilityExpansionProb = probability(volatilityExpansionProb, "risk.volatilityExpansionProb");
|
||||
spikeProb = probability(spikeProb, "risk.spikeProb");
|
||||
liquidityRiskProb = probability(liquidityRiskProb, "risk.liquidityRiskProb");
|
||||
liquidityCapacityRatio = nonNegative(liquidityCapacityRatio, "risk.liquidityCapacityRatio");
|
||||
modelVersion = requiredText(modelVersion, "risk.modelVersion");
|
||||
calibrationVersion = requiredText(calibrationVersion, "risk.calibrationVersion");
|
||||
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
|
||||
longPositionRiskProb = probability(longPositionRiskProb, "risk.longPositionRiskProb");
|
||||
shortPositionRiskProb = probability(shortPositionRiskProb, "risk.shortPositionRiskProb");
|
||||
marketPathRiskBps = nonNegative(marketPathRiskBps, "risk.marketPathRiskBps");
|
||||
longPositionPathRiskBps = nonNegative(longPositionPathRiskBps, "risk.longPositionPathRiskBps");
|
||||
shortPositionPathRiskBps = nonNegative(shortPositionPathRiskBps, "risk.shortPositionPathRiskBps");
|
||||
riskReasonScores = checkedProbabilities(riskReasonScores, "risk.riskReasonScores");
|
||||
}
|
||||
|
||||
public BigDecimal reasonScore(String key) {
|
||||
return riskReasonScores.get(requiredText(key, "risk reason key"));
|
||||
}
|
||||
|
||||
public BigDecimal sideRiskProbFor(PositionSide side) {
|
||||
if (side == PositionSide.LONG) {
|
||||
return longPositionRiskProb;
|
||||
}
|
||||
if (side == PositionSide.SHORT) {
|
||||
return shortPositionRiskProb;
|
||||
}
|
||||
throw new IllegalArgumentException("position risk requires LONG or SHORT side");
|
||||
}
|
||||
|
||||
public BigDecimal positionPathRiskBpsFor(PositionSide side) {
|
||||
if (side == PositionSide.LONG) {
|
||||
return longPositionPathRiskBps;
|
||||
}
|
||||
if (side == PositionSide.SHORT) {
|
||||
return shortPositionPathRiskBps;
|
||||
}
|
||||
throw new IllegalArgumentException("position path risk requires LONG or SHORT side");
|
||||
}
|
||||
|
||||
private static Map<String, BigDecimal> checkedProbabilities(Map<String, BigDecimal> scores, String field) {
|
||||
Map<String, BigDecimal> source = scores == null ? Map.of() : scores;
|
||||
if (!source.keySet().containsAll(REQUIRED_REASON_KEYS)) {
|
||||
throw new IllegalArgumentException(field + " must contain " + REQUIRED_REASON_KEYS);
|
||||
}
|
||||
source.forEach((key, value) -> {
|
||||
requiredText(key, field + ".key");
|
||||
probability(value, field + "." + key);
|
||||
});
|
||||
return Map.copyOf(source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,38 +2,28 @@ package com.quantai.trader.domain;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public record TraderModelOutput(
|
||||
String modelOutputId,
|
||||
String runId,
|
||||
String cycleId,
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
TraderModelOutputMetadata metadata,
|
||||
DirectionOutput direction,
|
||||
EntryOutput entry,
|
||||
ContinueOutput continuation,
|
||||
ExitOutput exit,
|
||||
RiskOutput risk,
|
||||
BigDecimal uncertainty,
|
||||
BigDecimal oodScore,
|
||||
Map<String, Object> explanation
|
||||
RiskOutput risk
|
||||
) {
|
||||
public TraderModelOutput {
|
||||
modelOutputId = requiredText(modelOutputId, "modelOutputId");
|
||||
runId = requiredText(runId, "runId");
|
||||
cycleId = requiredText(cycleId, "cycleId");
|
||||
modelBundleVersion = requiredText(modelBundleVersion, "modelBundleVersion");
|
||||
calibrationBundleVersion = requiredText(calibrationBundleVersion, "calibrationBundleVersion");
|
||||
metadata = Objects.requireNonNull(metadata, "metadata is required");
|
||||
direction = Objects.requireNonNull(direction, "direction is required");
|
||||
entry = Objects.requireNonNull(entry, "entry is required");
|
||||
continuation = Objects.requireNonNull(continuation, "continuation is required");
|
||||
exit = Objects.requireNonNull(exit, "exit is required");
|
||||
risk = Objects.requireNonNull(risk, "risk is required");
|
||||
uncertainty = probability(uncertainty, "uncertainty");
|
||||
oodScore = probability(oodScore, "oodScore");
|
||||
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.quantai.trader.domain;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
|
||||
public record TraderModelOutputMetadata(
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
Map<String, String> modelVersions,
|
||||
Map<String, String> calibrationVersions,
|
||||
String featureSchemaHash,
|
||||
String featureOrderHash,
|
||||
String outputSchemaHash,
|
||||
BigDecimal uncertainty,
|
||||
BigDecimal oodScore
|
||||
) {
|
||||
public TraderModelOutputMetadata {
|
||||
modelBundleVersion = requiredText(modelBundleVersion, "metadata.modelBundleVersion");
|
||||
calibrationBundleVersion = requiredText(calibrationBundleVersion, "metadata.calibrationBundleVersion");
|
||||
modelVersions = checkedTextMap(modelVersions, "metadata.modelVersions");
|
||||
calibrationVersions = checkedTextMap(calibrationVersions, "metadata.calibrationVersions");
|
||||
featureSchemaHash = requiredText(featureSchemaHash, "metadata.featureSchemaHash");
|
||||
featureOrderHash = requiredText(featureOrderHash, "metadata.featureOrderHash");
|
||||
outputSchemaHash = requiredText(outputSchemaHash, "metadata.outputSchemaHash");
|
||||
uncertainty = probability(uncertainty, "metadata.uncertainty");
|
||||
oodScore = probability(oodScore, "metadata.oodScore");
|
||||
}
|
||||
|
||||
private static Map<String, String> checkedTextMap(Map<String, String> values, String field) {
|
||||
Map<String, String> source = values == null ? Map.of() : values;
|
||||
source.forEach((key, value) -> {
|
||||
requiredText(key, field + ".key");
|
||||
requiredText(value, field + "." + key);
|
||||
});
|
||||
return Map.copyOf(source);
|
||||
}
|
||||
}
|
||||
@@ -81,22 +81,22 @@ public record TraderPmConfig(
|
||||
BigDecimal closePositionRiskProb,
|
||||
BigDecimal closeMarketRiskProb,
|
||||
BigDecimal closeContinueMax,
|
||||
BigDecimal reduceGivebackProb,
|
||||
BigDecimal reduceAdverseMoveProb,
|
||||
BigDecimal reduceContinueMin,
|
||||
BigDecimal reduceContinueMax,
|
||||
BigDecimal minProfitForReduceBps,
|
||||
BigDecimal maxExpectedShortfallBps
|
||||
BigDecimal maxPositionPathRiskBps
|
||||
) {
|
||||
public ExitRuleConfig {
|
||||
closeExitProb = probability(closeExitProb, "exit.closeExitProb");
|
||||
closePositionRiskProb = probability(closePositionRiskProb, "exit.closePositionRiskProb");
|
||||
closeMarketRiskProb = probability(closeMarketRiskProb, "exit.closeMarketRiskProb");
|
||||
closeContinueMax = probability(closeContinueMax, "exit.closeContinueMax");
|
||||
reduceGivebackProb = probability(reduceGivebackProb, "exit.reduceGivebackProb");
|
||||
reduceAdverseMoveProb = probability(reduceAdverseMoveProb, "exit.reduceAdverseMoveProb");
|
||||
reduceContinueMin = probability(reduceContinueMin, "exit.reduceContinueMin");
|
||||
reduceContinueMax = probability(reduceContinueMax, "exit.reduceContinueMax");
|
||||
minProfitForReduceBps = nonNegative(minProfitForReduceBps, "exit.minProfitForReduceBps");
|
||||
maxExpectedShortfallBps = nonNegative(maxExpectedShortfallBps, "exit.maxExpectedShortfallBps");
|
||||
maxPositionPathRiskBps = nonNegative(maxPositionPathRiskBps, "exit.maxPositionPathRiskBps");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.quantai.trader.domain;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
public record TraderPricePlanContext(
|
||||
String pricePlanId,
|
||||
String pricePlanConfigHash,
|
||||
BigDecimal stopDistanceBps,
|
||||
BigDecimal targetDistanceBps,
|
||||
int maxHoldMinutes,
|
||||
BigDecimal costBps
|
||||
) {
|
||||
public TraderPricePlanContext {
|
||||
pricePlanId = requiredText(pricePlanId, "pricePlan.pricePlanId");
|
||||
pricePlanConfigHash = requiredText(pricePlanConfigHash, "pricePlan.pricePlanConfigHash");
|
||||
stopDistanceBps = positive(stopDistanceBps, "pricePlan.stopDistanceBps");
|
||||
targetDistanceBps = positive(targetDistanceBps, "pricePlan.targetDistanceBps");
|
||||
if (maxHoldMinutes <= 0) {
|
||||
throw new IllegalArgumentException("pricePlan.maxHoldMinutes must be > 0");
|
||||
}
|
||||
costBps = nonNegative(costBps, "pricePlan.costBps");
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,11 @@ public enum TraderErrorCode {
|
||||
TRADER_RISK_BLOCKED,
|
||||
TRADER_EXECUTION_BLOCKED,
|
||||
TRADER_FEEDBACK_INVALID,
|
||||
TRADER_REQUEST_INVALID,
|
||||
TRADER_P0_MODE_BLOCKED,
|
||||
TRADER_KILL_SWITCH_ACTIVE,
|
||||
TRADER_RUNTIME_CONTROL_BLOCKED,
|
||||
TRADER_OUTBOX_BLOCKED,
|
||||
TRADER_ACTIVE_POINTER_MISMATCH,
|
||||
TRADER_PERSISTENCE_FAILED
|
||||
}
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
package com.quantai.trader.feature;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.artifact.TraderArtifactBundle;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
@Component
|
||||
public class TraderFeatureVectorBuilder {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderFeatureVectorBuilder.class);
|
||||
private static final int REQUIRED_FEATURE_COUNT = 39;
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ConcurrentMap<String, List<String>> featureOrderCache = new ConcurrentHashMap<>();
|
||||
|
||||
public TraderFeatureVectorBuilder(TraderProperties properties, ObjectMapper objectMapper) {
|
||||
this.properties = properties;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
public float[] build(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
|
||||
TraderModelManifest referenceManifest = referenceManifest(bundle);
|
||||
if (!referenceManifest.featureVersion().equals(snapshot.featureVersion())) {
|
||||
log.warn("event=trader.features.version_mismatch runId={} cycleId={} snapshotFeatureVersion={} modelFeatureVersion={}",
|
||||
snapshot.runId(), snapshot.cycleId(), snapshot.featureVersion(), referenceManifest.featureVersion());
|
||||
throw modelInputException("snapshot feature version does not match model feature version: "
|
||||
+ snapshot.featureVersion() + " != " + referenceManifest.featureVersion());
|
||||
}
|
||||
List<String> featureOrder = featureOrder(referenceManifest);
|
||||
rejectUnexpectedFeatures(snapshot, featureOrder);
|
||||
float[] values = new float[featureOrder.size()];
|
||||
for (int index = 0; index < featureOrder.size(); index++) {
|
||||
String featureName = featureOrder.get(index);
|
||||
Object rawValue = snapshot.featureJson().get(featureName);
|
||||
values[index] = toFloat(snapshot, rawValue, featureName, index);
|
||||
}
|
||||
log.debug("event=trader.features.vector_built runId={} cycleId={} featureVersion={} featureCount={} featureOrderHash={}",
|
||||
snapshot.runId(), snapshot.cycleId(), snapshot.featureVersion(), values.length, referenceManifest.featureOrderHash());
|
||||
return values;
|
||||
}
|
||||
|
||||
public List<String> featureOrder(TraderArtifactBundle bundle) {
|
||||
return featureOrder(referenceManifest(bundle));
|
||||
}
|
||||
|
||||
private TraderModelManifest referenceManifest(TraderArtifactBundle bundle) {
|
||||
return bundle.modelManifests().stream()
|
||||
.filter(manifest -> "DIRECTION".equals(manifest.modelType()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> modelInputException("DIRECTION model manifest is required for feature order"));
|
||||
}
|
||||
|
||||
private List<String> featureOrder(TraderModelManifest manifest) {
|
||||
String cacheKey = manifest.featureOrderHash() + "|" + manifest.featureOrderPath();
|
||||
// 特征顺序是模型包契约的一部分,按 hash 缓存,避免每轮重复读文件。
|
||||
return featureOrderCache.computeIfAbsent(cacheKey, ignored -> readFeatureOrder(manifest));
|
||||
}
|
||||
|
||||
private List<String> readFeatureOrder(TraderModelManifest manifest) {
|
||||
Path path = Path.of(properties.artifact().artifactRoot()).resolve(manifest.featureOrderPath());
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw modelInputException("feature_order.json is missing: " + path);
|
||||
}
|
||||
try {
|
||||
JsonNode root = objectMapper.readTree(path.toFile());
|
||||
if (!root.isArray() || root.size() != REQUIRED_FEATURE_COUNT) {
|
||||
throw modelInputException("feature_order.json must contain exactly " + REQUIRED_FEATURE_COUNT + " fields: " + path);
|
||||
}
|
||||
List<String> order = StreamSupport.stream(root.spliterator(), false)
|
||||
.map(JsonNode::asText)
|
||||
.toList();
|
||||
Set<String> unique = new LinkedHashSet<>(order);
|
||||
if (unique.size() != order.size() || order.stream().anyMatch(String::isBlank)) {
|
||||
throw modelInputException("feature_order.json contains duplicate or blank feature names: " + path);
|
||||
}
|
||||
log.info("event=trader.features.order_loaded featureOrderPath={} featureOrderHash={} featureCount={}",
|
||||
manifest.featureOrderPath(), manifest.featureOrderHash(), order.size());
|
||||
return order;
|
||||
} catch (IOException exception) {
|
||||
throw modelInputException("feature_order.json cannot be read: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
private void rejectUnexpectedFeatures(TraderMarketSnapshot snapshot, List<String> featureOrder) {
|
||||
Set<String> allowed = Set.copyOf(featureOrder);
|
||||
List<String> unexpected = snapshot.featureJson().keySet().stream()
|
||||
.filter(key -> !allowed.contains(key))
|
||||
.sorted()
|
||||
.toList();
|
||||
if (!unexpected.isEmpty()) {
|
||||
log.warn("event=trader.features.unexpected_fields runId={} cycleId={} unexpectedFields={}",
|
||||
snapshot.runId(), snapshot.cycleId(), unexpected);
|
||||
throw modelInputException("snapshot featureJson contains fields outside feature_order.json: " + unexpected);
|
||||
}
|
||||
}
|
||||
|
||||
private float toFloat(TraderMarketSnapshot snapshot, Object rawValue, String featureName, int index) {
|
||||
if (rawValue == null) {
|
||||
log.warn("event=trader.features.missing runId={} cycleId={} featureIndex={} featureName={}",
|
||||
snapshot.runId(), snapshot.cycleId(), index + 1, featureName);
|
||||
throw modelInputException("snapshot feature is missing: " + featureName);
|
||||
}
|
||||
double value;
|
||||
if (rawValue instanceof BigDecimal decimal) {
|
||||
value = decimal.doubleValue();
|
||||
} else if (rawValue instanceof Number number) {
|
||||
value = number.doubleValue();
|
||||
} else {
|
||||
log.warn("event=trader.features.non_numeric runId={} cycleId={} featureIndex={} featureName={} valueType={}",
|
||||
snapshot.runId(), snapshot.cycleId(), index + 1, featureName, rawValue.getClass().getName());
|
||||
throw modelInputException("snapshot feature must be numeric: " + featureName);
|
||||
}
|
||||
if (!Double.isFinite(value)) {
|
||||
log.warn("event=trader.features.non_finite runId={} cycleId={} featureIndex={} featureName={} value={}",
|
||||
snapshot.runId(), snapshot.cycleId(), index + 1, featureName, value);
|
||||
throw modelInputException("snapshot feature must be finite: " + featureName);
|
||||
}
|
||||
return (float) value;
|
||||
}
|
||||
|
||||
private static TraderException modelInputException(String message) {
|
||||
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.quantai.trader.feedback;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.domain.TraderAppFeedback;
|
||||
import com.quantai.trader.persistence.TraderJsonCodec;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
|
||||
@Repository
|
||||
public class JdbcTraderFeedbackRepository implements TraderFeedbackRepository {
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
private final TraderJsonCodec jsonCodec;
|
||||
|
||||
public JdbcTraderFeedbackRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
this.jsonCodec = new TraderJsonCodec(objectMapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void insert(TraderAppFeedback feedback) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_app_feedback
|
||||
(run_id, cycle_id, feedback_id, action_id, feedback_source, is_real_fill,
|
||||
order_id, order_status, app_received_time, exchange_ack_time, filled_time,
|
||||
filled_price, filled_quantity, fee, slippage_bps, reject_reason, raw_feedback_json)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
feedback.runId(), feedback.cycleId(), feedback.feedbackId(), feedback.actionId(),
|
||||
feedback.feedbackSource().name(), feedback.realFill(), feedback.orderId(), feedback.orderStatus(),
|
||||
feedback.appReceivedTime() == null ? null : Timestamp.from(feedback.appReceivedTime()),
|
||||
feedback.exchangeAckTime() == null ? null : Timestamp.from(feedback.exchangeAckTime()),
|
||||
feedback.filledTime() == null ? null : Timestamp.from(feedback.filledTime()),
|
||||
feedback.filledPrice(), feedback.filledQuantity(), feedback.fee(), feedback.slippageBps(),
|
||||
feedback.rejectReason(), jsonCodec.toJson(feedback.rawFeedbackJson()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.quantai.trader.feedback;
|
||||
|
||||
import com.quantai.trader.domain.TraderAppFeedback;
|
||||
|
||||
public interface TraderFeedbackRepository {
|
||||
void insert(TraderAppFeedback feedback);
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactBundle;
|
||||
import com.quantai.trader.artifact.TraderArtifactModelPolicy;
|
||||
import com.quantai.trader.domain.*;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
public class ArtifactTraderModelService implements TraderModelService {
|
||||
@Override
|
||||
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
|
||||
TraderArtifactModelPolicy policy = bundle.modelPolicy();
|
||||
DirectionOutput direction = direction(snapshot, bundle, policy.direction());
|
||||
return new TraderModelOutput(
|
||||
"model_output_" + snapshot.cycleId(),
|
||||
snapshot.runId(),
|
||||
snapshot.cycleId(),
|
||||
bundle.modelBundleVersion(),
|
||||
bundle.calibrationBundleVersion(),
|
||||
direction,
|
||||
entry(bundle, policy.entry()),
|
||||
continuation(bundle, policy.continuation()),
|
||||
exit(bundle, policy.exit()),
|
||||
risk(snapshot, bundle, policy.risk()),
|
||||
policy.uncertainty(),
|
||||
policy.oodScore(),
|
||||
Map.of("artifactPolicy", bundle.bundleHashSha256()));
|
||||
}
|
||||
|
||||
private DirectionOutput direction(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle, TraderArtifactModelPolicy.DirectionPolicy policy) {
|
||||
BigDecimal longProb = snapshot.markPrice().compareTo(snapshot.indexPrice()) >= 0
|
||||
? policy.longProbWhenMarkGteIndex()
|
||||
: policy.longProbWhenMarkLtIndex();
|
||||
BigDecimal shortProb = BigDecimal.ONE.subtract(longProb).subtract(policy.neutralProb()).max(BigDecimal.ZERO);
|
||||
BigDecimal neutralProb = BigDecimal.ONE.subtract(longProb).subtract(shortProb);
|
||||
return new DirectionOutput(longProb, shortProb, neutralProb, longProb.max(shortProb),
|
||||
longProb.subtract(shortProb).abs(), policy.expectedReturnBps(), policy.horizonMinutes(),
|
||||
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
|
||||
}
|
||||
|
||||
private EntryOutput entry(TraderArtifactBundle bundle, TraderArtifactModelPolicy.EntryPolicy policy) {
|
||||
return new EntryOutput(policy.longEntryProb(), policy.shortEntryProb(), policy.entryQualityScore(),
|
||||
policy.expectedEdgeBps(), policy.pricePlanId(), policy.pricePlanConfigHash(),
|
||||
policy.stopDistanceBps(), policy.targetDistanceBps(), policy.maxHoldMinutes(), policy.costBps(),
|
||||
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
|
||||
}
|
||||
|
||||
private ContinueOutput continuation(TraderArtifactBundle bundle, TraderArtifactModelPolicy.ContinuePolicy policy) {
|
||||
return new ContinueOutput(policy.longContinueProb(), policy.shortContinueProb(), policy.trendPersistenceProb(),
|
||||
policy.holdEdgeBps(), policy.continueVsExitEdgeBps(),
|
||||
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
|
||||
}
|
||||
|
||||
private ExitOutput exit(TraderArtifactBundle bundle, TraderArtifactModelPolicy.ExitPolicy policy) {
|
||||
return new ExitOutput(policy.longExitProb(), policy.shortExitProb(), policy.profitGivebackProb(),
|
||||
policy.reversalProb(), policy.stopRiskProb(), policy.stagnationProb(), policy.expectedGivebackBps(),
|
||||
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
|
||||
}
|
||||
|
||||
private RiskOutput risk(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle, TraderArtifactModelPolicy.RiskPolicy policy) {
|
||||
BigDecimal liquidityCapacity = snapshot.dataReady()
|
||||
? policy.liquidityCapacityRatioWhenReady()
|
||||
: policy.liquidityCapacityRatioWhenNotReady();
|
||||
return new RiskOutput(policy.marketRiskProb(), policy.positionRiskProb(), policy.marketRiskSeverityBps(),
|
||||
policy.positionRiskSeverityBps(), policy.drawdownProb(), policy.expectedShortfallBps(),
|
||||
policy.volatilityExpansionProb(), policy.spikeProb(), policy.liquidityRiskProb(), liquidityCapacity,
|
||||
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.artifact.TraderArtifactBundle;
|
||||
import com.quantai.trader.artifact.TraderCalibrationManifest;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.*;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Path;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class OnnxTraderModelService implements TraderModelService {
|
||||
private static final Logger log = LoggerFactory.getLogger(OnnxTraderModelService.class);
|
||||
private static final Pattern OUTPUT_REFERENCE = Pattern.compile("^([A-Za-z0-9_]+)\\[(\\d+)]$");
|
||||
private static final Set<String> REQUIRED_TYPES = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final TraderFeatureVectorBuilder featureVectorBuilder;
|
||||
private final TraderOnnxInferenceClient inferenceClient;
|
||||
|
||||
public OnnxTraderModelService(TraderProperties properties, ObjectMapper objectMapper,
|
||||
TraderFeatureVectorBuilder featureVectorBuilder,
|
||||
TraderOnnxInferenceClient inferenceClient) {
|
||||
this.properties = properties;
|
||||
this.objectMapper = objectMapper;
|
||||
this.featureVectorBuilder = featureVectorBuilder;
|
||||
this.inferenceClient = inferenceClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
|
||||
float[] features = featureVectorBuilder.build(snapshot, bundle);
|
||||
Map<String, TraderModelManifest> manifests = manifestsByType(bundle);
|
||||
Map<String, Map<String, BigDecimal>> rawOutputs = new LinkedHashMap<>();
|
||||
Path artifactRoot = Path.of(properties.artifact().artifactRoot());
|
||||
log.info("event=trader.model.onnx_started runId={} cycleId={} modelBundleVersion={} calibrationBundleVersion={} featureCount={}",
|
||||
snapshot.runId(), snapshot.cycleId(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), features.length);
|
||||
for (TraderModelManifest manifest : manifests.values()) {
|
||||
Path modelPath = artifactRoot.resolve(manifest.artifactPath());
|
||||
log.debug("event=trader.model.onnx_model_started runId={} cycleId={} modelName={} modelType={} modelPath={}",
|
||||
snapshot.runId(), snapshot.cycleId(), manifest.modelName(), manifest.modelType(), modelPath);
|
||||
Map<String, float[]> tensors = inferenceClient.infer(manifest, modelPath, features);
|
||||
Map<String, BigDecimal> mapped = mappedOutputs(manifest, tensors);
|
||||
rawOutputs.put(manifest.modelType(), mapped);
|
||||
log.debug("event=trader.model.onnx_model_mapped runId={} cycleId={} modelName={} outputFieldCount={}",
|
||||
snapshot.runId(), snapshot.cycleId(), manifest.modelName(), mapped.size());
|
||||
}
|
||||
|
||||
DirectionOutput direction = direction(rawOutputs.get("DIRECTION"), calibrator(bundle, manifests.get("DIRECTION")));
|
||||
EntryOutput entry = entry(rawOutputs.get("ENTRY"), calibrator(bundle, manifests.get("ENTRY")),
|
||||
bounds(manifests.get("ENTRY")));
|
||||
ContinueOutput continuation = continuation(rawOutputs.get("CONTINUE"), calibrator(bundle, manifests.get("CONTINUE")),
|
||||
bounds(manifests.get("CONTINUE")));
|
||||
ExitOutput exit = exit(rawOutputs.get("EXIT"), calibrator(bundle, manifests.get("EXIT")),
|
||||
bounds(manifests.get("EXIT")));
|
||||
RiskOutput risk = risk(rawOutputs.get("RISK"), calibrator(bundle, manifests.get("RISK")),
|
||||
bounds(manifests.get("RISK")));
|
||||
|
||||
TraderModelOutput output = new TraderModelOutput(
|
||||
"model_output_" + snapshot.cycleId(),
|
||||
snapshot.runId(),
|
||||
snapshot.cycleId(),
|
||||
metadata(bundle, manifests, direction, snapshot),
|
||||
direction,
|
||||
entry,
|
||||
continuation,
|
||||
exit,
|
||||
risk);
|
||||
log.info("event=trader.model.onnx_evaluated runId={} cycleId={} modelBundleVersion={} featureCount={} directionLong={} directionShort={} marketRisk={}",
|
||||
snapshot.runId(), snapshot.cycleId(), bundle.modelBundleVersion(), features.length,
|
||||
direction.longProb(), direction.shortProb(), risk.marketRiskProb());
|
||||
return output;
|
||||
}
|
||||
|
||||
private Map<String, TraderModelManifest> manifestsByType(TraderArtifactBundle bundle) {
|
||||
Map<String, TraderModelManifest> manifests = new LinkedHashMap<>();
|
||||
for (TraderModelManifest manifest : bundle.modelManifests()) {
|
||||
TraderModelManifest previous = manifests.putIfAbsent(manifest.modelType(), manifest);
|
||||
if (previous != null) {
|
||||
throw modelException("model_manifest.json must contain exactly one manifest per V4 model type: "
|
||||
+ manifest.modelType());
|
||||
}
|
||||
}
|
||||
if (!manifests.keySet().containsAll(REQUIRED_TYPES)) {
|
||||
throw modelException("model_manifest.json does not contain all five V4 model types");
|
||||
}
|
||||
return Map.copyOf(manifests);
|
||||
}
|
||||
|
||||
private Map<String, BigDecimal> mappedOutputs(TraderModelManifest manifest, Map<String, float[]> tensors) {
|
||||
Map<String, BigDecimal> mapped = new LinkedHashMap<>();
|
||||
for (Map.Entry<String, Object> entry : manifest.outputMapping().entrySet()) {
|
||||
// manifest 只允许 tensor[index] 显式映射,避免 Java 根据位置临时猜字段。
|
||||
if (!(entry.getValue() instanceof String referenceText)) {
|
||||
throw modelException("output_mapping_json value must be tensor[index]: " + manifest.modelName());
|
||||
}
|
||||
Matcher matcher = OUTPUT_REFERENCE.matcher(referenceText);
|
||||
if (!matcher.matches()) {
|
||||
throw modelException("output_mapping_json value must be tensor[index]: " + referenceText);
|
||||
}
|
||||
String tensorName = matcher.group(1);
|
||||
int index = Integer.parseInt(matcher.group(2));
|
||||
float[] values = tensors.get(tensorName);
|
||||
if (values == null) {
|
||||
throw modelException("ONNX output tensor is missing: model=" + manifest.modelName()
|
||||
+ ", tensor=" + tensorName);
|
||||
}
|
||||
if (index < 0 || index >= values.length) {
|
||||
throw modelException("ONNX output tensor index is out of range: model=" + manifest.modelName()
|
||||
+ ", mapping=" + referenceText);
|
||||
}
|
||||
mapped.put(entry.getKey(), BigDecimal.valueOf(values[index]));
|
||||
}
|
||||
return Map.copyOf(mapped);
|
||||
}
|
||||
|
||||
private DirectionOutput direction(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator) {
|
||||
return new DirectionOutput(
|
||||
probability(raw, calibrator, "long_prob", "longProb"),
|
||||
probability(raw, calibrator, "short_prob", "shortProb"),
|
||||
probability(raw, calibrator, "neutral_prob", "neutralProb"));
|
||||
}
|
||||
|
||||
private EntryOutput entry(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
|
||||
TraderOutputSchemaBounds bounds) {
|
||||
return new EntryOutput(
|
||||
probability(raw, calibrator, "long_entry_prob", "longEntryProb"),
|
||||
probability(raw, calibrator, "short_entry_prob", "shortEntryProb"),
|
||||
bps(raw, bounds, "long_expected_net_edge_bps", "longExpectedNetEdgeBps"),
|
||||
bps(raw, bounds, "short_expected_net_edge_bps", "shortExpectedNetEdgeBps"));
|
||||
}
|
||||
|
||||
private ContinueOutput continuation(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
|
||||
TraderOutputSchemaBounds bounds) {
|
||||
return new ContinueOutput(
|
||||
probability(raw, calibrator, "long_continue_prob", "longContinueProb"),
|
||||
probability(raw, calibrator, "short_continue_prob", "shortContinueProb"),
|
||||
bps(raw, bounds, "long_expected_continue_edge_bps", "longExpectedContinueEdgeBps"),
|
||||
bps(raw, bounds, "short_expected_continue_edge_bps", "shortExpectedContinueEdgeBps"));
|
||||
}
|
||||
|
||||
private ExitOutput exit(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
|
||||
TraderOutputSchemaBounds bounds) {
|
||||
return new ExitOutput(
|
||||
probability(raw, calibrator, "long_exit_prob", "longExitProb"),
|
||||
probability(raw, calibrator, "short_exit_prob", "shortExitProb"),
|
||||
bps(raw, bounds, "long_adverse_move_bps", "longAdverseMoveBps"),
|
||||
bps(raw, bounds, "short_adverse_move_bps", "shortAdverseMoveBps"),
|
||||
Map.of(
|
||||
"adverse_move_prob", probability(raw, calibrator, "adverse_move_prob", "adverse_move_prob"),
|
||||
"reversal_prob", probability(raw, calibrator, "reversal_prob", "reversal_prob"),
|
||||
"stop_hit_prob", probability(raw, calibrator, "stop_hit_prob", "stop_hit_prob"),
|
||||
"stagnation_prob", probability(raw, calibrator, "stagnation_prob", "stagnation_prob")));
|
||||
}
|
||||
|
||||
private RiskOutput risk(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
|
||||
TraderOutputSchemaBounds bounds) {
|
||||
return new RiskOutput(
|
||||
probability(raw, calibrator, "market_risk_prob", "marketRiskProb"),
|
||||
probability(raw, calibrator, "long_position_risk_prob", "longPositionRiskProb"),
|
||||
probability(raw, calibrator, "short_position_risk_prob", "shortPositionRiskProb"),
|
||||
bps(raw, bounds, "market_path_risk_bps", "marketPathRiskBps"),
|
||||
bps(raw, bounds, "long_position_path_risk_bps", "longPositionPathRiskBps"),
|
||||
bps(raw, bounds, "short_position_path_risk_bps", "shortPositionPathRiskBps"),
|
||||
Map.of(
|
||||
"market_drawdown_prob", probability(raw, calibrator, "market_drawdown_prob", "market_drawdown_prob"),
|
||||
"volatility_expansion_prob", probability(raw, calibrator, "volatility_expansion_prob", "volatility_expansion_prob"),
|
||||
"spike_prob", probability(raw, calibrator, "spike_prob", "spike_prob"),
|
||||
"liquidity_deterioration_prob", probability(raw, calibrator, "liquidity_deterioration_prob", "liquidity_deterioration_prob"),
|
||||
"position_drawdown_prob", probability(raw, calibrator, "position_drawdown_prob", "position_drawdown_prob")));
|
||||
}
|
||||
|
||||
private TraderModelOutputMetadata metadata(TraderArtifactBundle bundle, Map<String, TraderModelManifest> manifests,
|
||||
DirectionOutput direction, TraderMarketSnapshot snapshot) {
|
||||
Map<String, String> modelVersions = manifests.values().stream()
|
||||
.collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, TraderModelManifest::sourceHash));
|
||||
Map<String, String> calibrationVersions = bundle.calibrationManifests().stream()
|
||||
.collect(Collectors.toUnmodifiableMap(TraderCalibrationManifest::modelName, TraderCalibrationManifest::calibratorVersion));
|
||||
TraderModelManifest reference = manifests.get("DIRECTION");
|
||||
// 第一版不再单独输出 uncertainty 子模型,用方向最大概率的反面表示本轮不确定性。
|
||||
BigDecimal uncertainty = BigDecimal.ONE.subtract(direction.confidence());
|
||||
BigDecimal oodScore = dataQualityProbability(snapshot, "ood_score");
|
||||
return new TraderModelOutputMetadata(
|
||||
bundle.modelBundleVersion(),
|
||||
bundle.calibrationBundleVersion(),
|
||||
modelVersions,
|
||||
calibrationVersions,
|
||||
reference.featureSchemaHash(),
|
||||
reference.featureOrderHash(),
|
||||
reference.outputSchemaHash(),
|
||||
uncertainty,
|
||||
oodScore);
|
||||
}
|
||||
|
||||
private TraderProbabilityCalibrator calibrator(TraderArtifactBundle bundle, TraderModelManifest manifest) {
|
||||
TraderCalibrationManifest calibrationManifest = bundle.calibrationManifests().stream()
|
||||
.filter(item -> item.modelName().equals(manifest.modelName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> modelException("calibration manifest is missing for model: " + manifest.modelName()));
|
||||
Path path = Path.of(properties.artifact().artifactRoot()).resolve(calibrationManifest.calibratorPath());
|
||||
log.debug("event=trader.model.calibrator_loaded modelName={} calibratorVersion={} calibratorPath={}",
|
||||
manifest.modelName(), calibrationManifest.calibratorVersion(), path);
|
||||
return TraderProbabilityCalibrator.read(objectMapper, path, manifest.modelName());
|
||||
}
|
||||
|
||||
private TraderOutputSchemaBounds bounds(TraderModelManifest manifest) {
|
||||
return TraderOutputSchemaBounds.read(objectMapper,
|
||||
Path.of(properties.artifact().artifactRoot()).resolve(manifest.outputSchemaPath()));
|
||||
}
|
||||
|
||||
private BigDecimal probability(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
|
||||
String outputKey, String targetName) {
|
||||
BigDecimal value = requiredRaw(raw, outputKey);
|
||||
return calibrator.calibrate(targetName, value);
|
||||
}
|
||||
|
||||
private BigDecimal bps(Map<String, BigDecimal> raw, TraderOutputSchemaBounds bounds,
|
||||
String outputKey, String fieldName) {
|
||||
return bounds.clip(fieldName, requiredRaw(raw, outputKey));
|
||||
}
|
||||
|
||||
private BigDecimal requiredRaw(Map<String, BigDecimal> raw, String outputKey) {
|
||||
BigDecimal value = raw.get(outputKey);
|
||||
if (value == null) {
|
||||
throw modelException("ONNX mapped output is missing: " + outputKey);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private BigDecimal dataQualityProbability(TraderMarketSnapshot snapshot, String field) {
|
||||
Object raw = snapshot.dataQualityJson().get(field);
|
||||
BigDecimal value;
|
||||
if (raw instanceof BigDecimal decimal) {
|
||||
value = decimal;
|
||||
} else if (raw instanceof Number number) {
|
||||
value = BigDecimal.valueOf(number.doubleValue());
|
||||
} else {
|
||||
log.warn("event=trader.model.ood_missing runId={} cycleId={} field={}",
|
||||
snapshot.runId(), snapshot.cycleId(), field);
|
||||
throw modelException("snapshot dataQualityJson must contain numeric field: " + field);
|
||||
}
|
||||
if (value.compareTo(BigDecimal.ZERO) < 0 || value.compareTo(BigDecimal.ONE) > 0) {
|
||||
log.warn("event=trader.model.ood_out_of_range runId={} cycleId={} field={} value={}",
|
||||
snapshot.runId(), snapshot.cycleId(), field, value);
|
||||
throw modelException("snapshot dataQualityJson probability is outside [0,1]: " + field);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private static TraderException modelException(String message) {
|
||||
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OnnxValue;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
|
||||
@Component
|
||||
public class OrtTraderOnnxInferenceClient implements TraderOnnxInferenceClient, DisposableBean {
|
||||
private final OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
private final ConcurrentMap<Path, OrtSession> sessions = new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
|
||||
try {
|
||||
OrtSession session = sessions.computeIfAbsent(modelPath.toAbsolutePath().normalize(), this::createSession);
|
||||
try (OnnxTensor input = OnnxTensor.createTensor(environment, new float[][]{features});
|
||||
OrtSession.Result result = session.run(Map.of(manifest.inputTensorName(), input))) {
|
||||
Map<String, float[]> outputs = new LinkedHashMap<>();
|
||||
for (Map.Entry<String, OnnxValue> entry : result) {
|
||||
outputs.put(entry.getKey(), flatten(entry.getValue()));
|
||||
}
|
||||
return Map.copyOf(outputs);
|
||||
}
|
||||
} catch (OrtException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"ONNX inference failed for model: " + manifest.modelName());
|
||||
}
|
||||
}
|
||||
|
||||
private OrtSession createSession(Path path) {
|
||||
try {
|
||||
// ONNX session 建立成本高,外层按模型文件路径缓存,进程退出时统一关闭。
|
||||
return environment.createSession(path.toString(), new OrtSession.SessionOptions());
|
||||
} catch (OrtException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"ONNX model cannot be loaded: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
private float[] flatten(OnnxValue value) {
|
||||
try {
|
||||
Object raw = value.getValue();
|
||||
if (raw instanceof float[] values) {
|
||||
return Arrays.copyOf(values, values.length);
|
||||
}
|
||||
if (raw instanceof float[][] matrix && matrix.length == 1) {
|
||||
return Arrays.copyOf(matrix[0], matrix[0].length);
|
||||
}
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"ONNX output tensor must be FLOAT32 with shape [n] or [1,n]");
|
||||
} catch (OrtException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"ONNX output tensor cannot be read");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() throws Exception {
|
||||
for (OrtSession session : sessions.values()) {
|
||||
session.close();
|
||||
}
|
||||
sessions.clear();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactBundle;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.artifact.TraderReplayModelFixture;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.*;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class ReplayFixtureTraderModelService implements TraderModelService {
|
||||
private final TraderRunMode runMode;
|
||||
|
||||
public ReplayFixtureTraderModelService(TraderProperties properties) {
|
||||
this.runMode = properties.runMode();
|
||||
}
|
||||
|
||||
public ReplayFixtureTraderModelService() {
|
||||
this.runMode = TraderRunMode.REPLAY_SIM;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
|
||||
if (runMode != TraderRunMode.REPLAY_SIM) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"replay fixture model service only allows REPLAY_SIM");
|
||||
}
|
||||
TraderReplayModelFixture fixture = bundle.requireReplayModelFixture();
|
||||
DirectionOutput direction = direction(snapshot, fixture.direction());
|
||||
return new TraderModelOutput(
|
||||
"model_output_" + snapshot.cycleId(),
|
||||
snapshot.runId(),
|
||||
snapshot.cycleId(),
|
||||
metadata(bundle, fixture),
|
||||
direction,
|
||||
entry(fixture.entry()),
|
||||
continuation(fixture.continuation()),
|
||||
exit(fixture.exit()),
|
||||
risk(fixture.risk()));
|
||||
}
|
||||
|
||||
private TraderModelOutputMetadata metadata(TraderArtifactBundle bundle, TraderReplayModelFixture fixture) {
|
||||
Map<String, String> modelVersions = bundle.modelManifests().stream()
|
||||
.collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, ignored -> bundle.modelBundleVersion()));
|
||||
Map<String, String> calibrationVersions = bundle.modelManifests().stream()
|
||||
.collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, ignored -> bundle.calibrationBundleVersion()));
|
||||
return new TraderModelOutputMetadata(
|
||||
bundle.modelBundleVersion(),
|
||||
bundle.calibrationBundleVersion(),
|
||||
modelVersions,
|
||||
calibrationVersions,
|
||||
fixture.featureSchemaHash(),
|
||||
fixture.featureOrderHash(),
|
||||
fixture.outputSchemaHash(),
|
||||
fixture.uncertainty(),
|
||||
fixture.oodScore());
|
||||
}
|
||||
|
||||
private DirectionOutput direction(TraderMarketSnapshot snapshot, TraderReplayModelFixture.DirectionFixture fixture) {
|
||||
BigDecimal longProb = snapshot.markPrice().compareTo(snapshot.indexPrice()) >= 0
|
||||
? fixture.longProbWhenMarkGteIndex()
|
||||
: fixture.longProbWhenMarkLtIndex();
|
||||
BigDecimal shortProb = BigDecimal.ONE.subtract(longProb).subtract(fixture.neutralProb()).max(BigDecimal.ZERO);
|
||||
BigDecimal neutralProb = BigDecimal.ONE.subtract(longProb).subtract(shortProb);
|
||||
return new DirectionOutput(longProb, shortProb, neutralProb);
|
||||
}
|
||||
|
||||
private EntryOutput entry(TraderReplayModelFixture.EntryFixture fixture) {
|
||||
return new EntryOutput(fixture.longEntryProb(), fixture.shortEntryProb(),
|
||||
fixture.longExpectedNetEdgeBps(), fixture.shortExpectedNetEdgeBps());
|
||||
}
|
||||
|
||||
private ContinueOutput continuation(TraderReplayModelFixture.ContinueFixture fixture) {
|
||||
return new ContinueOutput(fixture.longContinueProb(), fixture.shortContinueProb(),
|
||||
fixture.longExpectedContinueEdgeBps(), fixture.shortExpectedContinueEdgeBps());
|
||||
}
|
||||
|
||||
private ExitOutput exit(TraderReplayModelFixture.ExitFixture fixture) {
|
||||
return new ExitOutput(fixture.longExitProb(), fixture.shortExitProb(),
|
||||
fixture.longAdverseMoveBps(), fixture.shortAdverseMoveBps(), fixture.exitReasonScores());
|
||||
}
|
||||
|
||||
private RiskOutput risk(TraderReplayModelFixture.RiskFixture fixture) {
|
||||
return new RiskOutput(fixture.marketRiskProb(), fixture.longPositionRiskProb(), fixture.shortPositionRiskProb(),
|
||||
fixture.marketPathRiskBps(), fixture.longPositionPathRiskBps(), fixture.shortPositionPathRiskBps(),
|
||||
fixture.riskReasonScores());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactBundle;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import com.quantai.trader.domain.TraderModelOutput;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Primary
|
||||
@Service
|
||||
public class RoutingTraderModelService implements TraderModelService {
|
||||
private static final Logger log = LoggerFactory.getLogger(RoutingTraderModelService.class);
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ReplayFixtureTraderModelService replayFixtureTraderModelService;
|
||||
private final OnnxTraderModelService onnxTraderModelService;
|
||||
|
||||
public RoutingTraderModelService(TraderProperties properties,
|
||||
ReplayFixtureTraderModelService replayFixtureTraderModelService,
|
||||
OnnxTraderModelService onnxTraderModelService) {
|
||||
this.properties = properties;
|
||||
this.replayFixtureTraderModelService = replayFixtureTraderModelService;
|
||||
this.onnxTraderModelService = onnxTraderModelService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
|
||||
if (properties.runMode() == TraderRunMode.REPLAY_SIM) {
|
||||
log.info("event=trader.model.route runId={} cycleId={} runMode=REPLAY_SIM modelService=ReplayFixtureTraderModelService",
|
||||
snapshot.runId(), snapshot.cycleId());
|
||||
return replayFixtureTraderModelService.evaluate(snapshot, bundle);
|
||||
}
|
||||
if (properties.runMode() == TraderRunMode.SHADOW) {
|
||||
log.info("event=trader.model.route runId={} cycleId={} runMode=SHADOW modelService=OnnxTraderModelService",
|
||||
snapshot.runId(), snapshot.cycleId());
|
||||
return onnxTraderModelService.evaluate(snapshot, bundle);
|
||||
}
|
||||
log.warn("event=trader.model.route_rejected runId={} cycleId={} runMode={}",
|
||||
snapshot.runId(), snapshot.cycleId(), properties.runMode());
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"P0 model service only supports REPLAY_SIM and SHADOW");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Map;
|
||||
|
||||
public interface TraderOnnxInferenceClient {
|
||||
Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features);
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
|
||||
final class TraderOutputSchemaBounds {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderOutputSchemaBounds.class);
|
||||
|
||||
private final Map<String, Range> ranges;
|
||||
|
||||
private TraderOutputSchemaBounds(Map<String, Range> ranges) {
|
||||
this.ranges = Map.copyOf(ranges);
|
||||
}
|
||||
|
||||
static TraderOutputSchemaBounds read(ObjectMapper objectMapper, Path path) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw modelException("output_schema.json is missing: " + path);
|
||||
}
|
||||
try {
|
||||
JsonNode root = objectMapper.readTree(path.toFile());
|
||||
Map<String, Range> ranges = new HashMap<>();
|
||||
collectRanges(root, ranges);
|
||||
if (ranges.isEmpty()) {
|
||||
throw modelException("output_schema.json must define field ranges: " + path);
|
||||
}
|
||||
log.debug("event=trader.output_schema.loaded path={} rangedFieldCount={}", path, ranges.size());
|
||||
return new TraderOutputSchemaBounds(ranges);
|
||||
} catch (IOException exception) {
|
||||
throw modelException("output_schema.json cannot be read: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
BigDecimal clip(String fieldName, BigDecimal value) {
|
||||
Range range = ranges.get(fieldName);
|
||||
if (range == null) {
|
||||
log.warn("event=trader.output_schema.range_missing field={}", fieldName);
|
||||
throw modelException("output_schema.json does not define range for field: " + fieldName);
|
||||
}
|
||||
if (value.compareTo(range.min()) < 0) {
|
||||
log.debug("event=trader.output_schema.clipped field={} raw={} clipped={}", fieldName, value, range.min());
|
||||
return range.min();
|
||||
}
|
||||
if (value.compareTo(range.max()) > 0) {
|
||||
log.debug("event=trader.output_schema.clipped field={} raw={} clipped={}", fieldName, value, range.max());
|
||||
return range.max();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private static void collectRanges(JsonNode node, Map<String, Range> ranges) {
|
||||
if (!node.isObject()) {
|
||||
return;
|
||||
}
|
||||
Iterator<Map.Entry<String, JsonNode>> iterator = node.properties().iterator();
|
||||
while (iterator.hasNext()) {
|
||||
Map.Entry<String, JsonNode> entry = iterator.next();
|
||||
JsonNode child = entry.getValue();
|
||||
JsonNode rangeNode = child.path("range");
|
||||
if (rangeNode.isArray() && rangeNode.size() == 2 && rangeNode.get(0).isNumber() && rangeNode.get(1).isNumber()) {
|
||||
ranges.put(entry.getKey(), new Range(rangeNode.get(0).decimalValue(), rangeNode.get(1).decimalValue()));
|
||||
}
|
||||
collectRanges(child, ranges);
|
||||
}
|
||||
}
|
||||
|
||||
private static TraderException modelException(String message) {
|
||||
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
|
||||
}
|
||||
|
||||
private record Range(BigDecimal min, BigDecimal max) {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
final class TraderProbabilityCalibrator {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderProbabilityCalibrator.class);
|
||||
|
||||
private final String modelName;
|
||||
private final String method;
|
||||
private final BigDecimal clipMin;
|
||||
private final BigDecimal clipMax;
|
||||
private final Map<String, List<Bin>> targets;
|
||||
|
||||
private TraderProbabilityCalibrator(String modelName, String method, BigDecimal clipMin, BigDecimal clipMax,
|
||||
Map<String, List<Bin>> targets) {
|
||||
this.modelName = modelName;
|
||||
this.method = method;
|
||||
this.clipMin = clipMin;
|
||||
this.clipMax = clipMax;
|
||||
this.targets = targets;
|
||||
}
|
||||
|
||||
static TraderProbabilityCalibrator read(ObjectMapper objectMapper, Path path, String modelName) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw modelException("calibrator is missing: " + path);
|
||||
}
|
||||
try {
|
||||
JsonNode root = objectMapper.readTree(path.toFile());
|
||||
String method = requiredText(root, "method", path);
|
||||
if (!"BINNING".equals(method)) {
|
||||
throw modelException("probability calibrator method must be BINNING for model " + modelName);
|
||||
}
|
||||
JsonNode clip = root.path("clip");
|
||||
BigDecimal clipMin = decimal(clip.path("min"), "clip.min", path);
|
||||
BigDecimal clipMax = decimal(clip.path("max"), "clip.max", path);
|
||||
JsonNode targetsNode = root.path("targets");
|
||||
if (!targetsNode.isObject() || targetsNode.isEmpty()) {
|
||||
throw modelException("calibrator targets must not be empty: " + path);
|
||||
}
|
||||
Map<String, List<Bin>> targets = StreamSupport.stream(targetsNode.properties().spliterator(), false)
|
||||
.collect(java.util.stream.Collectors.toUnmodifiableMap(
|
||||
Map.Entry::getKey,
|
||||
entry -> bins(entry.getValue().path("bins"), entry.getKey(), path)));
|
||||
log.debug("event=trader.calibrator.loaded modelName={} method={} targetCount={} path={}",
|
||||
modelName, method, targets.size(), path);
|
||||
return new TraderProbabilityCalibrator(modelName, method, clipMin, clipMax, targets);
|
||||
} catch (IOException exception) {
|
||||
throw modelException("calibrator cannot be read: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
BigDecimal calibrate(String targetName, BigDecimal rawProbability) {
|
||||
if (!"BINNING".equals(method)) {
|
||||
throw modelException("unsupported calibrator method for model " + modelName);
|
||||
}
|
||||
List<Bin> bins = targets.get(targetName);
|
||||
if (bins == null || bins.isEmpty()) {
|
||||
log.warn("event=trader.calibrator.target_missing modelName={} target={}", modelName, targetName);
|
||||
throw modelException("calibrator target is missing: model=" + modelName + ", target=" + targetName);
|
||||
}
|
||||
for (Bin bin : bins) {
|
||||
if (rawProbability.compareTo(bin.min()) >= 0 && rawProbability.compareTo(bin.max()) <= 0) {
|
||||
BigDecimal calibrated = bin.calibrated();
|
||||
if (calibrated.compareTo(clipMin) < 0 || calibrated.compareTo(clipMax) > 0
|
||||
|| calibrated.compareTo(BigDecimal.ZERO) < 0 || calibrated.compareTo(BigDecimal.ONE) > 0) {
|
||||
log.warn("event=trader.calibrator.output_out_of_range modelName={} target={} raw={} calibrated={} clipMin={} clipMax={}",
|
||||
modelName, targetName, rawProbability, calibrated, clipMin, clipMax);
|
||||
throw modelException("calibrated probability is outside allowed range: model="
|
||||
+ modelName + ", target=" + targetName);
|
||||
}
|
||||
return calibrated;
|
||||
}
|
||||
}
|
||||
log.warn("event=trader.calibrator.bin_missing modelName={} target={} raw={}",
|
||||
modelName, targetName, rawProbability);
|
||||
throw modelException("raw probability does not match any calibrator bin: model="
|
||||
+ modelName + ", target=" + targetName + ", value=" + rawProbability);
|
||||
}
|
||||
|
||||
private static List<Bin> bins(JsonNode binsNode, String targetName, Path path) {
|
||||
if (!binsNode.isArray() || binsNode.isEmpty()) {
|
||||
throw modelException("calibrator target has no bins: " + targetName + " in " + path);
|
||||
}
|
||||
List<Bin> bins = new ArrayList<>();
|
||||
for (JsonNode node : binsNode) {
|
||||
bins.add(new Bin(
|
||||
decimal(node.path("min"), targetName + ".min", path),
|
||||
decimal(node.path("max"), targetName + ".max", path),
|
||||
decimal(node.path("calibrated"), targetName + ".calibrated", path)));
|
||||
}
|
||||
return List.copyOf(bins);
|
||||
}
|
||||
|
||||
private static String requiredText(JsonNode node, String field, Path path) {
|
||||
String value = node.path(field).asText("");
|
||||
if (value.isBlank()) {
|
||||
throw modelException("calibrator field is required: " + field + " in " + path);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private static BigDecimal decimal(JsonNode node, String field, Path path) {
|
||||
if (!node.isNumber()) {
|
||||
throw modelException("calibrator numeric field is required: " + field + " in " + path);
|
||||
}
|
||||
return node.decimalValue();
|
||||
}
|
||||
|
||||
private static TraderException modelException(String message) {
|
||||
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
|
||||
}
|
||||
|
||||
private record Bin(BigDecimal min, BigDecimal max, BigDecimal calibrated) {
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.quantai.trader.outbox;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.persistence.TraderJsonCodec;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -35,4 +36,31 @@ public class JdbcTraderOutboxRepository implements TraderOutboxRepository {
|
||||
log.info("event=trader.outbox.inserted runId={} cycleId={} destination={} aggregateType={} aggregateId={} status={}",
|
||||
event.runId(), event.cycleId(), event.destination(), event.aggregateType(), event.aggregateId(), event.status());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markSent(String outboxId) {
|
||||
jdbcTemplate.update("""
|
||||
update trader_outbox
|
||||
set status = 'SENT', updated_at = current_timestamp(3)
|
||||
where outbox_id = ?
|
||||
""",
|
||||
outboxId);
|
||||
log.info("event=trader.outbox.sent outboxId={}", outboxId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasUnsentExposureIncrease(String runId, String symbol, PositionSide side) {
|
||||
Integer count = jdbcTemplate.queryForObject("""
|
||||
select count(*)
|
||||
from trader_outbox o
|
||||
join trader_action a on a.run_id = o.run_id and a.action_id = o.aggregate_id
|
||||
where o.run_id = ?
|
||||
and a.symbol = ?
|
||||
and a.side = ?
|
||||
and a.action_type in ('OPEN_LONG','OPEN_SHORT','ADD_LONG','ADD_SHORT')
|
||||
and o.status in ('PENDING','SENDING','FAILED')
|
||||
""",
|
||||
Integer.class, runId, symbol, side.name());
|
||||
return count != null && count > 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.quantai.trader.outbox;
|
||||
|
||||
import com.quantai.trader.domain.TraderAction;
|
||||
import com.quantai.trader.domain.TraderAppFeedback;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import com.quantai.trader.enums.FeedbackSource;
|
||||
import com.quantai.trader.feedback.TraderFeedbackRepository;
|
||||
import com.quantai.trader.replay.state.TraderPostActionStateRepository;
|
||||
import com.quantai.trader.replay.state.TraderReplayState;
|
||||
import com.quantai.trader.replay.state.TraderReplayStateStore;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class P0OutboxDispatcher implements TraderOutboxDispatcher {
|
||||
private static final Logger log = LoggerFactory.getLogger(P0OutboxDispatcher.class);
|
||||
|
||||
private final TraderOutboxRepository outboxRepository;
|
||||
private final TraderFeedbackRepository feedbackRepository;
|
||||
private final TraderReplayStateStore stateStore;
|
||||
private final TraderPostActionStateRepository postActionStateRepository;
|
||||
|
||||
public P0OutboxDispatcher(TraderOutboxRepository outboxRepository,
|
||||
TraderFeedbackRepository feedbackRepository,
|
||||
TraderReplayStateStore stateStore,
|
||||
TraderPostActionStateRepository postActionStateRepository) {
|
||||
this.outboxRepository = outboxRepository;
|
||||
this.feedbackRepository = feedbackRepository;
|
||||
this.stateStore = stateStore;
|
||||
this.postActionStateRepository = postActionStateRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Transactional
|
||||
public void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination) {
|
||||
TraderAppFeedback feedback = p0Feedback(action, snapshot, destination);
|
||||
feedbackRepository.insert(feedback);
|
||||
outboxRepository.markSent("outbox_" + action.actionId());
|
||||
TraderReplayState nextState = stateStore.advance(currentState, action, snapshot);
|
||||
postActionStateRepository.insertPostActionState(nextState);
|
||||
log.info("event=trader.outbox.dispatched runId={} cycleId={} action={} destination={} feedbackId={}",
|
||||
action.runId(), action.cycleId(), action.actionType(), destination, feedback.feedbackId());
|
||||
}
|
||||
|
||||
private TraderAppFeedback p0Feedback(TraderAction action, TraderMarketSnapshot snapshot, String destination) {
|
||||
FeedbackSource source = switch (destination) {
|
||||
case "REPLAY_SIM_EXECUTION" -> FeedbackSource.REPLAY_SIMULATOR;
|
||||
case "SHADOW_RECORDER" -> FeedbackSource.SHADOW_APP;
|
||||
default -> throw new IllegalArgumentException("P0 outbox destination is not allowed: " + destination);
|
||||
};
|
||||
Instant now = snapshot.snapshotTime();
|
||||
return new TraderAppFeedback(
|
||||
"feedback_" + action.actionId(),
|
||||
action.runId(),
|
||||
action.cycleId(),
|
||||
action.actionId(),
|
||||
source,
|
||||
false,
|
||||
null,
|
||||
"RECORDED",
|
||||
now,
|
||||
now,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
Map.of("destination", destination, "actionType", action.actionType().name()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.quantai.trader.outbox;
|
||||
|
||||
import com.quantai.trader.domain.TraderAction;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import com.quantai.trader.replay.state.TraderReplayState;
|
||||
|
||||
public interface TraderOutboxDispatcher {
|
||||
void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination);
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
package com.quantai.trader.outbox;
|
||||
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
|
||||
public interface TraderOutboxRepository {
|
||||
void insert(TraderOutboxEvent event);
|
||||
|
||||
void markSent(String outboxId);
|
||||
|
||||
boolean hasUnsentExposureIncrease(String runId, String symbol, PositionSide side);
|
||||
}
|
||||
|
||||
@@ -30,10 +30,16 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter
|
||||
TraderMarketSnapshot snapshot,
|
||||
TraderModelOutput modelOutput,
|
||||
TraderPositionState positionState,
|
||||
TraderAccountState accountState,
|
||||
TraderExecutionState executionState,
|
||||
TraderPositionManagerDecision pmDecision,
|
||||
TraderRiskDecision riskDecision,
|
||||
TraderAction action) {
|
||||
upsertRun(cycle);
|
||||
insertMarketSnapshot(snapshot);
|
||||
insertPositionState(positionState, "PM_INPUT");
|
||||
insertAccountState(accountState, "PM_INPUT");
|
||||
insertExecutionState(executionState, "PM_INPUT");
|
||||
insertCycle(cycle, positionState, riskDecision);
|
||||
insertModelOutput(modelOutput, snapshot);
|
||||
insertPmDecision(pmDecision);
|
||||
@@ -66,6 +72,63 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter
|
||||
Timestamp.from(cycle.cycleTime()));
|
||||
}
|
||||
|
||||
void insertMarketSnapshot(TraderMarketSnapshot snapshot) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_market_snapshot
|
||||
(run_id, cycle_id, snapshot_id, symbol, snapshot_time, feature_version,
|
||||
mark_price, index_price, spread_bps, funding_rate_bps,
|
||||
depth_notional_5bps, depth_notional_10bps, depth_notional_25bps,
|
||||
data_ready, feature_json, data_quality_json)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
snapshot.runId(), snapshot.cycleId(), snapshot.snapshotId(), snapshot.symbol(),
|
||||
Timestamp.from(snapshot.snapshotTime()), snapshot.featureVersion(), snapshot.markPrice(),
|
||||
snapshot.indexPrice(), snapshot.spreadBps(), snapshot.fundingRateBps(),
|
||||
snapshot.depthNotional5Bps(), snapshot.depthNotional10Bps(), snapshot.depthNotional25Bps(),
|
||||
snapshot.dataReady(), jsonCodec.toJson(snapshot.featureJson()), jsonCodec.toJson(snapshot.dataQualityJson()));
|
||||
}
|
||||
|
||||
void insertPositionState(TraderPositionState state, String role) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_position_state
|
||||
(run_id, cycle_id, position_state_id, state_role, symbol, side, position_ratio,
|
||||
average_entry_price, current_price, unrealized_pnl_bps, liquidation_buffer_bps,
|
||||
add_count, remaining_add_capacity, last_add_time)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
state.runId(), state.cycleId(), state.positionStateId(), role, state.symbol(), state.side().name(),
|
||||
state.positionRatio(), state.averageEntryPrice(), state.currentPrice(), state.unrealizedPnlBps(),
|
||||
state.liquidationBufferBps(), state.addCount(), state.remainingAddCapacity(),
|
||||
state.lastAddTime() == null ? null : Timestamp.from(state.lastAddTime()));
|
||||
}
|
||||
|
||||
void insertAccountState(TraderAccountState state, String role) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_account_state
|
||||
(run_id, cycle_id, account_state_id, state_role, daily_drawdown_bps,
|
||||
portfolio_exposure_ratio, remaining_symbol_capacity_ratio, consecutive_losses)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
state.runId(), state.cycleId(), state.accountStateId(), role, state.dailyDrawdownBps(),
|
||||
state.portfolioExposureRatio(), state.remainingSymbolCapacityRatio(), state.consecutiveLosses());
|
||||
}
|
||||
|
||||
void insertExecutionState(TraderExecutionState state, String role) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_execution_state
|
||||
(run_id, cycle_id, execution_state_id, state_role, symbol, open_orders_json,
|
||||
expected_slippage_bps, exchange_latency_ms, api_error_count, maker_fee_bps,
|
||||
taker_fee_bps, min_notional, price_tick_size, lot_size_step_size,
|
||||
market_lot_size_step_size, liquidity_capacity_ratio)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
state.runId(), state.cycleId(), state.executionStateId(), role, state.symbol(),
|
||||
jsonCodec.toJson(state.openOrders()), state.expectedSlippageBps(), state.exchangeLatencyMs(),
|
||||
state.apiErrorCount(), state.makerFeeBps(), state.takerFeeBps(), state.minNotional(),
|
||||
state.priceTickSize(), state.lotSizeStepSize(), state.marketLotSizeStepSize(),
|
||||
state.liquidityCapacityRatio());
|
||||
}
|
||||
|
||||
void insertCycle(TraderDecisionCycle cycle, TraderPositionState positionState, TraderRiskDecision riskDecision) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_decision_cycle
|
||||
@@ -83,14 +146,15 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter
|
||||
insert into trader_model_output
|
||||
(run_id, cycle_id, model_output_id, model_bundle_version, calibration_bundle_version,
|
||||
direction_json, entry_json, continue_json, exit_json, risk_json,
|
||||
uncertainty, ood_score, usable, blocker)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
metadata_json, uncertainty, ood_score, usable, blocker)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
modelOutput.runId(), modelOutput.cycleId(), modelOutput.modelOutputId(),
|
||||
modelOutput.modelBundleVersion(), modelOutput.calibrationBundleVersion(),
|
||||
modelOutput.metadata().modelBundleVersion(), modelOutput.metadata().calibrationBundleVersion(),
|
||||
jsonCodec.toJson(modelOutput.direction()), jsonCodec.toJson(modelOutput.entry()),
|
||||
jsonCodec.toJson(modelOutput.continuation()), jsonCodec.toJson(modelOutput.exit()),
|
||||
jsonCodec.toJson(modelOutput.risk()), modelOutput.uncertainty(), modelOutput.oodScore(),
|
||||
jsonCodec.toJson(modelOutput.risk()), jsonCodec.toJson(modelOutput.metadata()),
|
||||
modelOutput.metadata().uncertainty(), modelOutput.metadata().oodScore(),
|
||||
snapshot.dataReady(), snapshot.dataReady() ? null : "DATA_NOT_READY");
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ public interface TraderDecisionTraceWriter {
|
||||
TraderMarketSnapshot snapshot,
|
||||
TraderModelOutput modelOutput,
|
||||
TraderPositionState positionState,
|
||||
TraderAccountState accountState,
|
||||
TraderExecutionState executionState,
|
||||
TraderPositionManagerDecision pmDecision,
|
||||
TraderRiskDecision riskDecision,
|
||||
TraderAction action);
|
||||
|
||||
@@ -33,7 +33,8 @@ public class TraderPositionManager {
|
||||
TraderPmConfig.SizingConfig sizing = input.pmConfig().sizing();
|
||||
EntryOutput entry = input.modelOutput().entry();
|
||||
RiskOutput risk = input.modelOutput().risk();
|
||||
BigDecimal expectedEdge = entry.expectedEdgeBps().max(BigDecimal.ZERO);
|
||||
TraderPricePlanContext pricePlan = input.pricePlanContext();
|
||||
BigDecimal expectedEdge = entry.netEdgeBpsFor(side).max(BigDecimal.ZERO);
|
||||
if (expectedEdge.compareTo(sizing.minEdgeBps()) < 0) {
|
||||
return BigDecimal.ZERO;
|
||||
}
|
||||
@@ -41,8 +42,8 @@ public class TraderPositionManager {
|
||||
? input.modelOutput().direction().longProb()
|
||||
: input.modelOutput().direction().shortProb();
|
||||
BigDecimal entryProb = side.isLong() ? entry.longEntryProb() : entry.shortEntryProb();
|
||||
BigDecimal stopLossBudget = entry.stopDistanceBps().add(entry.costBps()).max(BigDecimal.ONE);
|
||||
BigDecimal uncertaintyPenalty = BigDecimal.ONE.subtract(input.modelOutput().uncertainty()
|
||||
BigDecimal stopLossBudget = pricePlan.stopDistanceBps().add(pricePlan.costBps()).max(BigDecimal.ONE);
|
||||
BigDecimal uncertaintyPenalty = BigDecimal.ONE.subtract(input.modelOutput().metadata().uncertainty()
|
||||
.multiply(sizing.uncertaintyPenaltyMultiplier())).max(BigDecimal.ZERO);
|
||||
BigDecimal edgeRiskBudget = TraderNumbers.safeDivide(expectedEdge, stopLossBudget)
|
||||
.multiply(directionStrength)
|
||||
@@ -50,7 +51,7 @@ public class TraderPositionManager {
|
||||
.multiply(BigDecimal.ONE.subtract(risk.marketRiskProb()))
|
||||
.multiply(uncertaintyPenalty);
|
||||
BigDecimal raw = sizing.baseRatio().multiply(edgeRiskBudget);
|
||||
BigDecimal liquidityCap = risk.liquidityCapacityRatio().multiply(sizing.maxLiquidityUsageRatio());
|
||||
BigDecimal liquidityCap = input.executionState().liquidityCapacityRatio().multiply(sizing.maxLiquidityUsageRatio());
|
||||
BigDecimal lossBudgetCap = TraderNumbers.safeDivide(sizing.maxLossPerTradeBps(), stopLossBudget);
|
||||
BigDecimal hardCap = min(sizing.maxSingleLegRatio(), min(input.accountState().remainingSymbolCapacityRatio(), min(liquidityCap, lossBudgetCap)));
|
||||
if (hardCap.compareTo(sizing.minInitialRatio()) < 0) {
|
||||
@@ -61,7 +62,8 @@ public class TraderPositionManager {
|
||||
|
||||
public BigDecimal calculateAddRatio(PositionManagerInput input) {
|
||||
BigDecimal raw = input.pmConfig().sizing().baseRatio()
|
||||
.multiply(input.modelOutput().continuation().continueVsExitEdgeBps().max(BigDecimal.ZERO))
|
||||
.multiply(input.modelOutput().continuation()
|
||||
.continueEdgeBpsFor(input.positionState().side()).max(BigDecimal.ZERO))
|
||||
.divide(new BigDecimal("100"), java.math.MathContext.DECIMAL64);
|
||||
return TraderNumbers.clamp(raw, input.pmConfig().sizing().minAddRatio(),
|
||||
min(input.pmConfig().sizing().maxAddRatio(), input.positionState().remainingAddCapacity()));
|
||||
@@ -73,20 +75,20 @@ public class TraderPositionManager {
|
||||
RiskOutput risk = input.modelOutput().risk();
|
||||
TraderPmConfig.OpenRuleConfig open = input.pmConfig().open();
|
||||
boolean longPass = direction.longProb().compareTo(open.longOpenProb()) > 0
|
||||
&& direction.directionMargin().compareTo(open.minDirectionMargin()) > 0
|
||||
&& direction.margin().compareTo(open.minDirectionMargin()) > 0
|
||||
&& entry.longEntryProb().compareTo(open.minLongEntryProb()) > 0
|
||||
&& entry.expectedEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
|
||||
&& entry.longExpectedNetEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
|
||||
&& risk.marketRiskProb().compareTo(open.maxMarketRiskProb()) < 0
|
||||
&& risk.liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
|
||||
&& input.modelOutput().oodScore().compareTo(open.maxOodScore()) <= 0
|
||||
&& input.executionState().liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
|
||||
&& input.modelOutput().metadata().oodScore().compareTo(open.maxOodScore()) <= 0
|
||||
&& input.executionState().openOrders().isEmpty();
|
||||
boolean shortPass = direction.shortProb().compareTo(open.shortOpenProb()) > 0
|
||||
&& direction.directionMargin().compareTo(open.minDirectionMargin()) > 0
|
||||
&& direction.margin().compareTo(open.minDirectionMargin()) > 0
|
||||
&& entry.shortEntryProb().compareTo(open.minShortEntryProb()) > 0
|
||||
&& entry.expectedEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
|
||||
&& entry.shortExpectedNetEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
|
||||
&& risk.marketRiskProb().compareTo(open.maxMarketRiskProb()) < 0
|
||||
&& risk.liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
|
||||
&& input.modelOutput().oodScore().compareTo(open.maxOodScore()) <= 0
|
||||
&& input.executionState().liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
|
||||
&& input.modelOutput().metadata().oodScore().compareTo(open.maxOodScore()) <= 0
|
||||
&& input.executionState().openOrders().isEmpty();
|
||||
BigDecimal longRatio = longPass ? calculateInitialRatio(input, PositionSide.LONG) : BigDecimal.ZERO;
|
||||
BigDecimal shortRatio = shortPass ? calculateInitialRatio(input, PositionSide.SHORT) : BigDecimal.ZERO;
|
||||
@@ -106,7 +108,7 @@ public class TraderPositionManager {
|
||||
}
|
||||
if (shouldReduce(input)) {
|
||||
TraderActionType action = input.positionState().side().isLong() ? TraderActionType.REDUCE_LONG : TraderActionType.REDUCE_SHORT;
|
||||
return decision(input, action, input.positionState().side(), null, null, new BigDecimal("0.50"), "PROFIT_GIVEBACK_REDUCE");
|
||||
return decision(input, action, input.positionState().side(), null, null, new BigDecimal("0.50"), "ADVERSE_MOVE_REDUCE");
|
||||
}
|
||||
if (shouldAdd(input)) {
|
||||
BigDecimal ratio = calculateAddRatio(input);
|
||||
@@ -130,17 +132,28 @@ public class TraderPositionManager {
|
||||
: input.modelOutput().continuation().shortContinueProb();
|
||||
return sidePass
|
||||
&& continueProb.compareTo(add.minContinueProb()) > 0
|
||||
&& input.modelOutput().continuation().continueVsExitEdgeBps().compareTo(add.minContinueVsExitEdgeBps()) > 0
|
||||
&& input.modelOutput().entry().expectedEdgeBps().compareTo(add.minExpectedEdgeBps()) > 0
|
||||
&& input.modelOutput().continuation().continueEdgeBpsFor(input.positionState().side())
|
||||
.compareTo(add.minContinueVsExitEdgeBps()) > 0
|
||||
&& input.modelOutput().entry().netEdgeBpsFor(input.positionState().side())
|
||||
.compareTo(add.minExpectedEdgeBps()) > 0
|
||||
&& input.modelOutput().risk().marketRiskProb().compareTo(add.maxMarketRiskProb()) < 0
|
||||
&& input.modelOutput().risk().positionRiskProb().compareTo(add.maxPositionRiskProb()) < 0
|
||||
&& input.modelOutput().risk().liquidityCapacityRatio().compareTo(add.minLiquidityCapacityRatio()) >= 0
|
||||
&& input.modelOutput().risk().sideRiskProbFor(input.positionState().side())
|
||||
.compareTo(add.maxPositionRiskProb()) < 0
|
||||
&& input.executionState().liquidityCapacityRatio().compareTo(add.minLiquidityCapacityRatio()) >= 0
|
||||
&& input.positionState().unrealizedPnlBps().compareTo(BigDecimal.ZERO) > 0
|
||||
&& input.positionState().addCount() < add.maxAddCount()
|
||||
&& addCooldownPassed(input, add)
|
||||
&& input.positionState().remainingAddCapacity().compareTo(BigDecimal.ZERO) > 0
|
||||
&& input.executionState().openOrders().isEmpty();
|
||||
}
|
||||
|
||||
private boolean addCooldownPassed(PositionManagerInput input, TraderPmConfig.AddRuleConfig add) {
|
||||
if (input.positionState().lastAddTime() == null) {
|
||||
return true;
|
||||
}
|
||||
return !input.cycle().cycleTime().isBefore(input.positionState().lastAddTime().plusSeconds(add.cooldownMinutes() * 60));
|
||||
}
|
||||
|
||||
private boolean shouldClose(PositionManagerInput input) {
|
||||
TraderPmConfig.ExitRuleConfig exit = input.pmConfig().exit();
|
||||
boolean exitProb = input.positionState().side().isLong()
|
||||
@@ -151,9 +164,11 @@ public class TraderPositionManager {
|
||||
: input.modelOutput().continuation().shortContinueProb();
|
||||
return exitProb
|
||||
|| continueProb.compareTo(exit.closeContinueMax()) < 0
|
||||
|| input.modelOutput().risk().positionRiskProb().compareTo(exit.closePositionRiskProb()) > 0
|
||||
|| input.modelOutput().risk().sideRiskProbFor(input.positionState().side())
|
||||
.compareTo(exit.closePositionRiskProb()) > 0
|
||||
|| input.modelOutput().risk().marketRiskProb().compareTo(exit.closeMarketRiskProb()) > 0
|
||||
|| input.modelOutput().risk().expectedShortfallBps().compareTo(exit.maxExpectedShortfallBps()) > 0;
|
||||
|| input.modelOutput().risk().positionPathRiskBpsFor(input.positionState().side())
|
||||
.compareTo(exit.maxPositionPathRiskBps()) > 0;
|
||||
}
|
||||
|
||||
private boolean shouldReduce(PositionManagerInput input) {
|
||||
@@ -161,7 +176,7 @@ public class TraderPositionManager {
|
||||
BigDecimal continueProb = input.positionState().side().isLong()
|
||||
? input.modelOutput().continuation().longContinueProb()
|
||||
: input.modelOutput().continuation().shortContinueProb();
|
||||
return input.modelOutput().exit().profitGivebackProb().compareTo(exit.reduceGivebackProb()) > 0
|
||||
return input.modelOutput().exit().reasonScore("adverse_move_prob").compareTo(exit.reduceAdverseMoveProb()) > 0
|
||||
&& continueProb.compareTo(exit.reduceContinueMin()) >= 0
|
||||
&& continueProb.compareTo(exit.reduceContinueMax()) <= 0
|
||||
&& input.positionState().unrealizedPnlBps().compareTo(exit.minProfitForReduceBps()) > 0;
|
||||
@@ -173,9 +188,9 @@ public class TraderPositionManager {
|
||||
|
||||
private TraderPositionManagerDecision decision(PositionManagerInput input, TraderActionType action, PositionSide side,
|
||||
BigDecimal targetRatio, BigDecimal addRatio, BigDecimal reduceRatio, String reason) {
|
||||
EntryOutput entry = input.modelOutput().entry();
|
||||
BigDecimal stopPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), entry.stopDistanceBps(), side, false) : null;
|
||||
BigDecimal targetPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), entry.targetDistanceBps(), side, true) : null;
|
||||
TraderPricePlanContext pricePlan = input.pricePlanContext();
|
||||
BigDecimal stopPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), pricePlan.stopDistanceBps(), side, false) : null;
|
||||
BigDecimal targetPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), pricePlan.targetDistanceBps(), side, true) : null;
|
||||
return new TraderPositionManagerDecision(
|
||||
"pm_" + input.cycle().cycleId(),
|
||||
input.cycle().runId(),
|
||||
@@ -186,15 +201,17 @@ public class TraderPositionManager {
|
||||
input.executionState().executionStateId(),
|
||||
action,
|
||||
side,
|
||||
action.increasesExposure() ? entry.pricePlanId() : null,
|
||||
action.increasesExposure() ? entry.pricePlanConfigHash() : null,
|
||||
action.increasesExposure() ? pricePlan.pricePlanId() : null,
|
||||
action.increasesExposure() ? pricePlan.pricePlanConfigHash() : null,
|
||||
targetRatio,
|
||||
addRatio,
|
||||
reduceRatio,
|
||||
stopPrice,
|
||||
targetPrice,
|
||||
reason,
|
||||
Map.of("pmConfigVersion", input.pmConfig().pmConfigVersion()));
|
||||
Map.of(
|
||||
"pmConfigVersion", input.pmConfig().pmConfigVersion(),
|
||||
"pricePlanMaxHoldMinutes", pricePlan.maxHoldMinutes()));
|
||||
}
|
||||
|
||||
private BigDecimal priceFromBps(BigDecimal markPrice, BigDecimal distanceBps, PositionSide side, boolean profitTarget) {
|
||||
|
||||
@@ -1,15 +1,44 @@
|
||||
package com.quantai.trader.replay;
|
||||
|
||||
import static com.quantai.trader.util.TraderNumbers.nonNegative;
|
||||
import static com.quantai.trader.util.TraderNumbers.positive;
|
||||
import static com.quantai.trader.util.TraderNumbers.required;
|
||||
import static com.quantai.trader.util.TraderNumbers.requiredText;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.time.Instant;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public record ReplayMarketEvent(
|
||||
String runId,
|
||||
String symbol,
|
||||
Instant eventTime,
|
||||
String featureVersion,
|
||||
BigDecimal markPrice,
|
||||
BigDecimal indexPrice,
|
||||
BigDecimal spreadBps,
|
||||
BigDecimal depthNotional5Bps
|
||||
BigDecimal fundingRateBps,
|
||||
BigDecimal depthNotional5Bps,
|
||||
BigDecimal depthNotional10Bps,
|
||||
BigDecimal depthNotional25Bps,
|
||||
boolean dataReady,
|
||||
Map<String, Object> featureJson,
|
||||
Map<String, Object> dataQualityJson
|
||||
) {
|
||||
public ReplayMarketEvent {
|
||||
runId = requiredText(runId, "runId");
|
||||
symbol = requiredText(symbol, "symbol");
|
||||
eventTime = Objects.requireNonNull(eventTime, "eventTime is required");
|
||||
featureVersion = requiredText(featureVersion, "featureVersion");
|
||||
markPrice = positive(markPrice, "markPrice");
|
||||
indexPrice = positive(indexPrice, "indexPrice");
|
||||
spreadBps = nonNegative(spreadBps, "spreadBps");
|
||||
fundingRateBps = required(fundingRateBps, "fundingRateBps");
|
||||
depthNotional5Bps = nonNegative(depthNotional5Bps, "depthNotional5Bps");
|
||||
depthNotional10Bps = nonNegative(depthNotional10Bps, "depthNotional10Bps");
|
||||
depthNotional25Bps = nonNegative(depthNotional25Bps, "depthNotional25Bps");
|
||||
featureJson = Map.copyOf(featureJson == null ? Map.of() : featureJson);
|
||||
dataQualityJson = Map.copyOf(dataQualityJson == null ? Map.of() : dataQualityJson);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.quantai.trader.replay;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactBundle;
|
||||
import com.quantai.trader.artifact.TraderArtifactManifestRepository;
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.*;
|
||||
@@ -9,6 +10,7 @@ import com.quantai.trader.evidence.EvidenceAppender;
|
||||
import com.quantai.trader.model.TraderModelService;
|
||||
import com.quantai.trader.outbox.TraderOutboxRepository;
|
||||
import com.quantai.trader.outbox.TraderOutboxEvent;
|
||||
import com.quantai.trader.outbox.TraderOutboxDispatcher;
|
||||
import com.quantai.trader.persistence.TraderDecisionTraceWriter;
|
||||
import com.quantai.trader.position.TraderPositionManager;
|
||||
import com.quantai.trader.replay.state.TraderReplayState;
|
||||
@@ -16,11 +18,15 @@ import com.quantai.trader.replay.state.TraderReplayStateStore;
|
||||
import com.quantai.trader.risk.RiskGateInput;
|
||||
import com.quantai.trader.risk.RiskLimits;
|
||||
import com.quantai.trader.risk.TraderRiskGate;
|
||||
import com.quantai.trader.runtime.TraderRuntimeControlDecision;
|
||||
import com.quantai.trader.runtime.TraderRuntimeControlService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.support.TransactionSynchronization;
|
||||
import org.springframework.transaction.support.TransactionSynchronizationManager;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.time.Instant;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -36,8 +42,11 @@ public class TraderP0CycleRunner {
|
||||
private final TraderActionFactory actionFactory;
|
||||
private final EvidenceAppender evidenceAppender;
|
||||
private final TraderDecisionTraceWriter traceWriter;
|
||||
private final TraderArtifactManifestRepository artifactManifestRepository;
|
||||
private final TraderOutboxRepository outboxRepository;
|
||||
private final TraderOutboxDispatcher outboxDispatcher;
|
||||
private final TraderReplayStateStore stateStore;
|
||||
private final TraderRuntimeControlService runtimeControlService;
|
||||
|
||||
public TraderP0CycleRunner(TraderProperties properties,
|
||||
TraderArtifactLoader artifactLoader,
|
||||
@@ -47,8 +56,11 @@ public class TraderP0CycleRunner {
|
||||
TraderActionFactory actionFactory,
|
||||
EvidenceAppender evidenceAppender,
|
||||
TraderDecisionTraceWriter traceWriter,
|
||||
TraderArtifactManifestRepository artifactManifestRepository,
|
||||
TraderOutboxRepository outboxRepository,
|
||||
TraderReplayStateStore stateStore) {
|
||||
TraderOutboxDispatcher outboxDispatcher,
|
||||
TraderReplayStateStore stateStore,
|
||||
TraderRuntimeControlService runtimeControlService) {
|
||||
this.properties = properties;
|
||||
this.artifactLoader = artifactLoader;
|
||||
this.modelService = modelService;
|
||||
@@ -57,47 +69,88 @@ public class TraderP0CycleRunner {
|
||||
this.actionFactory = actionFactory;
|
||||
this.evidenceAppender = evidenceAppender;
|
||||
this.traceWriter = traceWriter;
|
||||
this.artifactManifestRepository = artifactManifestRepository;
|
||||
this.outboxRepository = outboxRepository;
|
||||
this.outboxDispatcher = outboxDispatcher;
|
||||
this.stateStore = stateStore;
|
||||
this.runtimeControlService = runtimeControlService;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public TraderCycleResult runCycle(ReplayMarketEvent event) {
|
||||
String cycleId = "cycle_" + event.runId() + "_" + event.eventTime().toEpochMilli();
|
||||
TraderArtifactBundle bundle = artifactLoader.loadActiveBundle();
|
||||
artifactManifestRepository.upsertActiveBundle(bundle);
|
||||
TraderDecisionCycle cycle = new TraderDecisionCycle(event.runId(), cycleId, event.symbol(), event.eventTime(),
|
||||
properties.runMode(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion());
|
||||
TraderMarketSnapshot snapshot = snapshot(event, cycleId);
|
||||
TraderReplayState state = stateStore.load(cycle, snapshot);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of());
|
||||
TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId()));
|
||||
PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput,
|
||||
state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig());
|
||||
TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name()));
|
||||
TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(),
|
||||
pmInput.executionState(), snapshot, riskLimits()));
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "RISK_DECISION", riskDecision.allowAction(), riskDecision.allowAction() ? "RISK_PASS" : riskDecision.blocker(), riskDecision.blocker(), Map.of());
|
||||
TraderAction action = actionFactory.create(pmDecision, riskDecision, event.symbol());
|
||||
traceWriter.persistCycleTrace(cycle, snapshot, modelOutput, pmInput.positionState(), pmDecision, riskDecision, action);
|
||||
outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(),
|
||||
"TRADER_ACTION", action.actionId(), "ACTION_CREATED", properties.runMode().name() + "_RECORDER",
|
||||
Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now()));
|
||||
stateStore.advance(state, action, snapshot);
|
||||
log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING", action.runId(), action.cycleId(), action.actionType());
|
||||
return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action);
|
||||
runtimeControlService.acquireCycleLock(cycle);
|
||||
try {
|
||||
TraderMarketSnapshot snapshot = snapshot(event, cycleId);
|
||||
TraderReplayState state = stateStore.load(cycle, snapshot);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of());
|
||||
TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId()));
|
||||
PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput,
|
||||
bundle.pricePlanContext(), state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig());
|
||||
TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name()));
|
||||
TraderRuntimeControlDecision runtimeDecision = runtimeControlService.validateExposureIncrease(cycle, pmDecision);
|
||||
if (runtimeDecision.allowed()
|
||||
&& pmDecision.candidateAction().increasesExposure()
|
||||
&& outboxRepository.hasUnsentExposureIncrease(cycle.runId(), cycle.symbol(), pmDecision.side())) {
|
||||
runtimeDecision = TraderRuntimeControlDecision.block("OUTBOX_PENDING_OPEN_ADD");
|
||||
}
|
||||
TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(),
|
||||
pmInput.executionState(), snapshot, riskLimits(runtimeDecision)));
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "RISK_DECISION", riskDecision.allowAction(), riskDecision.allowAction() ? "RISK_PASS" : riskDecision.blocker(), riskDecision.blocker(), Map.of());
|
||||
TraderAction action = actionFactory.create(pmDecision, riskDecision, event.symbol());
|
||||
traceWriter.persistCycleTrace(cycle, snapshot, modelOutput, pmInput.positionState(),
|
||||
pmInput.accountState(), pmInput.executionState(), pmDecision, riskDecision, action);
|
||||
String destination = outboxDestination();
|
||||
outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(),
|
||||
"TRADER_ACTION", action.actionId(), "ACTION_CREATED", destination,
|
||||
Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now()));
|
||||
dispatchAfterCommit(action, state, snapshot, destination);
|
||||
log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING runtimeBlocker={}",
|
||||
action.runId(), action.cycleId(), action.actionType(), runtimeDecision.blocker());
|
||||
return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action);
|
||||
} finally {
|
||||
runtimeControlService.releaseCycleLock(cycle);
|
||||
}
|
||||
}
|
||||
|
||||
private TraderMarketSnapshot snapshot(ReplayMarketEvent event, String cycleId) {
|
||||
return new TraderMarketSnapshot("snapshot_" + cycleId, event.runId(), cycleId, event.symbol(), event.eventTime(),
|
||||
"feature-v4-p0", event.markPrice(), event.indexPrice(), event.spreadBps(), BigDecimal.ZERO,
|
||||
event.depthNotional5Bps(), event.depthNotional5Bps(), event.depthNotional5Bps(),
|
||||
event.depthNotional5Bps().compareTo(BigDecimal.ZERO) > 0, Map.of(), Map.of());
|
||||
event.featureVersion(), event.markPrice(), event.indexPrice(), event.spreadBps(), event.fundingRateBps(),
|
||||
event.depthNotional5Bps(), event.depthNotional10Bps(), event.depthNotional25Bps(),
|
||||
event.dataReady(), event.featureJson(), event.dataQualityJson());
|
||||
}
|
||||
|
||||
private RiskLimits riskLimits() {
|
||||
private RiskLimits riskLimits(TraderRuntimeControlDecision runtimeDecision) {
|
||||
return new RiskLimits(properties.risk().maxDailyLossBps(), properties.risk().maxTotalExposureRatio(),
|
||||
properties.risk().minLiquidationBufferBps(), properties.execution().maxApiErrorCount(),
|
||||
properties.execution().maxExchangeLatencyMs(), false, false);
|
||||
properties.execution().maxExchangeLatencyMs(), false, !runtimeDecision.allowed(), runtimeDecision.blocker());
|
||||
}
|
||||
|
||||
private String outboxDestination() {
|
||||
return switch (properties.execution().mode()) {
|
||||
case REPLAY_SIM -> "REPLAY_SIM_EXECUTION";
|
||||
case SHADOW -> "SHADOW_RECORDER";
|
||||
case PAPER, REAL -> throw new IllegalStateException("P0 runtime guard must reject PAPER/REAL execution mode");
|
||||
};
|
||||
}
|
||||
|
||||
private void dispatchAfterCommit(TraderAction action, TraderReplayState state,
|
||||
TraderMarketSnapshot snapshot, String destination) {
|
||||
if (!TransactionSynchronizationManager.isSynchronizationActive()) {
|
||||
outboxDispatcher.dispatch(action, state, snapshot, destination);
|
||||
return;
|
||||
}
|
||||
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
|
||||
@Override
|
||||
public void afterCommit() {
|
||||
outboxDispatcher.dispatch(action, state, snapshot, destination);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
package com.quantai.trader.replay.state;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.domain.TraderAccountState;
|
||||
import com.quantai.trader.domain.TraderExecutionState;
|
||||
import com.quantai.trader.domain.TraderPositionState;
|
||||
import com.quantai.trader.persistence.TraderJsonCodec;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
|
||||
@Repository
|
||||
public class JdbcTraderPostActionStateRepository implements TraderPostActionStateRepository {
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
private final TraderJsonCodec jsonCodec;
|
||||
|
||||
public JdbcTraderPostActionStateRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
this.jsonCodec = new TraderJsonCodec(objectMapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void insertPostActionState(TraderReplayState state) {
|
||||
insertPositionState(state.positionState());
|
||||
insertAccountState(state.accountState());
|
||||
insertExecutionState(state.executionState());
|
||||
}
|
||||
|
||||
private void insertPositionState(TraderPositionState state) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_position_state
|
||||
(run_id, cycle_id, position_state_id, state_role, symbol, side, position_ratio,
|
||||
average_entry_price, current_price, unrealized_pnl_bps, liquidation_buffer_bps,
|
||||
add_count, remaining_add_capacity, last_add_time)
|
||||
values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
state.runId(), state.cycleId(), state.positionStateId() + "_post", state.symbol(),
|
||||
state.side().name(), state.positionRatio(), state.averageEntryPrice(), state.currentPrice(),
|
||||
state.unrealizedPnlBps(), state.liquidationBufferBps(), state.addCount(),
|
||||
state.remainingAddCapacity(), state.lastAddTime() == null ? null : Timestamp.from(state.lastAddTime()));
|
||||
}
|
||||
|
||||
private void insertAccountState(TraderAccountState state) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_account_state
|
||||
(run_id, cycle_id, account_state_id, state_role, daily_drawdown_bps,
|
||||
portfolio_exposure_ratio, remaining_symbol_capacity_ratio, consecutive_losses)
|
||||
values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?)
|
||||
""",
|
||||
state.runId(), state.cycleId(), state.accountStateId() + "_post", state.dailyDrawdownBps(),
|
||||
state.portfolioExposureRatio(), state.remainingSymbolCapacityRatio(), state.consecutiveLosses());
|
||||
}
|
||||
|
||||
private void insertExecutionState(TraderExecutionState state) {
|
||||
jdbcTemplate.update("""
|
||||
insert into trader_execution_state
|
||||
(run_id, cycle_id, execution_state_id, state_role, symbol, open_orders_json,
|
||||
expected_slippage_bps, exchange_latency_ms, api_error_count, maker_fee_bps,
|
||||
taker_fee_bps, min_notional, price_tick_size, lot_size_step_size,
|
||||
market_lot_size_step_size, liquidity_capacity_ratio)
|
||||
values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
state.runId(), state.cycleId(), state.executionStateId() + "_post", state.symbol(),
|
||||
jsonCodec.toJson(state.openOrders()), state.expectedSlippageBps(), state.exchangeLatencyMs(),
|
||||
state.apiErrorCount(), state.makerFeeBps(), state.takerFeeBps(), state.minNotional(),
|
||||
state.priceTickSize(), state.lotSizeStepSize(), state.marketLotSizeStepSize(),
|
||||
state.liquidityCapacityRatio());
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.MathContext;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@@ -88,7 +87,7 @@ public class P0ReplayStateStore implements TraderReplayStateStore {
|
||||
BigDecimal weightedEntry = weightedEntry(current.averageEntryPrice(), current.positionRatio(), snapshot.markPrice(), addRatio);
|
||||
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
|
||||
current.side(), newRatio, weightedEntry, snapshot.markPrice(), pnlBps(current.side(), weightedEntry, snapshot.markPrice()),
|
||||
current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), Instant.now());
|
||||
current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), snapshot.snapshotTime());
|
||||
}
|
||||
|
||||
private TraderPositionState reducePosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
package com.quantai.trader.replay.state;
|
||||
|
||||
public interface TraderPostActionStateRepository {
|
||||
void insertPostActionState(TraderReplayState state);
|
||||
}
|
||||
@@ -9,6 +9,7 @@ public record RiskLimits(
|
||||
int maxApiErrorCount,
|
||||
long maxExchangeLatencyMs,
|
||||
boolean killSwitchActive,
|
||||
boolean executionBlocked
|
||||
boolean executionBlocked,
|
||||
String executionBlocker
|
||||
) {
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ public class TraderRiskGate {
|
||||
if (input.riskLimits().killSwitchActive() && input.pmDecision().candidateAction().increasesExposure()) {
|
||||
decision = block(input, "KILL_SWITCH_ACTIVE");
|
||||
} else if (input.riskLimits().executionBlocked()) {
|
||||
decision = block(input, "EXECUTION_BLOCKED");
|
||||
decision = block(input, input.riskLimits().executionBlocker());
|
||||
} else if (input.accountState().dailyDrawdownBps().compareTo(input.riskLimits().maxDailyLossBps()) >= 0) {
|
||||
decision = block(input, "MAX_DAILY_LOSS");
|
||||
} else if (input.accountState().portfolioExposureRatio().compareTo(input.riskLimits().maxTotalExposureRatio()) >= 0
|
||||
@@ -46,6 +46,9 @@ public class TraderRiskGate {
|
||||
}
|
||||
|
||||
private TraderRiskDecision block(RiskGateInput input, String blocker) {
|
||||
if (blocker == null || blocker.isBlank()) {
|
||||
blocker = "EXECUTION_BLOCKED";
|
||||
}
|
||||
TraderActionType finalAction = input.pmDecision().candidateAction().increasesExposure() ? TraderActionType.WAIT : input.pmDecision().candidateAction();
|
||||
return new TraderRiskDecision("risk_" + input.pmDecision().cycleId(), input.pmDecision().runId(), input.pmDecision().cycleId(),
|
||||
input.pmDecision().pmDecisionId(), false, input.pmDecision().candidateAction(), finalAction, blocker,
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
package com.quantai.trader.runtime;
|
||||
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.TraderDecisionCycle;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.domain.TraderPositionManagerDecision;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Component
|
||||
public class RedisTraderRuntimeControlService implements TraderRuntimeControlService {
|
||||
private static final Logger log = LoggerFactory.getLogger(RedisTraderRuntimeControlService.class);
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final StringRedisTemplate redisTemplate;
|
||||
|
||||
public RedisTraderRuntimeControlService(TraderProperties properties, StringRedisTemplate redisTemplate) {
|
||||
this.properties = properties;
|
||||
this.redisTemplate = redisTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void acquireCycleLock(TraderDecisionCycle cycle) {
|
||||
String key = key("cycle:lock:" + cycle.symbol() + ":" + cycle.cycleTime().toEpochMilli());
|
||||
try {
|
||||
Boolean acquired = redisTemplate.opsForValue().setIfAbsent(key, cycle.cycleId(), Duration.ofSeconds(30));
|
||||
if (!Boolean.TRUE.equals(acquired)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED,
|
||||
"cycle lock already exists: " + key);
|
||||
}
|
||||
log.info("event=trader.runtime.lock_acquired runId={} cycleId={} key={}",
|
||||
cycle.runId(), cycle.cycleId(), key);
|
||||
} catch (TraderException exception) {
|
||||
throw exception;
|
||||
} catch (RuntimeException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED,
|
||||
"redis unavailable while acquiring cycle lock");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void releaseCycleLock(TraderDecisionCycle cycle) {
|
||||
String key = key("cycle:lock:" + cycle.symbol() + ":" + cycle.cycleTime().toEpochMilli());
|
||||
try {
|
||||
redisTemplate.delete(key);
|
||||
log.info("event=trader.runtime.lock_released runId={} cycleId={} key={}",
|
||||
cycle.runId(), cycle.cycleId(), key);
|
||||
} catch (RuntimeException exception) {
|
||||
log.warn("event=trader.runtime.lock_release_failed runId={} cycleId={} key={} blocker=REDIS_UNAVAILABLE",
|
||||
cycle.runId(), cycle.cycleId(), key);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle,
|
||||
TraderPositionManagerDecision decision) {
|
||||
if (!decision.candidateAction().increasesExposure()) {
|
||||
return TraderRuntimeControlDecision.allow();
|
||||
}
|
||||
try {
|
||||
String activePointer = redisTemplate.opsForValue().get(key("model:active:" + cycle.symbol()));
|
||||
String expectedPointer = cycle.modelBundleVersion() + "|" + cycle.calibrationBundleVersion() + "|" + cycle.pmConfigVersion();
|
||||
if (properties.release().activePointerCheckEnabled() && activePointer == null) {
|
||||
return TraderRuntimeControlDecision.block("ACTIVE_POINTER_MISSING");
|
||||
}
|
||||
if (properties.release().activePointerCheckEnabled() && !expectedPointer.equals(activePointer)) {
|
||||
return TraderRuntimeControlDecision.block("ACTIVE_POINTER_MISMATCH");
|
||||
}
|
||||
if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:kill-switch:global")))) {
|
||||
return TraderRuntimeControlDecision.block("KILL_SWITCH_ACTIVE");
|
||||
}
|
||||
if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:kill-switch:symbol:" + cycle.symbol())))) {
|
||||
return TraderRuntimeControlDecision.block("KILL_SWITCH_ACTIVE");
|
||||
}
|
||||
if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:execution:close-only:" + cycle.symbol())))) {
|
||||
return TraderRuntimeControlDecision.block("CLOSE_ONLY_ACTIVE");
|
||||
}
|
||||
redisTemplate.opsForValue().set(key("runtime:open-add-probe:" + cycle.cycleId()), "1", Duration.ofSeconds(30));
|
||||
return TraderRuntimeControlDecision.allow();
|
||||
} catch (RuntimeException exception) {
|
||||
return TraderRuntimeControlDecision.block("REDIS_UNAVAILABLE");
|
||||
}
|
||||
}
|
||||
|
||||
private String key(String suffix) {
|
||||
return properties.runtime().redisKeyPrefix() + ":" + suffix;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.quantai.trader.runtime;
|
||||
|
||||
public record TraderRuntimeControlDecision(
|
||||
boolean allowed,
|
||||
String blocker
|
||||
) {
|
||||
public static TraderRuntimeControlDecision allow() {
|
||||
return new TraderRuntimeControlDecision(true, null);
|
||||
}
|
||||
|
||||
public static TraderRuntimeControlDecision block(String blocker) {
|
||||
return new TraderRuntimeControlDecision(false, blocker);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.quantai.trader.runtime;
|
||||
|
||||
import com.quantai.trader.domain.TraderDecisionCycle;
|
||||
import com.quantai.trader.domain.TraderPositionManagerDecision;
|
||||
|
||||
public interface TraderRuntimeControlService {
|
||||
void acquireCycleLock(TraderDecisionCycle cycle);
|
||||
|
||||
void releaseCycleLock(TraderDecisionCycle cycle);
|
||||
|
||||
TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle,
|
||||
TraderPositionManagerDecision decision);
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package com.quantai.trader.util;
|
||||
import java.math.BigDecimal;
|
||||
import java.math.MathContext;
|
||||
import java.util.Collection;
|
||||
import java.util.Objects;
|
||||
|
||||
public final class TraderNumbers {
|
||||
public static final BigDecimal ZERO = BigDecimal.ZERO;
|
||||
@@ -14,7 +13,10 @@ public final class TraderNumbers {
|
||||
}
|
||||
|
||||
public static BigDecimal required(BigDecimal value, String field) {
|
||||
return Objects.requireNonNull(value, field + " is required");
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException(field + " is required");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
public static String requiredText(String value, String field) {
|
||||
@@ -68,6 +70,9 @@ public final class TraderNumbers {
|
||||
}
|
||||
|
||||
public static <T> Collection<T> requiredCollection(Collection<T> value, String field) {
|
||||
return Objects.requireNonNull(value, field + " is required");
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException(field + " is required");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,22 +33,111 @@ create table trader_decision_cycle (
|
||||
);
|
||||
|
||||
create table trader_model_bundle_manifest (
|
||||
id bigint primary key auto_increment,
|
||||
model_bundle_version varchar(96) not null,
|
||||
calibration_bundle_version varchar(96) not null,
|
||||
feature_version varchar(96) not null,
|
||||
label_version varchar(96) not null,
|
||||
split_version varchar(96) not null,
|
||||
required_models_json json not null,
|
||||
provided_models_json json not null,
|
||||
missing_models_json json not null,
|
||||
bundle_hash_sha256 varchar(64) not null,
|
||||
complete boolean not null,
|
||||
status varchar(32) not null,
|
||||
created_at datetime(3) not null default current_timestamp(3),
|
||||
id bigint primary key auto_increment comment '自增主键',
|
||||
manifest_schema_version varchar(32) not null comment '清单格式版本,代码启动时必须校验',
|
||||
model_bundle_version varchar(96) not null comment '五个小模型组成的一整包模型版本',
|
||||
calibration_bundle_version varchar(96) not null comment '这一包模型对应的校准器版本',
|
||||
feature_version varchar(96) not null comment '训练和运行使用的特征版本',
|
||||
label_version varchar(96) not null comment '训练标签版本',
|
||||
split_version varchar(96) not null comment '训练、验证、测试切分版本',
|
||||
training_run_id varchar(128) not null comment '训练运行编号,用来回查训练日志',
|
||||
training_export_id varchar(128) not null comment '训练产物导出编号',
|
||||
backtest_manifest_id varchar(128) not null comment '对应的回测报告编号',
|
||||
required_models_json json not null comment '必须具备的模型类型',
|
||||
provided_models_json json not null comment '实际提供的模型类型',
|
||||
missing_models_json json not null comment '缺失的模型类型,运行时必须为空数组',
|
||||
allowed_run_modes_json json not null comment '允许使用的运行模式,P0 只能包含 REPLAY_SIM、SHADOW',
|
||||
bundle_hash_sha256 varchar(64) not null comment '整包文件指纹',
|
||||
complete boolean not null comment '模型包是否完整',
|
||||
status varchar(32) not null comment '状态:候选、启用、拒绝、下线',
|
||||
created_at datetime(3) not null default current_timestamp(3) comment '创建时间',
|
||||
updated_at datetime(3) not null default current_timestamp(3) on update current_timestamp(3) comment '更新时间',
|
||||
unique key uk_trader_model_bundle (model_bundle_version, calibration_bundle_version),
|
||||
key idx_trader_model_bundle_status (status, created_at),
|
||||
key idx_trader_model_bundle_training (training_run_id, training_export_id),
|
||||
constraint chk_trader_model_bundle_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED'))
|
||||
);
|
||||
) comment='Trader 模型整包清单,控制五个小模型是否齐全、是否允许进入运行';
|
||||
|
||||
create table trader_model_manifest (
|
||||
id bigint primary key auto_increment comment '自增主键',
|
||||
model_bundle_version varchar(96) not null comment '所属模型整包版本',
|
||||
calibration_bundle_version varchar(96) not null comment '对应校准器整包版本',
|
||||
model_name varchar(96) not null comment '小模型稳定名字',
|
||||
model_type varchar(32) not null comment '小模型类型,只允许 DIRECTION、ENTRY、CONTINUE、EXIT、RISK',
|
||||
side varchar(16) not null comment '适用方向:LONG、SHORT、BOTH',
|
||||
symbol_scope_json json not null comment '适用交易对范围',
|
||||
bar_interval varchar(16) not null comment '输入行情周期',
|
||||
horizon_minutes int not null comment '预测周期,单位分钟',
|
||||
model_format varchar(32) not null comment '模型文件格式,首版固定 ONNX',
|
||||
model_runtime varchar(64) not null comment 'Java 侧加载方式,首版固定 ONNX_RUNTIME_JAVA',
|
||||
model_runtime_version varchar(64) not null comment '运行库版本',
|
||||
onnx_opset_version int not null comment 'ONNX 算子版本',
|
||||
producer_name varchar(128) not null comment '模型导出工具名称',
|
||||
producer_version varchar(64) not null comment '模型导出工具版本',
|
||||
artifact_path varchar(512) not null comment '模型文件相对 artifact root 的路径',
|
||||
artifact_hash_sha256 varchar(64) not null comment '模型文件指纹',
|
||||
source_hash varchar(64) not null comment '训练代码和关键训练配置指纹',
|
||||
feature_version varchar(96) not null comment '特征版本',
|
||||
feature_schema_path varchar(512) not null comment '输入特征清单文件路径',
|
||||
feature_schema_hash varchar(64) not null comment '输入特征清单指纹',
|
||||
feature_order_path varchar(512) not null comment '输入特征顺序文件路径',
|
||||
feature_order_hash varchar(64) not null comment '输入特征顺序指纹',
|
||||
input_tensor_name varchar(128) not null comment 'ONNX 输入张量名',
|
||||
input_dtype varchar(32) not null comment '输入数字类型',
|
||||
input_shape_json json not null comment '输入形状',
|
||||
input_example_path varchar(512) not null comment '固定样例输入文件,用于启动冒烟检查',
|
||||
output_schema_path varchar(512) not null comment '输出字段说明文件路径',
|
||||
output_schema_hash varchar(64) not null comment '输出字段说明指纹',
|
||||
output_tensor_names_json json not null comment 'ONNX 原始输出张量名',
|
||||
output_mapping_json json not null comment '原始输出到业务字段的映射',
|
||||
output_value_rules_json json not null comment '输出合法范围',
|
||||
label_version varchar(96) not null comment '标签版本',
|
||||
split_version varchar(96) not null comment '训练切分版本',
|
||||
training_fold varchar(64) not null comment '训练折编号',
|
||||
train_start datetime(3) not null comment '训练数据开始时间',
|
||||
train_end datetime(3) not null comment '训练数据结束时间',
|
||||
validation_start datetime(3) not null comment '验证数据开始时间',
|
||||
validation_end datetime(3) not null comment '验证数据结束时间',
|
||||
test_start datetime(3) not null comment '测试数据开始时间',
|
||||
test_end datetime(3) not null comment '测试数据结束时间',
|
||||
metrics_json json not null comment '模型评估结果',
|
||||
status varchar(32) not null comment '状态:候选、启用、拒绝、下线',
|
||||
created_at datetime(3) not null default current_timestamp(3) comment '创建时间',
|
||||
updated_at datetime(3) not null default current_timestamp(3) on update current_timestamp(3) comment '更新时间',
|
||||
unique key uk_trader_model_manifest (model_bundle_version, calibration_bundle_version, model_name, side, horizon_minutes),
|
||||
key idx_trader_model_manifest_type (model_type, status),
|
||||
key idx_trader_model_manifest_bundle (model_bundle_version, calibration_bundle_version),
|
||||
constraint chk_trader_model_type check (model_type in ('DIRECTION','ENTRY','CONTINUE','EXIT','RISK')),
|
||||
constraint chk_trader_model_side check (side in ('LONG','SHORT','BOTH')),
|
||||
constraint chk_trader_model_format check (model_format in ('ONNX')),
|
||||
constraint chk_trader_model_runtime check (model_runtime in ('ONNX_RUNTIME_JAVA')),
|
||||
constraint chk_trader_model_input_dtype check (input_dtype in ('FLOAT32')),
|
||||
constraint chk_trader_model_horizon check (horizon_minutes > 0),
|
||||
constraint chk_trader_model_manifest_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED'))
|
||||
) comment='Trader 单个小模型清单,说明模型如何加载、输入是什么、输出是什么';
|
||||
|
||||
create table trader_calibration_manifest (
|
||||
id bigint primary key auto_increment comment '自增主键',
|
||||
calibration_bundle_version varchar(96) not null comment '校准器整包版本',
|
||||
model_bundle_version varchar(96) not null comment '对应模型整包版本',
|
||||
model_name varchar(96) not null comment '对应小模型名字',
|
||||
calibrator_version varchar(96) not null comment '单个校准器版本',
|
||||
calibration_method varchar(64) not null comment '校准方法',
|
||||
calibrator_path varchar(512) not null comment '校准文件路径',
|
||||
calibrator_hash_sha256 varchar(64) not null comment '校准文件指纹',
|
||||
calibration_window_from datetime(3) not null comment '校准数据开始时间',
|
||||
calibration_window_to datetime(3) not null comment '校准数据结束时间',
|
||||
calibration_metrics_json json not null comment '校准效果指标',
|
||||
bucket_metrics_json json not null comment '分桶校准明细',
|
||||
output_after_calibration_schema_hash varchar(64) not null comment '校准后输出格式指纹',
|
||||
status varchar(32) not null comment '状态:候选、启用、拒绝、下线',
|
||||
created_at datetime(3) not null default current_timestamp(3) comment '创建时间',
|
||||
updated_at datetime(3) not null default current_timestamp(3) on update current_timestamp(3) comment '更新时间',
|
||||
unique key uk_trader_calibration_manifest (calibration_bundle_version, model_bundle_version, model_name),
|
||||
key idx_trader_calibration_bundle (model_bundle_version, calibration_bundle_version),
|
||||
constraint chk_trader_calibration_method check (calibration_method in ('ISOTONIC','PLATT','BINNING','NONE')),
|
||||
constraint chk_trader_calibration_manifest_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED'))
|
||||
) comment='Trader 校准器清单,说明每个小模型的原始输出如何变成可用概率或分数';
|
||||
|
||||
create table trader_pm_config_manifest (
|
||||
id bigint primary key auto_increment,
|
||||
@@ -65,6 +154,90 @@ create table trader_pm_config_manifest (
|
||||
constraint chk_trader_pm_config_status check (status in ('CANDIDATE','ACTIVE','REJECTED','RETIRED'))
|
||||
);
|
||||
|
||||
create table trader_market_snapshot (
|
||||
id bigint primary key auto_increment,
|
||||
run_id varchar(64) not null,
|
||||
cycle_id varchar(128) not null,
|
||||
snapshot_id varchar(128) not null,
|
||||
symbol varchar(32) not null,
|
||||
snapshot_time datetime(3) not null,
|
||||
feature_version varchar(96) not null,
|
||||
mark_price decimal(28,10) not null,
|
||||
index_price decimal(28,10) not null,
|
||||
spread_bps decimal(18,8) not null,
|
||||
funding_rate_bps decimal(18,8) not null,
|
||||
depth_notional_5bps decimal(28,10) not null,
|
||||
depth_notional_10bps decimal(28,10) not null,
|
||||
depth_notional_25bps decimal(28,10) not null,
|
||||
data_ready boolean not null,
|
||||
feature_json json not null,
|
||||
data_quality_json json not null,
|
||||
created_at datetime(3) not null default current_timestamp(3),
|
||||
unique key uk_trader_market_snapshot (run_id, snapshot_id),
|
||||
key idx_trader_market_snapshot_cycle (run_id, cycle_id)
|
||||
);
|
||||
|
||||
create table trader_position_state (
|
||||
id bigint primary key auto_increment,
|
||||
run_id varchar(64) not null,
|
||||
cycle_id varchar(128) not null,
|
||||
position_state_id varchar(128) not null,
|
||||
state_role varchar(32) not null,
|
||||
symbol varchar(32) not null,
|
||||
side varchar(16) not null,
|
||||
position_ratio decimal(18,8) not null,
|
||||
average_entry_price decimal(28,10) null,
|
||||
current_price decimal(28,10) not null,
|
||||
unrealized_pnl_bps decimal(18,8) not null,
|
||||
liquidation_buffer_bps decimal(18,8) not null,
|
||||
add_count int not null,
|
||||
remaining_add_capacity decimal(18,8) not null,
|
||||
last_add_time datetime(3) null,
|
||||
created_at datetime(3) not null default current_timestamp(3),
|
||||
unique key uk_trader_position_state (run_id, position_state_id),
|
||||
key idx_trader_position_latest (run_id, symbol, state_role, created_at),
|
||||
constraint chk_trader_position_side check (side in ('NONE','LONG','SHORT')),
|
||||
constraint chk_trader_position_role check (state_role in ('PM_INPUT','POST_ACTION','FEEDBACK_UPDATE'))
|
||||
);
|
||||
|
||||
create table trader_account_state (
|
||||
id bigint primary key auto_increment,
|
||||
run_id varchar(64) not null,
|
||||
cycle_id varchar(128) not null,
|
||||
account_state_id varchar(128) not null,
|
||||
state_role varchar(32) not null,
|
||||
daily_drawdown_bps decimal(18,8) not null,
|
||||
portfolio_exposure_ratio decimal(18,8) not null,
|
||||
remaining_symbol_capacity_ratio decimal(18,8) not null,
|
||||
consecutive_losses int not null,
|
||||
created_at datetime(3) not null default current_timestamp(3),
|
||||
unique key uk_trader_account_state (run_id, account_state_id),
|
||||
constraint chk_trader_account_role check (state_role in ('PM_INPUT','POST_ACTION','FEEDBACK_UPDATE'))
|
||||
);
|
||||
|
||||
create table trader_execution_state (
|
||||
id bigint primary key auto_increment,
|
||||
run_id varchar(64) not null,
|
||||
cycle_id varchar(128) not null,
|
||||
execution_state_id varchar(128) not null,
|
||||
state_role varchar(32) not null,
|
||||
symbol varchar(32) not null,
|
||||
open_orders_json json not null,
|
||||
expected_slippage_bps decimal(18,8) not null,
|
||||
exchange_latency_ms bigint not null,
|
||||
api_error_count int not null,
|
||||
maker_fee_bps decimal(18,8) not null,
|
||||
taker_fee_bps decimal(18,8) not null,
|
||||
min_notional decimal(28,10) not null,
|
||||
price_tick_size decimal(28,10) not null,
|
||||
lot_size_step_size decimal(28,10) not null,
|
||||
market_lot_size_step_size decimal(28,10) not null,
|
||||
liquidity_capacity_ratio decimal(18,8) not null,
|
||||
created_at datetime(3) not null default current_timestamp(3),
|
||||
unique key uk_trader_execution_state (run_id, execution_state_id),
|
||||
constraint chk_trader_execution_role check (state_role in ('PM_INPUT','POST_ACTION','FEEDBACK_UPDATE'))
|
||||
);
|
||||
|
||||
create table trader_model_output (
|
||||
id bigint primary key auto_increment,
|
||||
run_id varchar(64) not null,
|
||||
@@ -77,6 +250,7 @@ create table trader_model_output (
|
||||
continue_json json not null,
|
||||
exit_json json not null,
|
||||
risk_json json not null,
|
||||
metadata_json json not null,
|
||||
uncertainty decimal(18,8) not null,
|
||||
ood_score decimal(18,8) not null,
|
||||
usable boolean not null,
|
||||
|
||||
@@ -8,13 +8,17 @@ import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import com.quantai.trader.enums.TraderExecutionMode;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.replay.ReplayMarketEvent;
|
||||
import com.quantai.trader.risk.RiskLimits;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.time.Instant;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -37,14 +41,19 @@ public final class TestFixtures {
|
||||
}
|
||||
|
||||
public static TraderProperties propertiesWithArtifactRoot(Path artifactRoot) {
|
||||
return propertiesWithArtifactRoot(artifactRoot, TraderRunMode.SHADOW, TraderExecutionMode.SHADOW);
|
||||
}
|
||||
|
||||
public static TraderProperties propertiesWithArtifactRoot(Path artifactRoot, TraderRunMode runMode,
|
||||
TraderExecutionMode executionMode) {
|
||||
TraderProperties base = properties();
|
||||
return new TraderProperties(
|
||||
base.serviceName(),
|
||||
base.runMode(),
|
||||
runMode,
|
||||
base.symbol(),
|
||||
new TraderProperties.Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", artifactRoot.toString()),
|
||||
base.feedback(),
|
||||
base.execution(),
|
||||
new TraderProperties.Execution(executionMode, base.execution().maxApiErrorCount(), base.execution().maxExchangeLatencyMs()),
|
||||
base.runtime(),
|
||||
base.outbox(),
|
||||
base.release(),
|
||||
@@ -60,7 +69,7 @@ public final class TestFixtures {
|
||||
"BTC-USDT-PERP",
|
||||
new TraderProperties.Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", "/tmp/trader-v4-p0"),
|
||||
new TraderProperties.Feedback(feedbackHttpEnabled),
|
||||
new TraderProperties.Execution(executionMode, 3, 1500),
|
||||
new TraderProperties.Execution(executionMode, 3, 1500L),
|
||||
new TraderProperties.Runtime("trader:v4:test", true, tradingEnabled),
|
||||
new TraderProperties.Outbox(true, 5),
|
||||
new TraderProperties.Release(true, true, true),
|
||||
@@ -98,27 +107,85 @@ public final class TestFixtures {
|
||||
public static TraderMarketSnapshot snapshot(boolean dataReady, String depthNotional5Bps) {
|
||||
return new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1", "BTC-USDT-PERP", T0,
|
||||
"feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"), BigDecimal.ZERO,
|
||||
bd(depthNotional5Bps), bd(depthNotional5Bps), bd(depthNotional5Bps), dataReady, Map.of(), Map.of());
|
||||
bd(depthNotional5Bps), bd(depthNotional5Bps), bd(depthNotional5Bps), dataReady,
|
||||
featureJson(), dataQualityJson());
|
||||
}
|
||||
|
||||
public static ReplayMarketEvent replayEvent(String runId, Instant eventTime, String markPrice,
|
||||
String indexPrice, String depthNotional) {
|
||||
BigDecimal depth = bd(depthNotional);
|
||||
return new ReplayMarketEvent(runId, "BTC-USDT-PERP", eventTime, "feature-v4-p0",
|
||||
bd(markPrice), bd(indexPrice), bd("1.2"), bd("0.5"),
|
||||
depth, depth.multiply(new BigDecimal("1.4")), depth.multiply(new BigDecimal("2.2")),
|
||||
depth.compareTo(BigDecimal.ZERO) > 0, featureJson(), dataQualityJson());
|
||||
}
|
||||
|
||||
public static Map<String, Object> dataQualityJson() {
|
||||
return Map.of("ood_score", bd("0.05"), "data_quality_flag", "OK");
|
||||
}
|
||||
|
||||
public static Map<String, Object> featureJson() {
|
||||
Map<String, Object> features = new LinkedHashMap<>();
|
||||
features.put("ret_1m_bps", bd("1.1"));
|
||||
features.put("ret_5m_bps", bd("4.5"));
|
||||
features.put("ret_15m_bps", bd("8.2"));
|
||||
features.put("ret_60m_bps", bd("18.6"));
|
||||
features.put("ret_240m_bps", bd("42.0"));
|
||||
features.put("realized_vol_15m_bps", bd("9.4"));
|
||||
features.put("realized_vol_60m_bps", bd("12.7"));
|
||||
features.put("vol_ratio_15m_60m", bd("0.74"));
|
||||
features.put("range_15m_bps", bd("21.2"));
|
||||
features.put("range_60m_bps", bd("55.4"));
|
||||
features.put("volume_zscore_60m", bd("0.35"));
|
||||
features.put("trend_consistency_15m", bd("0.20"));
|
||||
features.put("channel_position_60m_pct", bd("0.62"));
|
||||
features.put("upper_breakout_60m_bps", bd("0.0"));
|
||||
features.put("lower_breakout_60m_bps", bd("0.0"));
|
||||
features.put("upper_failed_break_reclaim_15m_bps", bd("0.0"));
|
||||
features.put("lower_failed_break_reclaim_15m_bps", bd("0.0"));
|
||||
features.put("sweep_up_15m_bps", bd("6.8"));
|
||||
features.put("sweep_down_15m_bps", bd("5.1"));
|
||||
features.put("compression_score_4h_pct", bd("0.30"));
|
||||
features.put("compression_release_15m_bps", bd("2.4"));
|
||||
features.put("taker_imbalance_1m", bd("0.08"));
|
||||
features.put("taker_imbalance_5m", bd("0.12"));
|
||||
features.put("taker_imbalance_15m", bd("0.04"));
|
||||
features.put("level1_ofi_1m", bd("0.10"));
|
||||
features.put("spread_bps", bd("1.2"));
|
||||
features.put("spread_rank_24h_pct", bd("0.40"));
|
||||
features.put("oi_delta_15m_bps", bd("3.2"));
|
||||
features.put("oi_delta_60m_bps", bd("6.5"));
|
||||
features.put("funding_bps", bd("0.5"));
|
||||
features.put("mark_index_basis_bps", bd("5.0"));
|
||||
features.put("liquidation_buy_notional_1m", bd("1000"));
|
||||
features.put("liquidation_sell_notional_1m", bd("800"));
|
||||
features.put("liquidation_imbalance_15m", bd("0.10"));
|
||||
features.put("liquidation_notional_zscore_15m", bd("0.25"));
|
||||
features.put("liquidation_available", bd("1"));
|
||||
features.put("minute_of_day_sin", bd("0.0"));
|
||||
features.put("minute_of_day_cos", bd("1.0"));
|
||||
features.put("minutes_to_next_funding", bd("120"));
|
||||
return Map.copyOf(features);
|
||||
}
|
||||
|
||||
public static TraderModelOutput modelOutput() {
|
||||
return modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34",
|
||||
"0.20", "0.50", "0.18", "0.12", "1.0", "12", "4", "0.10", "0.05", "20");
|
||||
"0.20", "0.50", "0.18", "0.12", "12", "4", "0.10", "0.05", "20");
|
||||
}
|
||||
|
||||
public static TraderModelOutput shortModelOutput() {
|
||||
return modelOutput("0.20", "0.70", "0.30", "0.70", "0.34", "0.66",
|
||||
"0.50", "0.20", "0.18", "0.12", "1.0", "12", "4", "0.10", "0.05", "20");
|
||||
"0.50", "0.20", "0.18", "0.12", "12", "4", "0.10", "0.05", "20");
|
||||
}
|
||||
|
||||
public static TraderModelOutput modelOutput(String longProb, String shortProb,
|
||||
String longEntryProb, String shortEntryProb,
|
||||
String longContinueProb, String shortContinueProb,
|
||||
String longExitProb, String shortExitProb,
|
||||
String marketRiskProb, String positionRiskProb,
|
||||
String liquidityCapacityRatio, String expectedEdgeBps,
|
||||
String continueVsExitEdgeBps, String uncertainty,
|
||||
String oodScore, String expectedShortfallBps) {
|
||||
String marketRiskProb, String sidePositionRiskProb,
|
||||
String netEdgeBps, String continueEdgeBps,
|
||||
String uncertainty, String oodScore,
|
||||
String positionPathRiskBps) {
|
||||
BigDecimal longP = bd(longProb);
|
||||
BigDecimal shortP = bd(shortProb);
|
||||
BigDecimal neutral = BigDecimal.ONE.subtract(longP).subtract(shortP);
|
||||
@@ -126,23 +193,53 @@ public final class TestFixtures {
|
||||
"model-output-1",
|
||||
"run-1",
|
||||
"cycle-1",
|
||||
"trader-v4-btc-p0",
|
||||
"cal-v4-btc-p0",
|
||||
new DirectionOutput(longP, shortP, neutral, longP.max(shortP), longP.subtract(shortP).abs(),
|
||||
bd("8.0"), 45, "direction-p0", "cal-v4-btc-p0", Map.of()),
|
||||
new EntryOutput(bd(longEntryProb), bd(shortEntryProb), bd("0.70"), bd(expectedEdgeBps),
|
||||
"p0-plan-atr-2r", "p0-price-plan-hash", bd("35"), bd("70"), 45, bd("4.0"),
|
||||
"entry-p0", "cal-v4-btc-p0", Map.of()),
|
||||
new ContinueOutput(bd(longContinueProb), bd(shortContinueProb), bd("0.60"), bd("5.0"),
|
||||
bd(continueVsExitEdgeBps), "continue-p0", "cal-v4-btc-p0", Map.of()),
|
||||
new ExitOutput(bd(longExitProb), bd(shortExitProb), bd("0.20"), bd("0.25"), bd("0.22"),
|
||||
bd("0.20"), bd("10"), "exit-p0", "cal-v4-btc-p0", Map.of()),
|
||||
new RiskOutput(bd(marketRiskProb), bd(positionRiskProb), bd("20"), bd("18"), bd("0.15"),
|
||||
bd(expectedShortfallBps), bd("0.20"), bd("0.10"), bd("0.12"),
|
||||
bd(liquidityCapacityRatio), "risk-p0", "cal-v4-btc-p0", Map.of()),
|
||||
bd(uncertainty),
|
||||
bd(oodScore),
|
||||
Map.of("fixture", "p0"));
|
||||
modelMetadata(uncertainty, oodScore),
|
||||
new DirectionOutput(longP, shortP, neutral),
|
||||
new EntryOutput(bd(longEntryProb), bd(shortEntryProb), bd(netEdgeBps), bd(netEdgeBps)),
|
||||
new ContinueOutput(bd(longContinueProb), bd(shortContinueProb), bd(continueEdgeBps), bd(continueEdgeBps)),
|
||||
new ExitOutput(bd(longExitProb), bd(shortExitProb), bd("10"), bd("10"), exitReasonScores()),
|
||||
new RiskOutput(bd(marketRiskProb), bd(sidePositionRiskProb), bd(sidePositionRiskProb), bd("20"),
|
||||
bd(positionPathRiskBps), bd(positionPathRiskBps), riskReasonScores()));
|
||||
}
|
||||
|
||||
public static TraderModelOutputMetadata modelMetadata(String uncertainty, String oodScore) {
|
||||
Map<String, String> versions = Map.of(
|
||||
"DIRECTION", "trader-v4-btc-p0",
|
||||
"ENTRY", "trader-v4-btc-p0",
|
||||
"CONTINUE", "trader-v4-btc-p0",
|
||||
"EXIT", "trader-v4-btc-p0",
|
||||
"RISK", "trader-v4-btc-p0");
|
||||
Map<String, String> calibrationVersions = Map.of(
|
||||
"DIRECTION", "cal-v4-btc-p0",
|
||||
"ENTRY", "cal-v4-btc-p0",
|
||||
"CONTINUE", "cal-v4-btc-p0",
|
||||
"EXIT", "cal-v4-btc-p0",
|
||||
"RISK", "cal-v4-btc-p0");
|
||||
return new TraderModelOutputMetadata("trader-v4-btc-p0", "cal-v4-btc-p0",
|
||||
versions, calibrationVersions,
|
||||
"feature-schema-hash", "feature-order-hash", "output-schema-hash",
|
||||
bd(uncertainty), bd(oodScore));
|
||||
}
|
||||
|
||||
public static TraderPricePlanContext pricePlan() {
|
||||
return new TraderPricePlanContext("p0-plan-atr-2r", "p0-price-plan-hash", bd("35"), bd("70"), 45, bd("4.0"));
|
||||
}
|
||||
|
||||
public static Map<String, BigDecimal> exitReasonScores() {
|
||||
return Map.of(
|
||||
"adverse_move_prob", bd("0.20"),
|
||||
"reversal_prob", bd("0.25"),
|
||||
"stop_hit_prob", bd("0.22"),
|
||||
"stagnation_prob", bd("0.20"));
|
||||
}
|
||||
|
||||
public static Map<String, BigDecimal> riskReasonScores() {
|
||||
return Map.of(
|
||||
"market_drawdown_prob", bd("0.15"),
|
||||
"volatility_expansion_prob", bd("0.20"),
|
||||
"spike_prob", bd("0.10"),
|
||||
"liquidity_deterioration_prob", bd("0.12"),
|
||||
"position_drawdown_prob", bd("0.14"));
|
||||
}
|
||||
|
||||
public static PositionManagerInput pmInput(TraderModelOutput modelOutput, TraderPositionState positionState) {
|
||||
@@ -151,7 +248,7 @@ public final class TestFixtures {
|
||||
|
||||
public static PositionManagerInput pmInput(TraderModelOutput modelOutput, TraderPositionState positionState,
|
||||
TraderAccountState accountState, TraderExecutionState executionState) {
|
||||
return new PositionManagerInput(cycle(), snapshot(), modelOutput, positionState, accountState, executionState, pmConfig());
|
||||
return new PositionManagerInput(cycle(), snapshot(), modelOutput, pricePlan(), positionState, accountState, executionState, pmConfig());
|
||||
}
|
||||
|
||||
public static TraderPositionState flatPosition() {
|
||||
@@ -192,9 +289,14 @@ public final class TestFixtures {
|
||||
}
|
||||
|
||||
public static TraderExecutionState execution(List<OpenOrderState> openOrders, long latencyMs, int apiErrorCount) {
|
||||
return execution(openOrders, latencyMs, apiErrorCount, "1.0");
|
||||
}
|
||||
|
||||
public static TraderExecutionState execution(List<OpenOrderState> openOrders, long latencyMs, int apiErrorCount,
|
||||
String liquidityCapacityRatio) {
|
||||
return new TraderExecutionState("execution-state-1", "run-1", "cycle-1", "BTC-USDT-PERP",
|
||||
openOrders, bd("1.5"), latencyMs, apiErrorCount, bd("1"), bd("4"), bd("5"),
|
||||
bd("0.1"), bd("0.001"), bd("0.001"), BigDecimal.ONE);
|
||||
bd("0.1"), bd("0.001"), bd("0.001"), bd(liquidityCapacityRatio));
|
||||
}
|
||||
|
||||
public static TraderPositionManagerDecision pmDecision(TraderActionType action, PositionSide side) {
|
||||
@@ -213,31 +315,243 @@ public final class TestFixtures {
|
||||
}
|
||||
|
||||
public static RiskLimits riskLimits() {
|
||||
return new RiskLimits(bd("200"), BigDecimal.ONE, bd("500"), 3, 1500, false, false);
|
||||
return new RiskLimits(bd("200"), BigDecimal.ONE, bd("500"), 3, 1500, false, false, null);
|
||||
}
|
||||
|
||||
public static void writeArtifactBundle(Path artifactRoot) throws IOException {
|
||||
Files.createDirectories(artifactRoot.resolve("manifests"));
|
||||
Files.createDirectories(artifactRoot.resolve("models"));
|
||||
Files.createDirectories(artifactRoot.resolve("calibrators"));
|
||||
Files.createDirectories(artifactRoot.resolve("schemas"));
|
||||
Files.createDirectories(artifactRoot.resolve("examples"));
|
||||
String directionHash = writeArtifact(artifactRoot.resolve("models/direction.json"), "{\"model\":\"DIRECTION\"}");
|
||||
String entryHash = writeArtifact(artifactRoot.resolve("models/entry.json"), "{\"model\":\"ENTRY\"}");
|
||||
String continueHash = writeArtifact(artifactRoot.resolve("models/continue.json"), "{\"model\":\"CONTINUE\"}");
|
||||
String exitHash = writeArtifact(artifactRoot.resolve("models/exit.json"), "{\"model\":\"EXIT\"}");
|
||||
String riskHash = writeArtifact(artifactRoot.resolve("models/risk.json"), "{\"model\":\"RISK\"}");
|
||||
String featureSchemaHash = writeArtifact(artifactRoot.resolve("schemas/features.json"), "{\"features\":[\"mark_price\",\"index_price\"]}");
|
||||
String featureOrderHash = writeArtifact(artifactRoot.resolve("schemas/feature_order.json"), """
|
||||
[
|
||||
"ret_1m_bps","ret_5m_bps","ret_15m_bps","ret_60m_bps","ret_240m_bps",
|
||||
"realized_vol_15m_bps","realized_vol_60m_bps","vol_ratio_15m_60m",
|
||||
"range_15m_bps","range_60m_bps","volume_zscore_60m","trend_consistency_15m",
|
||||
"channel_position_60m_pct","upper_breakout_60m_bps","lower_breakout_60m_bps",
|
||||
"upper_failed_break_reclaim_15m_bps","lower_failed_break_reclaim_15m_bps",
|
||||
"sweep_up_15m_bps","sweep_down_15m_bps","compression_score_4h_pct",
|
||||
"compression_release_15m_bps","taker_imbalance_1m","taker_imbalance_5m",
|
||||
"taker_imbalance_15m","level1_ofi_1m","spread_bps","spread_rank_24h_pct",
|
||||
"oi_delta_15m_bps","oi_delta_60m_bps","funding_bps","mark_index_basis_bps",
|
||||
"liquidation_buy_notional_1m","liquidation_sell_notional_1m",
|
||||
"liquidation_imbalance_15m","liquidation_notional_zscore_15m",
|
||||
"liquidation_available","minute_of_day_sin","minute_of_day_cos",
|
||||
"minutes_to_next_funding"
|
||||
]
|
||||
""");
|
||||
String outputSchemaHash = writeArtifact(artifactRoot.resolve("schemas/outputs.json"), """
|
||||
{
|
||||
"direction": {
|
||||
"longProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"neutralProb": {"type": "decimal", "range": [0.0, 1.0]}
|
||||
},
|
||||
"entry": {
|
||||
"longEntryProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortEntryProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longExpectedNetEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]},
|
||||
"shortExpectedNetEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]}
|
||||
},
|
||||
"continue": {
|
||||
"longContinueProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortContinueProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longExpectedContinueEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]},
|
||||
"shortExpectedContinueEdgeBps": {"type": "decimal", "range": [-1000.0, 1000.0]}
|
||||
},
|
||||
"exit": {
|
||||
"longExitProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortExitProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longAdverseMoveBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"shortAdverseMoveBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"exitReasonScores": {
|
||||
"adverse_move_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"reversal_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"stop_hit_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"stagnation_prob": {"type": "decimal", "range": [0.0, 1.0]}
|
||||
}
|
||||
},
|
||||
"risk": {
|
||||
"marketRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"marketPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"longPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"shortPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"riskReasonScores": {
|
||||
"market_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"volatility_expansion_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"spike_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"liquidity_deterioration_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"position_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]}
|
||||
}
|
||||
}
|
||||
}
|
||||
""");
|
||||
writeArtifact(artifactRoot.resolve("examples/input.json"), "{\"mark_price\":100,\"index_price\":99.5}");
|
||||
String directionCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/direction.json"),
|
||||
calibratorJson("direction-cal-p0", Map.of("longProb", "0.62", "shortProb", "0.28", "neutralProb", "0.10")));
|
||||
String entryCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/entry.json"),
|
||||
calibratorJson("entry-cal-p0", Map.of("longEntryProb", "0.63", "shortEntryProb", "0.42")));
|
||||
String continueCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/continue.json"),
|
||||
calibratorJson("continue-cal-p0", Map.of("longContinueProb", "0.61", "shortContinueProb", "0.39")));
|
||||
String exitCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/exit.json"),
|
||||
calibratorJson("exit-cal-p0", Map.of(
|
||||
"longExitProb", "0.24", "shortExitProb", "0.48",
|
||||
"adverse_move_prob", "0.20", "reversal_prob", "0.25",
|
||||
"stop_hit_prob", "0.22", "stagnation_prob", "0.20")));
|
||||
String riskCalibratorHash = writeArtifact(artifactRoot.resolve("calibrators/risk.json"),
|
||||
calibratorJson("risk-cal-p0", Map.of(
|
||||
"marketRiskProb", "0.20", "longPositionRiskProb", "0.18", "shortPositionRiskProb", "0.28",
|
||||
"market_drawdown_prob", "0.15", "volatility_expansion_prob", "0.20",
|
||||
"spike_prob", "0.10", "liquidity_deterioration_prob", "0.12",
|
||||
"position_drawdown_prob", "0.14")));
|
||||
Files.writeString(artifactRoot.resolve("manifests/model_bundle_manifest.json"), """
|
||||
{
|
||||
"manifest_schema_version": "trader-model-bundle-v1",
|
||||
"model_bundle_version": "trader-v4-btc-p0",
|
||||
"calibration_bundle_version": "cal-v4-btc-p0",
|
||||
"feature_version": "feature-v4-p0",
|
||||
"label_version": "label-v4-p0",
|
||||
"split_version": "split-v4-p0",
|
||||
"training_run_id": "train-run-p0",
|
||||
"training_export_id": "export-p0",
|
||||
"backtest_manifest_id": "backtest-p0",
|
||||
"required_models_json": ["DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"],
|
||||
"provided_models_json": ["DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"],
|
||||
"missing_models_json": [],
|
||||
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
|
||||
"bundle_hash_sha256": "00000000000000000000000000000000000000000000000000000000000000a1",
|
||||
"complete": true,
|
||||
"status": "ACTIVE"
|
||||
}
|
||||
""");
|
||||
Files.writeString(artifactRoot.resolve("manifests/model_manifest.json"), """
|
||||
[
|
||||
{
|
||||
"model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0",
|
||||
"model_name":"DIRECTION","model_type":"DIRECTION","side":"BOTH",
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["probabilities"],
|
||||
"output_mapping_json":{"long_prob":"probabilities[0]","short_prob":"probabilities[1]","neutral_prob":"probabilities[2]"},
|
||||
"output_value_rules_json":{"probability":"[0,1]"},
|
||||
"label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0",
|
||||
"train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z",
|
||||
"validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z",
|
||||
"test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z",
|
||||
"metrics_json":{"sample_count":1000},"artifact_path":"models/direction.json",
|
||||
"artifact_hash_sha256":"%1$s","source_hash":"source-direction","status":"ACTIVE"
|
||||
},
|
||||
{
|
||||
"model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0",
|
||||
"model_name":"ENTRY","model_type":"ENTRY","side":"BOTH",
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["entry"],
|
||||
"output_mapping_json":{"long_entry_prob":"entry[0]","short_entry_prob":"entry[1]","long_expected_net_edge_bps":"entry[2]","short_expected_net_edge_bps":"entry[3]"},
|
||||
"output_value_rules_json":{"probability":"[0,1]"},
|
||||
"label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0",
|
||||
"train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z",
|
||||
"validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z",
|
||||
"test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z",
|
||||
"metrics_json":{"sample_count":1000},"artifact_path":"models/entry.json",
|
||||
"artifact_hash_sha256":"%2$s","source_hash":"source-entry","status":"ACTIVE"
|
||||
},
|
||||
{
|
||||
"model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0",
|
||||
"model_name":"CONTINUE","model_type":"CONTINUE","side":"BOTH",
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["continue"],
|
||||
"output_mapping_json":{"long_continue_prob":"continue[0]","short_continue_prob":"continue[1]","long_expected_continue_edge_bps":"continue[2]","short_expected_continue_edge_bps":"continue[3]"},
|
||||
"output_value_rules_json":{"probability":"[0,1]"},
|
||||
"label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0",
|
||||
"train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z",
|
||||
"validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z",
|
||||
"test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z",
|
||||
"metrics_json":{"sample_count":1000},"artifact_path":"models/continue.json",
|
||||
"artifact_hash_sha256":"%3$s","source_hash":"source-continue","status":"ACTIVE"
|
||||
},
|
||||
{
|
||||
"model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0",
|
||||
"model_name":"EXIT","model_type":"EXIT","side":"BOTH",
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["exit"],
|
||||
"output_mapping_json":{"long_exit_prob":"exit[0]","short_exit_prob":"exit[1]","long_adverse_move_bps":"exit[2]","short_adverse_move_bps":"exit[3]","adverse_move_prob":"exit[4]","reversal_prob":"exit[5]","stop_hit_prob":"exit[6]","stagnation_prob":"exit[7]"},
|
||||
"output_value_rules_json":{"probability":"[0,1]"},
|
||||
"label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0",
|
||||
"train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z",
|
||||
"validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z",
|
||||
"test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z",
|
||||
"metrics_json":{"sample_count":1000},"artifact_path":"models/exit.json",
|
||||
"artifact_hash_sha256":"%4$s","source_hash":"source-exit","status":"ACTIVE"
|
||||
},
|
||||
{
|
||||
"model_bundle_version":"trader-v4-btc-p0","calibration_bundle_version":"cal-v4-btc-p0",
|
||||
"model_name":"RISK","model_type":"RISK","side":"BOTH",
|
||||
"symbol_scope_json":["BTC-USDT-PERP"],"bar_interval":"1m","horizon_minutes":45,
|
||||
"model_format":"ONNX","model_runtime":"ONNX_RUNTIME_JAVA","model_runtime_version":"1.22.0",
|
||||
"onnx_opset_version":17,"producer_name":"test-exporter","producer_version":"p0",
|
||||
"feature_version":"feature-v4-p0","feature_schema_path":"schemas/features.json",
|
||||
"feature_schema_hash":"%6$s","feature_order_path":"schemas/feature_order.json","feature_order_hash":"%8$s",
|
||||
"input_tensor_name":"features","input_dtype":"FLOAT32","input_shape_json":{"batch":1,"features":39},
|
||||
"input_example_path":"examples/input.json","output_schema_path":"schemas/outputs.json",
|
||||
"output_schema_hash":"%7$s","output_tensor_names_json":["risk"],
|
||||
"output_mapping_json":{"market_risk_prob":"risk[0]","long_position_risk_prob":"risk[1]","short_position_risk_prob":"risk[2]","market_path_risk_bps":"risk[3]","long_position_path_risk_bps":"risk[4]","short_position_path_risk_bps":"risk[5]","market_drawdown_prob":"risk[6]","volatility_expansion_prob":"risk[7]","spike_prob":"risk[8]","liquidity_deterioration_prob":"risk[9]","position_drawdown_prob":"risk[10]"},
|
||||
"output_value_rules_json":{"probability":"[0,1]"},
|
||||
"label_version":"label-v4-p0","split_version":"split-v4-p0","training_fold":"fold-0",
|
||||
"train_start":"2026-01-01T00:00:00Z","train_end":"2026-03-01T00:00:00Z",
|
||||
"validation_start":"2026-03-01T00:00:00Z","validation_end":"2026-04-01T00:00:00Z",
|
||||
"test_start":"2026-04-01T00:00:00Z","test_end":"2026-05-01T00:00:00Z",
|
||||
"metrics_json":{"sample_count":1000},"artifact_path":"models/risk.json",
|
||||
"artifact_hash_sha256":"%5$s","source_hash":"source-risk","status":"ACTIVE"
|
||||
}
|
||||
]
|
||||
""".formatted(directionHash, entryHash, continueHash, exitHash, riskHash, featureSchemaHash, outputSchemaHash, featureOrderHash));
|
||||
Files.writeString(artifactRoot.resolve("manifests/calibration_manifest.json"), """
|
||||
[
|
||||
{"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"DIRECTION","calibrator_version":"direction-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/direction.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"},
|
||||
{"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"ENTRY","calibrator_version":"entry-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/entry.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"},
|
||||
{"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"CONTINUE","calibrator_version":"continue-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/continue.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"},
|
||||
{"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"EXIT","calibrator_version":"exit-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/exit.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"},
|
||||
{"calibration_bundle_version":"cal-v4-btc-p0","model_bundle_version":"trader-v4-btc-p0","model_name":"RISK","calibrator_version":"risk-cal-p0","calibration_method":"BINNING","calibrator_path":"calibrators/risk.json","calibrator_hash_sha256":"%s","calibration_window_from":"2026-03-01T00:00:00Z","calibration_window_to":"2026-04-01T00:00:00Z","calibration_metrics_json":{},"bucket_metrics_json":{},"output_after_calibration_schema_hash":"output-cal-hash","status":"ACTIVE"}
|
||||
]
|
||||
""".formatted(directionCalibratorHash, entryCalibratorHash, continueCalibratorHash, exitCalibratorHash, riskCalibratorHash));
|
||||
Files.writeString(artifactRoot.resolve("manifests/position_manager_manifest.json"), """
|
||||
{
|
||||
"pm_config_version": "pm-v4-btc-p0",
|
||||
"model_bundle_version": "trader-v4-btc-p0",
|
||||
"calibration_bundle_version": "cal-v4-btc-p0",
|
||||
"threshold_stability_json": {},
|
||||
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
|
||||
"config_hash_sha256": "00000000000000000000000000000000000000000000000000000000000000b1",
|
||||
"status": "ACTIVE",
|
||||
@@ -274,11 +588,11 @@ public final class TestFixtures {
|
||||
"closePositionRiskProb": 0.70,
|
||||
"closeMarketRiskProb": 0.70,
|
||||
"closeContinueMax": 0.25,
|
||||
"reduceGivebackProb": 0.62,
|
||||
"reduceAdverseMoveProb": 0.62,
|
||||
"reduceContinueMin": 0.35,
|
||||
"reduceContinueMax": 0.70,
|
||||
"minProfitForReduceBps": 5.0,
|
||||
"maxExpectedShortfallBps": 80
|
||||
"maxPositionPathRiskBps": 80
|
||||
},
|
||||
"sizing": {
|
||||
"baseRatio": 0.80,
|
||||
@@ -296,64 +610,110 @@ public final class TestFixtures {
|
||||
}
|
||||
}
|
||||
""");
|
||||
Files.writeString(artifactRoot.resolve("model_output_policy.json"), """
|
||||
Files.writeString(artifactRoot.resolve("price_plan_context.json"), """
|
||||
{
|
||||
"pricePlanId": "p0-plan-atr-2r",
|
||||
"pricePlanConfigHash": "p0-price-plan-hash",
|
||||
"stopDistanceBps": 35,
|
||||
"targetDistanceBps": 70,
|
||||
"maxHoldMinutes": 45,
|
||||
"costBps": 4.0
|
||||
}
|
||||
""");
|
||||
Files.writeString(artifactRoot.resolve("replay_model_fixture.json"), """
|
||||
{
|
||||
"direction": {
|
||||
"longProbWhenMarkGteIndex": 0.62,
|
||||
"longProbWhenMarkLtIndex": 0.32,
|
||||
"neutralProb": 0.10,
|
||||
"expectedReturnBps": 8.0,
|
||||
"horizonMinutes": 45,
|
||||
"modelVersion": "direction-p0"
|
||||
"neutralProb": 0.10
|
||||
},
|
||||
"entry": {
|
||||
"longEntryProb": 0.63,
|
||||
"shortEntryProb": 0.42,
|
||||
"entryQualityScore": 0.64,
|
||||
"expectedEdgeBps": 12.0,
|
||||
"pricePlanId": "p0-plan-atr-2r",
|
||||
"pricePlanConfigHash": "p0-price-plan-hash",
|
||||
"stopDistanceBps": 35,
|
||||
"targetDistanceBps": 70,
|
||||
"maxHoldMinutes": 45,
|
||||
"costBps": 4.0,
|
||||
"modelVersion": "entry-p0"
|
||||
"longExpectedNetEdgeBps": 12.0,
|
||||
"shortExpectedNetEdgeBps": 8.0
|
||||
},
|
||||
"continuation": {
|
||||
"longContinueProb": 0.61,
|
||||
"shortContinueProb": 0.39,
|
||||
"trendPersistenceProb": 0.58,
|
||||
"holdEdgeBps": 5.0,
|
||||
"continueVsExitEdgeBps": 3.0,
|
||||
"modelVersion": "continue-p0"
|
||||
"longExpectedContinueEdgeBps": 3.0,
|
||||
"shortExpectedContinueEdgeBps": 1.5
|
||||
},
|
||||
"exit": {
|
||||
"longExitProb": 0.24,
|
||||
"shortExitProb": 0.48,
|
||||
"profitGivebackProb": 0.20,
|
||||
"reversalProb": 0.25,
|
||||
"stopRiskProb": 0.22,
|
||||
"stagnationProb": 0.20,
|
||||
"expectedGivebackBps": 10,
|
||||
"modelVersion": "exit-p0"
|
||||
"longAdverseMoveBps": 10,
|
||||
"shortAdverseMoveBps": 18,
|
||||
"exitReasonScores": {
|
||||
"adverse_move_prob": 0.20,
|
||||
"reversal_prob": 0.25,
|
||||
"stop_hit_prob": 0.22,
|
||||
"stagnation_prob": 0.20
|
||||
}
|
||||
},
|
||||
"risk": {
|
||||
"marketRiskProb": 0.20,
|
||||
"positionRiskProb": 0.18,
|
||||
"marketRiskSeverityBps": 20,
|
||||
"positionRiskSeverityBps": 18,
|
||||
"drawdownProb": 0.15,
|
||||
"expectedShortfallBps": 20,
|
||||
"volatilityExpansionProb": 0.20,
|
||||
"spikeProb": 0.10,
|
||||
"liquidityRiskProb": 0.12,
|
||||
"liquidityCapacityRatioWhenReady": 1.0,
|
||||
"liquidityCapacityRatioWhenNotReady": 0,
|
||||
"modelVersion": "risk-p0"
|
||||
"longPositionRiskProb": 0.18,
|
||||
"shortPositionRiskProb": 0.28,
|
||||
"marketPathRiskBps": 20,
|
||||
"longPositionPathRiskBps": 20,
|
||||
"shortPositionPathRiskBps": 30,
|
||||
"riskReasonScores": {
|
||||
"market_drawdown_prob": 0.15,
|
||||
"volatility_expansion_prob": 0.20,
|
||||
"spike_prob": 0.10,
|
||||
"liquidity_deterioration_prob": 0.12,
|
||||
"position_drawdown_prob": 0.14
|
||||
}
|
||||
},
|
||||
"uncertainty": 0.10,
|
||||
"oodScore": 0.05
|
||||
"oodScore": 0.05,
|
||||
"featureSchemaHash": "feature-schema-hash",
|
||||
"featureOrderHash": "feature-order-hash",
|
||||
"outputSchemaHash": "output-schema-hash"
|
||||
}
|
||||
""");
|
||||
}
|
||||
|
||||
private static String calibratorJson(String version, Map<String, String> targets) {
|
||||
StringBuilder targetJson = new StringBuilder();
|
||||
int index = 0;
|
||||
for (Map.Entry<String, String> entry : targets.entrySet()) {
|
||||
if (index++ > 0) {
|
||||
targetJson.append(",");
|
||||
}
|
||||
targetJson.append("\"").append(entry.getKey()).append("\":")
|
||||
.append("{\"bins\":[{\"min\":0.0,\"max\":1.0,\"calibrated\":")
|
||||
.append(entry.getValue())
|
||||
.append("}]}");
|
||||
}
|
||||
return """
|
||||
{
|
||||
"calibrator_version": "%s",
|
||||
"method": "BINNING",
|
||||
"targets": {%s},
|
||||
"clip": {"min": 0.0, "max": 1.0},
|
||||
"fail_policy": "FAIL_FAST"
|
||||
}
|
||||
""".formatted(version, targetJson);
|
||||
}
|
||||
|
||||
private static String writeArtifact(Path path, String content) throws IOException {
|
||||
Files.writeString(path, content);
|
||||
return sha256(content);
|
||||
}
|
||||
|
||||
private static String sha256(String content) {
|
||||
try {
|
||||
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
||||
byte[] hash = digest.digest(content.getBytes(java.nio.charset.StandardCharsets.UTF_8));
|
||||
StringBuilder builder = new StringBuilder(hash.length * 2);
|
||||
for (byte value : hash) {
|
||||
builder.append(String.format("%02x", value & 0xff));
|
||||
}
|
||||
return builder.toString();
|
||||
} catch (NoSuchAlgorithmException exception) {
|
||||
throw new IllegalStateException("SHA-256 is not available", exception);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+35
@@ -0,0 +1,35 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.contains;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class JdbcTraderArtifactManifestRepositoryTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void upsertsModelCalibrationAndPmManifestsForActiveBundle() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
TraderArtifactBundle bundle = new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper())
|
||||
.loadActiveBundle();
|
||||
JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class);
|
||||
JdbcTraderArtifactManifestRepository repository = new JdbcTraderArtifactManifestRepository(jdbcTemplate, objectMapper());
|
||||
|
||||
repository.upsertActiveBundle(bundle);
|
||||
|
||||
verify(jdbcTemplate).update(contains("insert into trader_model_bundle_manifest"), any(Object[].class));
|
||||
verify(jdbcTemplate, times(5)).update(contains("insert into trader_model_manifest"), any(Object[].class));
|
||||
verify(jdbcTemplate, times(5)).update(contains("insert into trader_calibration_manifest"), any(Object[].class));
|
||||
verify(jdbcTemplate).update(contains("insert into trader_pm_config_manifest"), any(Object[].class));
|
||||
verifyNoMoreInteractions(jdbcTemplate);
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.model.ReplayFixtureTraderModelService;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -24,7 +26,25 @@ class TraderArtifactLoaderTest {
|
||||
|
||||
assertThat(bundle.providedModels()).containsExactlyInAnyOrder("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
|
||||
assertThat(bundle.pmConfig().pmConfigVersion()).isEqualTo("pm-v4-btc-p0");
|
||||
assertThat(bundle.modelPolicy().entry().pricePlanConfigHash()).isEqualTo("p0-price-plan-hash");
|
||||
assertThat(bundle.pricePlanContext().pricePlanConfigHash()).isEqualTo("p0-price-plan-hash");
|
||||
assertThat(bundle.modelManifests()).allSatisfy(manifest -> {
|
||||
assertThat(manifest.featureOrderPath()).isEqualTo("schemas/feature_order.json");
|
||||
assertThat(manifest.inputShapeJson()).containsEntry("features", 39);
|
||||
assertThat(manifest.onnxOpsetVersion()).isEqualTo(17);
|
||||
});
|
||||
assertThat(bundle.requireReplayModelFixture().entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12.0");
|
||||
assertThat(bundle.requireReplayModelFixture().entry().shortExpectedNetEdgeBps()).isEqualByComparingTo("8.0");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shadowModeCannotUseReplayFixtureInference() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
TraderProperties properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
TraderArtifactBundle bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
|
||||
assertThatThrownBy(() -> new ReplayFixtureTraderModelService(properties).evaluate(snapshot(), bundle))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("only allows REPLAY_SIM");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -41,7 +61,9 @@ class TraderArtifactLoaderTest {
|
||||
|
||||
assertThatThrownBy(() -> new TraderArtifactBundle(
|
||||
bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion(),
|
||||
bundle.bundleHashSha256(), Set.of("DIRECTION", "ENTRY", "CONTINUE", "RISK"), bundle.pmConfig(), bundle.modelPolicy()))
|
||||
bundle.bundleHashSha256(), Set.of("DIRECTION", "ENTRY", "CONTINUE", "RISK"),
|
||||
bundle.modelBundleManifest(), bundle.modelManifests(), bundle.calibrationManifests(),
|
||||
bundle.pmConfigManifest(), bundle.pmConfig(), bundle.pricePlanContext(), bundle.replayModelFixture()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("all five");
|
||||
}
|
||||
@@ -52,4 +74,37 @@ class TraderArtifactLoaderTest {
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("artifact file is missing");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsMissingFeatureOrderArtifact() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
Files.delete(artifactRoot.resolve("schemas/feature_order.json"));
|
||||
|
||||
assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle())
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("artifact referenced by manifest is missing");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsNonV4InputShape() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
Path manifest = artifactRoot.resolve("manifests/model_manifest.json");
|
||||
Files.writeString(manifest, Files.readString(manifest).replace("\"features\":39", "\"features\":38"));
|
||||
|
||||
assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle())
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("runtime contract is invalid");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsLegacyOutputMappingKeys() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
Path manifest = artifactRoot.resolve("manifests/model_manifest.json");
|
||||
Files.writeString(manifest, Files.readString(manifest)
|
||||
.replace("\"long_expected_net_edge_bps\":\"entry[2]\"", "\"expected_net_edge_bps\":\"entry[2]\""));
|
||||
|
||||
assertThatThrownBy(() -> new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()).loadActiveBundle())
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("rejected legacy key");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.quantai.trader.config;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.time.Instant;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class JacksonConfigTest {
|
||||
@Test
|
||||
void createsObjectMapperWithJavaTimeSupport() throws Exception {
|
||||
var mapper = new JacksonConfig().traderObjectMapper();
|
||||
|
||||
Instant value = mapper.readValue("\"2026-06-26T00:00:00Z\"", Instant.class);
|
||||
|
||||
assertThat(value).isEqualTo(Instant.parse("2026-06-26T00:00:00Z"));
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import com.quantai.trader.enums.FeedbackSource;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import com.quantai.trader.enums.TraderExecutionMode;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.feedback.TraderFeedbackRepository;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
@@ -21,7 +22,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
class TraderControllerTest {
|
||||
@Test
|
||||
void feedbackEndpointRejectsWhenHttpFeedbackIsDisabled() {
|
||||
TraderFeedbackController controller = new TraderFeedbackController(properties(), new FeedbackValidator());
|
||||
RecordingFeedbackRepository repository = new RecordingFeedbackRepository();
|
||||
TraderFeedbackController controller = new TraderFeedbackController(properties(), new FeedbackValidator(), repository);
|
||||
|
||||
assertThatThrownBy(() -> controller.feedback(feedback(FeedbackSource.SHADOW_APP, false)))
|
||||
.isInstanceOf(TraderException.class)
|
||||
@@ -31,7 +33,8 @@ class TraderControllerTest {
|
||||
@Test
|
||||
void feedbackEndpointRejectsPaperAppSourceInP0EvenWhenHttpIsEnabledForTest() {
|
||||
TraderFeedbackController controller = new TraderFeedbackController(
|
||||
properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator());
|
||||
properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator(),
|
||||
new RecordingFeedbackRepository());
|
||||
|
||||
assertThatThrownBy(() -> controller.feedback(feedback(FeedbackSource.PAPER_APP, true)))
|
||||
.isInstanceOf(TraderException.class)
|
||||
@@ -40,12 +43,15 @@ class TraderControllerTest {
|
||||
|
||||
@Test
|
||||
void feedbackEndpointAcceptsShadowRecorderFeedbackWhenExplicitlyEnabled() {
|
||||
RecordingFeedbackRepository repository = new RecordingFeedbackRepository();
|
||||
TraderFeedbackController controller = new TraderFeedbackController(
|
||||
properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator());
|
||||
properties(TraderRunMode.SHADOW, TraderExecutionMode.SHADOW, false, true), new FeedbackValidator(),
|
||||
repository);
|
||||
|
||||
Map<String, Object> result = controller.feedback(feedback(FeedbackSource.SHADOW_APP, false));
|
||||
|
||||
assertThat(result).containsEntry("accepted", true).containsEntry("feedbackId", "feedback-1");
|
||||
assertThat(repository.items()).containsExactly(feedback(FeedbackSource.SHADOW_APP, false));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -83,4 +89,17 @@ class TraderControllerTest {
|
||||
null,
|
||||
Map.of());
|
||||
}
|
||||
|
||||
private static final class RecordingFeedbackRepository implements TraderFeedbackRepository {
|
||||
private final java.util.List<TraderAppFeedback> items = new java.util.ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void insert(TraderAppFeedback feedback) {
|
||||
items.add(feedback);
|
||||
}
|
||||
|
||||
java.util.List<TraderAppFeedback> items() {
|
||||
return items;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package com.quantai.trader.controller;
|
||||
|
||||
import com.quantai.trader.domain.TraderAction;
|
||||
import com.quantai.trader.domain.TraderRiskDecision;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import com.quantai.trader.replay.ReplayMarketEvent;
|
||||
import com.quantai.trader.replay.TraderCycleResult;
|
||||
import com.quantai.trader.replay.TraderP0CycleRunner;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.pmDecision;
|
||||
import static com.quantai.trader.TestFixtures.replayEvent;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class TraderReplayControllerTest {
|
||||
@Test
|
||||
void delegatesReplayCycleRequestToRunner() {
|
||||
TraderP0CycleRunner runner = mock(TraderP0CycleRunner.class);
|
||||
TraderReplayController controller = new TraderReplayController(runner);
|
||||
ReplayMarketEvent event = replayEvent("run-1", com.quantai.trader.TestFixtures.T0, "100", "99.5", "1000");
|
||||
TraderCycleResult expected = result();
|
||||
when(runner.runCycle(event)).thenReturn(expected);
|
||||
|
||||
TraderCycleResult actual = controller.runOneCycle(event);
|
||||
|
||||
assertThat(actual).isSameAs(expected);
|
||||
verify(runner).runCycle(event);
|
||||
}
|
||||
|
||||
private TraderCycleResult result() {
|
||||
TraderRiskDecision riskDecision = new TraderRiskDecision("risk-1", "run-1", "cycle-1",
|
||||
"pm-cycle-1", true, TraderActionType.WAIT, TraderActionType.WAIT, null, Map.of());
|
||||
TraderAction action = new TraderAction("action-1", "run-1", "cycle-1", "model-output-1",
|
||||
"pm-cycle-1", "risk-1", TraderActionType.WAIT, "BTC-USDT-PERP", PositionSide.NONE,
|
||||
null, null, null, null, null, null, false, "idem-1", "WAIT", Map.of());
|
||||
return new TraderCycleResult("run-1", "cycle-1", pmDecision(TraderActionType.WAIT, PositionSide.NONE),
|
||||
riskDecision, action);
|
||||
}
|
||||
}
|
||||
@@ -16,8 +16,7 @@ class ModelOutputContractTest {
|
||||
@Test
|
||||
void rejectsProbabilityOutsideClosedUnitRange() {
|
||||
assertThatThrownBy(() -> new DirectionOutput(
|
||||
bd("1.01"), BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ONE, BigDecimal.ONE,
|
||||
bd("1"), 45, "direction-p0", "cal-v4-btc-p0", Map.of()))
|
||||
bd("1.01"), BigDecimal.ZERO, BigDecimal.ZERO))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("direction.longProb");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.quantai.trader.evidence;
|
||||
|
||||
import com.quantai.trader.domain.TraderEvidence;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.objectMapper;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
class JdbcTraderEvidenceRepositoryTest {
|
||||
@Test
|
||||
void insertsEvidenceWithJsonDetails() {
|
||||
JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class);
|
||||
JdbcTraderEvidenceRepository repository = new JdbcTraderEvidenceRepository(jdbcTemplate, objectMapper());
|
||||
TraderEvidence evidence = new TraderEvidence("evidence-1", "run-1", "cycle-1", "MODEL_OUTPUT",
|
||||
true, "MODEL_EVALUATED", null, Instant.parse("2026-06-26T00:00:00Z"),
|
||||
Map.of("modelOutputId", "model-output-1"));
|
||||
|
||||
repository.insert(evidence);
|
||||
|
||||
verify(jdbcTemplate).update(anyString(), any(Object[].class));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package com.quantai.trader.feature;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class TraderFeatureVectorBuilderTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void buildsFeatureVectorByArtifactFeatureOrder() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
TraderFeatureVectorBuilder builder = new TraderFeatureVectorBuilder(properties, objectMapper());
|
||||
|
||||
float[] values = builder.build(snapshot(), bundle);
|
||||
|
||||
assertThat(values).hasSize(39);
|
||||
assertThat(values[0]).isEqualTo(1.1f);
|
||||
assertThat(values[38]).isEqualTo(120.0f);
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsMissingFeatureInsteadOfFillingZero() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
TraderFeatureVectorBuilder builder = new TraderFeatureVectorBuilder(properties, objectMapper());
|
||||
Map<String, Object> features = new LinkedHashMap<>(featureJson());
|
||||
features.remove("ret_1m_bps");
|
||||
TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson());
|
||||
|
||||
assertThatThrownBy(() -> builder.build(badSnapshot, bundle))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("snapshot feature is missing");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsRuntimeFieldsMixedIntoModelFeatures() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
TraderFeatureVectorBuilder builder = new TraderFeatureVectorBuilder(properties, objectMapper());
|
||||
Map<String, Object> features = new LinkedHashMap<>(featureJson());
|
||||
features.put("account_balance", bd("1000"));
|
||||
TraderMarketSnapshot badSnapshot = new TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, features, dataQualityJson());
|
||||
|
||||
assertThatThrownBy(() -> builder.build(badSnapshot, bundle))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("outside feature_order");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.quantai.trader.feedback;
|
||||
|
||||
import com.quantai.trader.domain.TraderAppFeedback;
|
||||
import com.quantai.trader.enums.FeedbackSource;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.T0;
|
||||
import static com.quantai.trader.TestFixtures.objectMapper;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.contains;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
class JdbcTraderFeedbackRepositoryTest {
|
||||
@Test
|
||||
void insertsFeedbackWithSourceFillFlagsAndTimestamps() {
|
||||
JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class);
|
||||
JdbcTraderFeedbackRepository repository = new JdbcTraderFeedbackRepository(jdbcTemplate, objectMapper());
|
||||
TraderAppFeedback feedback = new TraderAppFeedback(
|
||||
"feedback-1", "run-1", "cycle-1", "action-1",
|
||||
FeedbackSource.SHADOW_APP, false, null, "RECORDED", T0, T0.plusMillis(10),
|
||||
null, null, null, null, null, null, Map.of("destination", "SHADOW_RECORDER"));
|
||||
|
||||
repository.insert(feedback);
|
||||
|
||||
ArgumentCaptor<Object[]> args = ArgumentCaptor.forClass(Object[].class);
|
||||
verify(jdbcTemplate).update(contains("insert into trader_app_feedback"), args.capture());
|
||||
assertThat(args.getValue()[4]).isEqualTo("SHADOW_APP");
|
||||
assertThat(args.getValue()[5]).isEqualTo(false);
|
||||
assertThat(args.getValue()[7]).isEqualTo("RECORDED");
|
||||
assertThat(args.getValue()[8]).isEqualTo(Timestamp.from(T0));
|
||||
assertThat(args.getValue()[9]).isEqualTo(Timestamp.from(T0.plusMillis(10)));
|
||||
assertThat(args.getValue()[16].toString()).contains("SHADOW_RECORDER");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.artifact.TraderModelManifest;
|
||||
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class OnnxTraderModelServiceTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void evaluatesFiveModelFamiliesWithFeatureOrderCalibrationAndOutputBounds() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
OnnxTraderModelService service = new OnnxTraderModelService(
|
||||
properties,
|
||||
objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
new FakeInferenceClient());
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
|
||||
assertThat(output.entry().longExpectedNetEdgeBps()).isEqualByComparingTo("12");
|
||||
assertThat(output.continuation().shortExpectedContinueEdgeBps()).isEqualByComparingTo("1.5");
|
||||
assertThat(output.exit().reasonScore("reversal_prob")).isEqualByComparingTo("0.25");
|
||||
assertThat(output.risk().reasonScore("liquidity_deterioration_prob")).isEqualByComparingTo("0.12");
|
||||
assertThat(output.metadata().uncertainty()).isEqualByComparingTo("0.38");
|
||||
assertThat(output.metadata().oodScore()).isEqualByComparingTo("0.05");
|
||||
}
|
||||
|
||||
@Test
|
||||
void failsWhenShadowSnapshotHasNoOodScore() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
OnnxTraderModelService service = new OnnxTraderModelService(
|
||||
properties,
|
||||
objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
new FakeInferenceClient());
|
||||
|
||||
var snapshot = new com.quantai.trader.domain.TraderMarketSnapshot("snapshot-1", "run-1", "cycle-1",
|
||||
"BTC-USDT-PERP", T0, "feature-v4-p0", bd("100"), bd("99.5"), bd("1.2"),
|
||||
bd("0.5"), bd("1000"), bd("1400"), bd("2200"), true, featureJson(), Map.of());
|
||||
|
||||
assertThatThrownBy(() -> service.evaluate(snapshot, bundle))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("ood_score");
|
||||
}
|
||||
|
||||
private static final class FakeInferenceClient implements TraderOnnxInferenceClient {
|
||||
@Override
|
||||
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
|
||||
return switch (manifest.modelType()) {
|
||||
case "DIRECTION" -> Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
|
||||
case "ENTRY" -> Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
|
||||
case "CONTINUE" -> Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
|
||||
case "EXIT" -> Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
case "RISK" -> Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class OrtTraderOnnxInferenceClientTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void failsClearlyWhenOnnxFileIsNotLoadable() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
var manifest = bundle.modelManifests().stream()
|
||||
.filter(item -> item.modelType().equals("DIRECTION"))
|
||||
.findFirst()
|
||||
.orElseThrow();
|
||||
|
||||
assertThatThrownBy(() -> new OrtTraderOnnxInferenceClient()
|
||||
.infer(manifest, artifactRoot.resolve(manifest.artifactPath()), new float[39]))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("ONNX model cannot be loaded");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderExecutionMode;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class RoutingTraderModelServiceTest {
|
||||
@TempDir
|
||||
Path artifactRoot;
|
||||
|
||||
@Test
|
||||
void replaySimUsesReplayFixtureOnly() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
var service = new RoutingTraderModelService(
|
||||
properties,
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new OnnxTraderModelService(properties, objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
(manifest, modelPath, features) -> {
|
||||
throw new AssertionError("REPLAY_SIM must not call ONNX");
|
||||
}));
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.direction().longProb()).isEqualByComparingTo("0.62");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shadowUsesOnnxService() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.SHADOW, TraderExecutionMode.SHADOW);
|
||||
var bundle = new TraderArtifactLoader(properties, objectMapper()).loadActiveBundle();
|
||||
var service = new RoutingTraderModelService(
|
||||
properties,
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new OnnxTraderModelService(properties, objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
(manifest, modelPath, features) -> switch (manifest.modelType()) {
|
||||
case "DIRECTION" -> java.util.Map.of("probabilities", new float[]{0.5f, 0.5f, 0.5f});
|
||||
case "ENTRY" -> java.util.Map.of("entry", new float[]{0.5f, 0.5f, 12.0f, 8.0f});
|
||||
case "CONTINUE" -> java.util.Map.of("continue", new float[]{0.5f, 0.5f, 3.0f, 1.5f});
|
||||
case "EXIT" -> java.util.Map.of("exit", new float[]{0.5f, 0.5f, 10.0f, 18.0f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
case "RISK" -> java.util.Map.of("risk", new float[]{0.5f, 0.5f, 0.5f, 20.0f, 20.0f, 30.0f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||
default -> throw new IllegalArgumentException("unexpected model type: " + manifest.modelType());
|
||||
}));
|
||||
|
||||
var output = service.evaluate(snapshot(), bundle);
|
||||
|
||||
assertThat(output.risk().marketRiskProb()).isEqualByComparingTo("0.20");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsRunModesOutsideP0Scope() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
var bundle = new TraderArtifactLoader(
|
||||
propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM),
|
||||
objectMapper()).loadActiveBundle();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.PAPER, TraderExecutionMode.PAPER);
|
||||
var service = new RoutingTraderModelService(
|
||||
properties,
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new OnnxTraderModelService(properties, objectMapper(),
|
||||
new TraderFeatureVectorBuilder(properties, objectMapper()),
|
||||
(manifest, modelPath, features) -> java.util.Map.of()));
|
||||
|
||||
assertThatThrownBy(() -> service.evaluate(snapshot(), bundle))
|
||||
.isInstanceOf(TraderException.class)
|
||||
.hasMessageContaining("only supports REPLAY_SIM and SHADOW");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.objectMapper;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class TraderOutputSchemaBoundsTest {
|
||||
@TempDir
|
||||
Path tempDir;
|
||||
|
||||
@Test
|
||||
void clipsBpsByOutputSchemaRange() throws IOException {
|
||||
Path path = tempDir.resolve("output_schema.json");
|
||||
Files.writeString(path, """
|
||||
{"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}}
|
||||
""");
|
||||
|
||||
TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path);
|
||||
|
||||
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("12"))).isEqualByComparingTo("10");
|
||||
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("-12"))).isEqualByComparingTo("-10");
|
||||
assertThat(bounds.clip("longExpectedNetEdgeBps", new BigDecimal("2"))).isEqualByComparingTo("2");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsMissingFieldRange() throws IOException {
|
||||
Path path = tempDir.resolve("output_schema.json");
|
||||
Files.writeString(path, """
|
||||
{"entry":{"longExpectedNetEdgeBps":{"type":"decimal","range":[-10.0,10.0]}}}
|
||||
""");
|
||||
TraderOutputSchemaBounds bounds = TraderOutputSchemaBounds.read(objectMapper(), path);
|
||||
|
||||
assertThatThrownBy(() -> bounds.clip("shortExpectedNetEdgeBps", BigDecimal.ONE))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("does not define range");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsSchemaWithoutRanges() throws IOException {
|
||||
Path path = tempDir.resolve("output_schema.json");
|
||||
Files.writeString(path, "{\"entry\":{}}");
|
||||
|
||||
assertThatThrownBy(() -> TraderOutputSchemaBounds.read(objectMapper(), path))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("must define field ranges");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.quantai.trader.model;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.objectMapper;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class TraderProbabilityCalibratorTest {
|
||||
@TempDir
|
||||
Path tempDir;
|
||||
|
||||
@Test
|
||||
void appliesBinningCalibration() throws IOException {
|
||||
Path path = tempDir.resolve("calibrator.json");
|
||||
Files.writeString(path, """
|
||||
{"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
|
||||
""");
|
||||
|
||||
TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION");
|
||||
|
||||
assertThat(calibrator.calibrate("longProb", new BigDecimal("0.40"))).isEqualByComparingTo("0.62");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsUnsupportedCalibrationMethod() throws IOException {
|
||||
Path path = tempDir.resolve("calibrator.json");
|
||||
Files.writeString(path, """
|
||||
{"method":"NONE","targets":{"longProb":{"bins":[{"min":0.0,"max":1.0,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
|
||||
""");
|
||||
|
||||
assertThatThrownBy(() -> TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION"))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("must be BINNING");
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsMissingTargetAndOutOfRangeBins() throws IOException {
|
||||
Path path = tempDir.resolve("calibrator.json");
|
||||
Files.writeString(path, """
|
||||
{"method":"BINNING","targets":{"longProb":{"bins":[{"min":0.0,"max":0.1,"calibrated":0.62}]}},"clip":{"min":0.0,"max":1.0}}
|
||||
""");
|
||||
TraderProbabilityCalibrator calibrator = TraderProbabilityCalibrator.read(objectMapper(), path, "DIRECTION");
|
||||
|
||||
assertThatThrownBy(() -> calibrator.calibrate("shortProb", new BigDecimal("0.05")))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("target is missing");
|
||||
assertThatThrownBy(() -> calibrator.calibrate("longProb", new BigDecimal("0.50")))
|
||||
.isInstanceOf(com.quantai.trader.domain.TraderException.class)
|
||||
.hasMessageContaining("does not match any");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.quantai.trader.outbox;
|
||||
|
||||
import com.quantai.trader.domain.*;
|
||||
import com.quantai.trader.enums.FeedbackSource;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import com.quantai.trader.feedback.TraderFeedbackRepository;
|
||||
import com.quantai.trader.replay.state.TraderPostActionStateRepository;
|
||||
import com.quantai.trader.replay.state.TraderReplayState;
|
||||
import com.quantai.trader.replay.state.TraderReplayStateStore;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class P0OutboxDispatcherTest {
|
||||
private final TraderOutboxRepository outboxRepository = mock(TraderOutboxRepository.class);
|
||||
private final TraderFeedbackRepository feedbackRepository = mock(TraderFeedbackRepository.class);
|
||||
private final TraderReplayStateStore stateStore = mock(TraderReplayStateStore.class);
|
||||
private final TraderPostActionStateRepository postActionStateRepository = mock(TraderPostActionStateRepository.class);
|
||||
private final P0OutboxDispatcher dispatcher = new P0OutboxDispatcher(
|
||||
outboxRepository, feedbackRepository, stateStore, postActionStateRepository);
|
||||
|
||||
@Test
|
||||
void dispatchesShadowRecorderFeedbackMarksOutboxAndPersistsPostActionState() {
|
||||
TraderAction action = action(TraderActionType.OPEN_LONG, PositionSide.LONG);
|
||||
TraderMarketSnapshot snapshot = snapshot();
|
||||
TraderReplayState currentState = new TraderReplayState(flatPosition(), account(), execution());
|
||||
TraderReplayState nextState = new TraderReplayState(longPosition("0"), account("0", "0.20", "0.80"), execution());
|
||||
when(stateStore.advance(currentState, action, snapshot)).thenReturn(nextState);
|
||||
|
||||
dispatcher.dispatch(action, currentState, snapshot, "SHADOW_RECORDER");
|
||||
|
||||
ArgumentCaptor<TraderAppFeedback> feedback = ArgumentCaptor.forClass(TraderAppFeedback.class);
|
||||
verify(feedbackRepository).insert(feedback.capture());
|
||||
assertThat(feedback.getValue().feedbackSource()).isEqualTo(FeedbackSource.SHADOW_APP);
|
||||
assertThat(feedback.getValue().realFill()).isFalse();
|
||||
assertThat(feedback.getValue().orderStatus()).isEqualTo("RECORDED");
|
||||
assertThat(feedback.getValue().rawFeedbackJson())
|
||||
.containsEntry("destination", "SHADOW_RECORDER")
|
||||
.containsEntry("actionType", "OPEN_LONG");
|
||||
verify(outboxRepository).markSent("outbox_" + action.actionId());
|
||||
verify(postActionStateRepository).insertPostActionState(nextState);
|
||||
}
|
||||
|
||||
@Test
|
||||
void mapsReplaySimDestinationToReplaySimulatorFeedback() {
|
||||
TraderAction action = action(TraderActionType.OPEN_LONG, PositionSide.LONG);
|
||||
TraderReplayState state = new TraderReplayState(flatPosition(), account(), execution());
|
||||
when(stateStore.advance(state, action, snapshot())).thenReturn(state);
|
||||
|
||||
dispatcher.dispatch(action, state, snapshot(), "REPLAY_SIM_EXECUTION");
|
||||
|
||||
ArgumentCaptor<TraderAppFeedback> feedback = ArgumentCaptor.forClass(TraderAppFeedback.class);
|
||||
verify(feedbackRepository).insert(feedback.capture());
|
||||
assertThat(feedback.getValue().feedbackSource()).isEqualTo(FeedbackSource.REPLAY_SIMULATOR);
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectsPaperOrRealDestinationsInP0Dispatcher() {
|
||||
TraderAction action = action(TraderActionType.OPEN_LONG, PositionSide.LONG);
|
||||
TraderReplayState state = new TraderReplayState(flatPosition(), account(), execution());
|
||||
|
||||
assertThatThrownBy(() -> dispatcher.dispatch(action, state, snapshot(), "PAPER_EXECUTION"))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("P0 outbox destination is not allowed");
|
||||
|
||||
verifyNoInteractions(feedbackRepository, outboxRepository, postActionStateRepository);
|
||||
}
|
||||
|
||||
private TraderAction action(TraderActionType type, PositionSide side) {
|
||||
TraderRiskDecision riskDecision = new TraderRiskDecision(
|
||||
"risk-1", "run-1", "cycle-1", "pm-cycle-1",
|
||||
true, type, type, null, Map.of());
|
||||
return new TraderActionFactory().create(pmDecision(type, side), riskDecision, "BTC-USDT-PERP");
|
||||
}
|
||||
}
|
||||
@@ -26,9 +26,9 @@ class JdbcTraderDecisionTraceWriterTest {
|
||||
"pm-cycle-1", true, TraderActionType.OPEN_LONG, TraderActionType.OPEN_LONG, null, Map.of());
|
||||
var action = new TraderActionFactory().create(pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG), riskDecision, "BTC-USDT-PERP");
|
||||
|
||||
writer.persistCycleTrace(cycle(), snapshot(), modelOutput(), flatPosition(),
|
||||
writer.persistCycleTrace(cycle(), snapshot(), modelOutput(), flatPosition(), account(), execution(),
|
||||
pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG), riskDecision, action);
|
||||
|
||||
verify(jdbcTemplate, times(6)).update(anyString(), any(Object[].class));
|
||||
verify(jdbcTemplate, times(10)).update(anyString(), any(Object[].class));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package com.quantai.trader.position;
|
||||
|
||||
import com.quantai.trader.domain.PositionManagerInput;
|
||||
import com.quantai.trader.domain.TraderModelOutput;
|
||||
import com.quantai.trader.domain.TraderPositionManagerDecision;
|
||||
import com.quantai.trader.domain.TraderPositionState;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -36,12 +39,13 @@ class TraderPositionManagerTest {
|
||||
@Test
|
||||
void waitsWhenDataOrLiquidityIsInsufficient() {
|
||||
TraderModelOutput noLiquidity = modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34",
|
||||
"0.20", "0.50", "0.18", "0.12", "0", "12", "4", "0.10", "0.05", "20");
|
||||
"0.20", "0.50", "0.18", "0.12", "12", "4", "0.10", "0.05", "20");
|
||||
|
||||
assertThat(positionManager.decide(pmInput(noLiquidity, flatPosition())).candidateAction())
|
||||
assertThat(positionManager.decide(pmInput(noLiquidity, flatPosition(), account(),
|
||||
execution(java.util.List.of(), 10, 0, "0"))).candidateAction())
|
||||
.isEqualTo(TraderActionType.WAIT);
|
||||
assertThat(positionManager.decide(new com.quantai.trader.domain.PositionManagerInput(
|
||||
cycle(), snapshot(false, "1000"), modelOutput(), flatPosition(), account(), execution(), pmConfig())).candidateAction())
|
||||
cycle(), snapshot(false, "1000"), modelOutput(), pricePlan(), flatPosition(), account(), execution(), pmConfig())).candidateAction())
|
||||
.isEqualTo(TraderActionType.WAIT);
|
||||
}
|
||||
|
||||
@@ -55,6 +59,20 @@ class TraderPositionManagerTest {
|
||||
assertThat(hold.candidateAction()).isEqualTo(TraderActionType.HOLD);
|
||||
}
|
||||
|
||||
@Test
|
||||
void holdsWhenAddCooldownHasNotElapsed() {
|
||||
TraderPositionState coolingPosition = new TraderPositionState(
|
||||
"position-state-1", "run-1", "cycle-1", "BTC-USDT-PERP",
|
||||
PositionSide.LONG, bd("0.30"), bd("100"), bd("101"), bd("10"), bd("1000"),
|
||||
1, bd("0.40"), T0);
|
||||
|
||||
TraderPositionManagerDecision decision = positionManager.decide(new PositionManagerInput(
|
||||
cycle(), snapshot(), modelOutput(), pricePlan(), coolingPosition, account(), execution(), pmConfig()));
|
||||
|
||||
assertThat(decision.candidateAction()).isEqualTo(TraderActionType.HOLD);
|
||||
assertThat(decision.reason()).isEqualTo("CONTINUE_HOLD");
|
||||
}
|
||||
|
||||
@Test
|
||||
void usesPositionSideContinuationProbabilityForShortAdd() {
|
||||
TraderPositionManagerDecision decision = positionManager.decide(pmInput(shortModelOutput(), shortPosition("10")));
|
||||
@@ -66,7 +84,7 @@ class TraderPositionManagerTest {
|
||||
@Test
|
||||
void closesExistingPositionWhenExitOrRiskSignalIsHigh() {
|
||||
TraderModelOutput exitHigh = modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34",
|
||||
"0.80", "0.50", "0.18", "0.12", "1.0", "12", "4", "0.10", "0.05", "20");
|
||||
"0.80", "0.50", "0.18", "0.12", "12", "4", "0.10", "0.05", "20");
|
||||
|
||||
TraderPositionManagerDecision decision = positionManager.decide(pmInput(exitHigh, longPosition("10")));
|
||||
|
||||
@@ -77,7 +95,7 @@ class TraderPositionManagerTest {
|
||||
@Test
|
||||
void returnsZeroInitialRatioWhenExpectedEdgeIsBelowSizingFloor() {
|
||||
TraderModelOutput weakEdge = modelOutput("0.70", "0.20", "0.70", "0.30", "0.66", "0.34",
|
||||
"0.20", "0.50", "0.18", "0.12", "1.0", "0.5", "4", "0.10", "0.05", "20");
|
||||
"0.20", "0.50", "0.18", "0.12", "0.5", "4", "0.10", "0.05", "20");
|
||||
|
||||
BigDecimal ratio = positionManager.calculateInitialRatio(pmInput(weakEdge, flatPosition()), com.quantai.trader.enums.PositionSide.LONG);
|
||||
|
||||
|
||||
@@ -2,23 +2,28 @@ package com.quantai.trader.replay;
|
||||
|
||||
import com.quantai.trader.artifact.TraderArtifactLoader;
|
||||
import com.quantai.trader.domain.*;
|
||||
import com.quantai.trader.enums.TraderExecutionMode;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.evidence.EvidenceAppender;
|
||||
import com.quantai.trader.evidence.TraderEvidenceRepository;
|
||||
import com.quantai.trader.model.ArtifactTraderModelService;
|
||||
import com.quantai.trader.model.ReplayFixtureTraderModelService;
|
||||
import com.quantai.trader.outbox.TraderOutboxEvent;
|
||||
import com.quantai.trader.outbox.TraderOutboxDispatcher;
|
||||
import com.quantai.trader.outbox.TraderOutboxRepository;
|
||||
import com.quantai.trader.persistence.TraderDecisionTraceWriter;
|
||||
import com.quantai.trader.position.TraderPositionManager;
|
||||
import com.quantai.trader.replay.state.P0ReplayStateStore;
|
||||
import com.quantai.trader.replay.state.TraderReplayState;
|
||||
import com.quantai.trader.risk.TraderRiskGate;
|
||||
import com.quantai.trader.runtime.TraderRuntimeControlDecision;
|
||||
import com.quantai.trader.runtime.TraderRuntimeControlService;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
@@ -35,28 +40,31 @@ class TraderP0CycleRunnerTest {
|
||||
EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository);
|
||||
RecordingTraceWriter traceWriter = new RecordingTraceWriter();
|
||||
RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM);
|
||||
P0ReplayStateStore stateStore = new P0ReplayStateStore(properties);
|
||||
TraderP0CycleRunner runner = new TraderP0CycleRunner(
|
||||
properties,
|
||||
new TraderArtifactLoader(properties, objectMapper()),
|
||||
new ArtifactTraderModelService(),
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new TraderPositionManager(),
|
||||
new TraderRiskGate(),
|
||||
new TraderActionFactory(),
|
||||
evidenceAppender,
|
||||
traceWriter,
|
||||
bundle -> {
|
||||
},
|
||||
outboxRepository,
|
||||
new P0ReplayStateStore(properties));
|
||||
new ImmediateDispatcher(stateStore),
|
||||
stateStore,
|
||||
new AllowRuntimeControlService());
|
||||
|
||||
TraderCycleResult result = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"),
|
||||
new BigDecimal("1.2"), new BigDecimal("1000")));
|
||||
TraderCycleResult result = runner.runCycle(replayEvent("run-1", T0, "100", "99.5", "1000"));
|
||||
|
||||
assertThat(result.action().actionType()).isEqualTo(TraderActionType.OPEN_LONG);
|
||||
assertThat(result.action().reduceOnly()).isFalse();
|
||||
assertThat(traceWriter.actions()).containsExactly(result.action());
|
||||
assertThat(outboxRepository.events()).hasSize(1);
|
||||
assertThat(outboxRepository.events().getFirst().destination()).isEqualTo("SHADOW_RECORDER");
|
||||
assertThat(outboxRepository.events().getFirst().destination()).isEqualTo("REPLAY_SIM_EXECUTION");
|
||||
assertThat(evidenceRepository.items()).extracting("stage")
|
||||
.containsExactly("MARKET_SNAPSHOT", "MODEL_OUTPUT", "PM_DECISION", "RISK_DECISION");
|
||||
}
|
||||
@@ -68,22 +76,25 @@ class TraderP0CycleRunnerTest {
|
||||
EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository);
|
||||
RecordingTraceWriter traceWriter = new RecordingTraceWriter();
|
||||
RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM);
|
||||
P0ReplayStateStore stateStore = new P0ReplayStateStore(properties);
|
||||
TraderP0CycleRunner runner = new TraderP0CycleRunner(
|
||||
properties,
|
||||
new TraderArtifactLoader(properties, objectMapper()),
|
||||
new ArtifactTraderModelService(),
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new TraderPositionManager(),
|
||||
new TraderRiskGate(),
|
||||
new TraderActionFactory(),
|
||||
evidenceAppender,
|
||||
traceWriter,
|
||||
bundle -> {
|
||||
},
|
||||
outboxRepository,
|
||||
new P0ReplayStateStore(properties));
|
||||
new ImmediateDispatcher(stateStore),
|
||||
stateStore,
|
||||
new AllowRuntimeControlService());
|
||||
|
||||
TraderCycleResult result = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("100"), new BigDecimal("99.5"),
|
||||
new BigDecimal("1.2"), BigDecimal.ZERO));
|
||||
TraderCycleResult result = runner.runCycle(replayEvent("run-1", T0.plusSeconds(60), "100", "99.5", "0"));
|
||||
|
||||
assertThat(result.action().actionType()).isEqualTo(TraderActionType.WAIT);
|
||||
assertThat(result.action().pricePlanId()).isNull();
|
||||
@@ -99,25 +110,26 @@ class TraderP0CycleRunnerTest {
|
||||
EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository);
|
||||
RecordingTraceWriter traceWriter = new RecordingTraceWriter();
|
||||
RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot, TraderRunMode.REPLAY_SIM, TraderExecutionMode.REPLAY_SIM);
|
||||
P0ReplayStateStore stateStore = new P0ReplayStateStore(properties);
|
||||
TraderP0CycleRunner runner = new TraderP0CycleRunner(
|
||||
properties,
|
||||
new TraderArtifactLoader(properties, objectMapper()),
|
||||
new ArtifactTraderModelService(),
|
||||
new ReplayFixtureTraderModelService(properties),
|
||||
new TraderPositionManager(),
|
||||
new TraderRiskGate(),
|
||||
new TraderActionFactory(),
|
||||
evidenceAppender,
|
||||
traceWriter,
|
||||
bundle -> {
|
||||
},
|
||||
outboxRepository,
|
||||
new P0ReplayStateStore(properties));
|
||||
new ImmediateDispatcher(stateStore),
|
||||
stateStore,
|
||||
new AllowRuntimeControlService());
|
||||
|
||||
TraderCycleResult first = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-state-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"),
|
||||
new BigDecimal("1.2"), new BigDecimal("1000")));
|
||||
TraderCycleResult second = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-state-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("101"), new BigDecimal("100.5"),
|
||||
new BigDecimal("1.2"), new BigDecimal("1000")));
|
||||
TraderCycleResult first = runner.runCycle(replayEvent("run-state-1", T0, "100", "99.5", "1000"));
|
||||
TraderCycleResult second = runner.runCycle(replayEvent("run-state-1", T0.plusSeconds(60), "101", "100.5", "1000"));
|
||||
|
||||
assertThat(first.action().actionType()).isEqualTo(TraderActionType.OPEN_LONG);
|
||||
assertThat(second.action().actionType()).isEqualTo(TraderActionType.ADD_LONG);
|
||||
@@ -147,18 +159,41 @@ class TraderP0CycleRunnerTest {
|
||||
events.add(event);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markSent(String outboxId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasUnsentExposureIncrease(String runId, String symbol, com.quantai.trader.enums.PositionSide side) {
|
||||
return false;
|
||||
}
|
||||
|
||||
List<TraderOutboxEvent> events() {
|
||||
return events;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class ImmediateDispatcher implements TraderOutboxDispatcher {
|
||||
private final P0ReplayStateStore stateStore;
|
||||
|
||||
private ImmediateDispatcher(P0ReplayStateStore stateStore) {
|
||||
this.stateStore = stateStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination) {
|
||||
stateStore.advance(currentState, action, snapshot);
|
||||
}
|
||||
}
|
||||
|
||||
private static final class RecordingTraceWriter implements TraderDecisionTraceWriter {
|
||||
private final List<TraderAction> actions = new ArrayList<>();
|
||||
private final List<TraderPositionState> positionStates = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void persistCycleTrace(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderModelOutput modelOutput,
|
||||
TraderPositionState positionState, TraderPositionManagerDecision pmDecision,
|
||||
TraderPositionState positionState, TraderAccountState accountState,
|
||||
TraderExecutionState executionState, TraderPositionManagerDecision pmDecision,
|
||||
TraderRiskDecision riskDecision, TraderAction action) {
|
||||
actions.add(action);
|
||||
positionStates.add(positionState);
|
||||
@@ -172,4 +207,20 @@ class TraderP0CycleRunnerTest {
|
||||
return positionStates;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class AllowRuntimeControlService implements TraderRuntimeControlService {
|
||||
@Override
|
||||
public void acquireCycleLock(TraderDecisionCycle cycle) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void releaseCycleLock(TraderDecisionCycle cycle) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle,
|
||||
TraderPositionManagerDecision decision) {
|
||||
return TraderRuntimeControlDecision.allow();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
package com.quantai.trader.replay.state;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.contains;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
class JdbcTraderPostActionStateRepositoryTest {
|
||||
@Test
|
||||
void insertsPostActionPositionAccountAndExecutionStateWithPostIds() {
|
||||
JdbcTemplate jdbcTemplate = mock(JdbcTemplate.class);
|
||||
JdbcTraderPostActionStateRepository repository = new JdbcTraderPostActionStateRepository(jdbcTemplate, objectMapper());
|
||||
TraderReplayState state = new TraderReplayState(longPosition("12"), account(), executionWithOpenOrder());
|
||||
|
||||
repository.insertPostActionState(state);
|
||||
|
||||
ArgumentCaptor<Object[]> positionArgs = ArgumentCaptor.forClass(Object[].class);
|
||||
ArgumentCaptor<Object[]> accountArgs = ArgumentCaptor.forClass(Object[].class);
|
||||
ArgumentCaptor<Object[]> executionArgs = ArgumentCaptor.forClass(Object[].class);
|
||||
verify(jdbcTemplate).update(contains("insert into trader_position_state"), positionArgs.capture());
|
||||
verify(jdbcTemplate).update(contains("insert into trader_account_state"), accountArgs.capture());
|
||||
verify(jdbcTemplate).update(contains("insert into trader_execution_state"), executionArgs.capture());
|
||||
verifyNoMoreInteractions(jdbcTemplate);
|
||||
|
||||
assertThat(positionArgs.getValue()[2]).isEqualTo("position-state-1_post");
|
||||
assertThat(positionArgs.getValue()[4]).isEqualTo("LONG");
|
||||
assertThat(accountArgs.getValue()[2]).isEqualTo("account-state-1_post");
|
||||
assertThat(executionArgs.getValue()[2]).isEqualTo("execution-state-1_post");
|
||||
assertThat(executionArgs.getValue()[4].toString()).contains("order-1");
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ class TraderRiskGateTest {
|
||||
|
||||
@Test
|
||||
void killSwitchBlocksOnlyExposureIncreasingActions() {
|
||||
RiskLimits limits = new RiskLimits(bd("200"), java.math.BigDecimal.ONE, bd("500"), 3, 1500, true, false);
|
||||
RiskLimits limits = new RiskLimits(bd("200"), java.math.BigDecimal.ONE, bd("500"), 3, 1500, true, false, null);
|
||||
|
||||
TraderRiskDecision open = riskGate.evaluate(new RiskGateInput(
|
||||
pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG), flatPosition(), account(), execution(), snapshot(), limits));
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package com.quantai.trader.runtime;
|
||||
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.data.redis.core.ValueOperations;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class RedisTraderRuntimeControlServiceTest {
|
||||
private final StringRedisTemplate redisTemplate = mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
private final ValueOperations<String, String> valueOperations = mock(ValueOperations.class);
|
||||
private final RedisTraderRuntimeControlService service = new RedisTraderRuntimeControlService(properties(), redisTemplate);
|
||||
|
||||
@Test
|
||||
void acquiresAndReleasesCycleLockWithDeterministicRedisKey() {
|
||||
when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
String lockKey = "trader:v4:test:cycle:lock:BTC-USDT-PERP:" + T0.toEpochMilli();
|
||||
when(valueOperations.setIfAbsent(eq(lockKey), eq("cycle-1"), eq(Duration.ofSeconds(30)))).thenReturn(true);
|
||||
|
||||
service.acquireCycleLock(cycle());
|
||||
service.releaseCycleLock(cycle());
|
||||
|
||||
verify(valueOperations).setIfAbsent(lockKey, "cycle-1", Duration.ofSeconds(30));
|
||||
verify(redisTemplate).delete(lockKey);
|
||||
}
|
||||
|
||||
@Test
|
||||
void reportsExistingCycleLockAsRuntimeBlockNotRedisOutage() {
|
||||
when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
when(valueOperations.setIfAbsent(anyString(), anyString(), any(Duration.class))).thenReturn(false);
|
||||
|
||||
assertThatThrownBy(() -> service.acquireCycleLock(cycle()))
|
||||
.isInstanceOfSatisfying(TraderException.class, exception -> {
|
||||
assertThat(exception.code()).isEqualTo(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED);
|
||||
assertThat(exception.getMessage()).contains("cycle lock already exists");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void allowsExposureIncreaseWhenActivePointerAndRuntimeSwitchesPass() {
|
||||
when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
when(valueOperations.get("trader:v4:test:model:active:BTC-USDT-PERP"))
|
||||
.thenReturn("trader-v4-btc-p0|cal-v4-btc-p0|pm-v4-btc-p0");
|
||||
when(redisTemplate.hasKey(anyString())).thenReturn(false);
|
||||
|
||||
TraderRuntimeControlDecision decision = service.validateExposureIncrease(
|
||||
cycle(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG));
|
||||
|
||||
assertThat(decision.allowed()).isTrue();
|
||||
verify(valueOperations).set("trader:v4:test:runtime:open-add-probe:cycle-1", "1", Duration.ofSeconds(30));
|
||||
}
|
||||
|
||||
@Test
|
||||
void blocksExposureIncreaseWhenActivePointerIsMissing() {
|
||||
when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
when(valueOperations.get("trader:v4:test:model:active:BTC-USDT-PERP")).thenReturn(null);
|
||||
|
||||
TraderRuntimeControlDecision decision = service.validateExposureIncrease(
|
||||
cycle(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG));
|
||||
|
||||
assertThat(decision.allowed()).isFalse();
|
||||
assertThat(decision.blocker()).isEqualTo("ACTIVE_POINTER_MISSING");
|
||||
}
|
||||
|
||||
@Test
|
||||
void skipsRedisChecksWhenActionDoesNotIncreaseExposure() {
|
||||
TraderRuntimeControlDecision decision = service.validateExposureIncrease(
|
||||
cycle(), pmDecision(TraderActionType.HOLD, PositionSide.LONG));
|
||||
|
||||
assertThat(decision.allowed()).isTrue();
|
||||
verifyNoInteractions(redisTemplate);
|
||||
}
|
||||
|
||||
@Test
|
||||
void blocksExposureIncreaseWhenRedisFailsDuringValidation() {
|
||||
when(redisTemplate.opsForValue()).thenThrow(new RuntimeException("redis down"));
|
||||
|
||||
TraderRuntimeControlDecision decision = service.validateExposureIncrease(
|
||||
cycle(), pmDecision(TraderActionType.OPEN_LONG, PositionSide.LONG));
|
||||
|
||||
assertThat(decision.allowed()).isFalse();
|
||||
assertThat(decision.blocker()).isEqualTo("REDIS_UNAVAILABLE");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.quantai.trader.runtime;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
class StartupValidationRunnerTest {
|
||||
@Test
|
||||
void delegatesStartupValidationToRuntimeGuard() {
|
||||
P0RuntimeGuard guard = mock(P0RuntimeGuard.class);
|
||||
StartupValidationRunner runner = new StartupValidationRunner(guard);
|
||||
|
||||
runner.run(null);
|
||||
|
||||
verify(guard).validateStartup();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
# Trader V4 Training Pipeline
|
||||
|
||||
This directory contains the executable training chain for Trader V4.
|
||||
Large data stays under `/Users/zach/Desktop/quant-strategy-training-data`.
|
||||
|
||||
Run order:
|
||||
|
||||
```bash
|
||||
PY=/Users/zach/IdeaProjects/quant-trading-ai/quant-strategy-server/.venv/bin/python
|
||||
RUN_ID=btc-v4-p0-001
|
||||
ROOT=/Users/zach/Desktop/quant-strategy-training-data
|
||||
|
||||
$PY training/scripts/01_audit_source_data.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19
|
||||
$PY training/scripts/02_build_replay_1m.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19
|
||||
$PY training/scripts/03_build_splits.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/04_build_feature_frame.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/05_build_price_plan_context.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/06_build_direction_labels.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/07_build_entry_labels.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/08_build_position_state_samples.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/09_build_continue_exit_risk_labels.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/10_build_train_datasets.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/11_train_small_models.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/12_calibrate_models.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/13_search_pm_thresholds.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/14_integrated_backtest.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/15_export_artifact_bundle.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle
|
||||
$PY training/scripts/17_promote_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --reason "validation_locked and latest_stress passed for SHADOW"
|
||||
$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --require-active --run-onnx
|
||||
```
|
||||
|
||||
Java SHADOW 只加载 `ACTIVE` 包。15 号脚本永远只生成 `CANDIDATE`,16 号校验通过且上线门槛通过后,17 号脚本才允许把包提升为 `ACTIVE`。
|
||||
@@ -0,0 +1,7 @@
|
||||
pandas==2.2.3
|
||||
pyarrow==24.0.0
|
||||
numpy==2.4.6
|
||||
scikit-learn==1.7.0
|
||||
scipy==1.18.0
|
||||
onnx==1.22.0
|
||||
onnxruntime==1.27.0
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.replay import write_audit_outputs
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--symbol", default="BTC-USDT-PERP")
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
parser.add_argument("--min-ready-days", type=int, default=250)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
write_audit_outputs(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.replay import build_replay_1m
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--raw-root", type=Path)
|
||||
parser.add_argument("--symbol", default="BTC-USDT-PERP")
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
parser.add_argument("--min-minutes-per-day", type=int, default=1400)
|
||||
parser.add_argument("--min-replay-ready-days", type=int, default=250)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_replay_1m(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.replay import build_splits
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--fit-inner-start", default="2025-06-20")
|
||||
parser.add_argument("--fit-inner-end", default="2026-01-15")
|
||||
parser.add_argument("--tune-inner-start", default="2026-01-16")
|
||||
parser.add_argument("--tune-inner-end", default="2026-02-28")
|
||||
parser.add_argument("--validation-locked-start", default="2026-03-01")
|
||||
parser.add_argument("--validation-locked-end", default="2026-04-30")
|
||||
parser.add_argument("--latest-stress-start", default="2026-05-01")
|
||||
parser.add_argument("--latest-stress-end", default="2026-06-19")
|
||||
parser.add_argument("--gap-minutes", type=int, default=60)
|
||||
parser.add_argument("--fold-count", type=int, default=3)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_splits(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.features import build_feature_frame
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--split-manifest-path", type=Path)
|
||||
parser.add_argument("--allow-incomplete-days", action="store_true")
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_feature_frame(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import write_price_plan_context
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--price-plan-id", default="btc-p0-plan-45m")
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
write_price_plan_context(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_direction_labels
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_direction_labels(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_entry_labels
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--price-plan-context-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_entry_labels(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_position_state_samples
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--entry-label-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_position_state_samples(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_continue_exit_risk_labels
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--price-plan-context-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_continue_exit_risk_labels(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.datasets import build_train_datasets
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--direction-label-path", type=Path)
|
||||
parser.add_argument("--entry-label-path", type=Path)
|
||||
parser.add_argument("--continue-label-path", type=Path)
|
||||
parser.add_argument("--exit-label-path", type=Path)
|
||||
parser.add_argument("--risk-label-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_train_datasets(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.training import train_small_models
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--max-rows", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
train_small_models(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.training import build_calibrators
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_calibrators(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.pm import search_pm_thresholds
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
search_pm_thresholds(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.pm import integrated_backtest
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
integrated_backtest(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.exporter import export_artifact_bundle
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--export-root", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
export_artifact_bundle(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import setup_logging
|
||||
from trader_training.validator import validate_artifact_bundle
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--artifact-root", type=Path, required=True)
|
||||
parser.add_argument("--require-active", action="store_true")
|
||||
parser.add_argument("--run-onnx", action="store_true")
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
validate_artifact_bundle(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import setup_logging
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--artifact-root", type=Path, required=True)
|
||||
parser.add_argument("--reason", required=True)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
promote_artifact_bundle(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.onnx_export import LinearHead, export_heads
|
||||
from trader_training.io_utils import read_json, write_json
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
from trader_training.replay import build_splits
|
||||
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
class TrainingContractTest(unittest.TestCase):
|
||||
def test_feature_order_is_v4_contract_size(self) -> None:
|
||||
self.assertEqual(39, len(FEATURE_ORDER))
|
||||
self.assertEqual(len(FEATURE_ORDER), len(set(FEATURE_ORDER)))
|
||||
self.assertEqual("ret_1m_bps", FEATURE_ORDER[0])
|
||||
self.assertEqual("minutes_to_next_funding", FEATURE_ORDER[-1])
|
||||
|
||||
def test_output_mapping_matches_model_outputs(self) -> None:
|
||||
for model_name, fields in MODEL_OUTPUTS.items():
|
||||
self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name]))
|
||||
self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields])
|
||||
|
||||
def test_split_builder_uses_locked_validation_contract(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
replay_path = data_root / "replay_1m.parquet"
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"event_time": pd.date_range("2025-06-20", "2026-06-19", freq="D", tz="UTC"),
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
}
|
||||
)
|
||||
frame.to_parquet(replay_path, index=False)
|
||||
|
||||
build_splits(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-split",
|
||||
replay_path=replay_path,
|
||||
fit_inner_start="2025-06-20",
|
||||
fit_inner_end="2026-01-15",
|
||||
tune_inner_start="2026-01-16",
|
||||
tune_inner_end="2026-02-28",
|
||||
validation_locked_start="2026-03-01",
|
||||
validation_locked_end="2026-04-30",
|
||||
latest_stress_start="2026-05-01",
|
||||
latest_stress_end="2026-06-19",
|
||||
gap_minutes=0,
|
||||
fold_count=2,
|
||||
)
|
||||
)
|
||||
|
||||
manifest = read_json(data_root / "trader-v4" / "runs" / "unit-split" / "split" / "split_manifest.json")
|
||||
self.assertEqual(set(TRAINING_SPLITS), {item["split_id"] for item in manifest["splits"]})
|
||||
self.assertEqual([VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], manifest["sealed_splits"])
|
||||
self.assertEqual("FINAL_GATE_ONLY", manifest["latest_stress_policy"])
|
||||
|
||||
def test_exported_onnx_accepts_java_feature_shape(self) -> None:
|
||||
import onnxruntime as ort
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "direction.onnx"
|
||||
export_heads(
|
||||
path,
|
||||
[
|
||||
LinearHead(
|
||||
"direction",
|
||||
"softmax",
|
||||
np.zeros((39, 3), dtype=np.float32),
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
)
|
||||
],
|
||||
)
|
||||
session = ort.InferenceSession(str(path))
|
||||
output = session.run(None, {"features": np.zeros((1, 39), dtype=np.float32)})[0]
|
||||
self.assertEqual((1, 3), output.shape)
|
||||
self.assertAlmostEqual(1.0, float(output.sum()), places=6)
|
||||
|
||||
def test_promotion_requires_passed_validation_and_marks_all_manifests_active(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
manifest_dir = root / "manifests"
|
||||
manifest_dir.mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_validation_result.json", {"status": "PASS", "release_gate_status": "PASS", "release_gate_reasons": []})
|
||||
write_json(manifest_dir / "model_bundle_manifest.json", {"status": "CANDIDATE"})
|
||||
write_json(manifest_dir / "model_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
|
||||
write_json(manifest_dir / "calibration_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
|
||||
write_json(manifest_dir / "position_manager_manifest.json", {"status": "CANDIDATE"})
|
||||
write_json(manifest_dir / "training_export_manifest.json", {"status": "CANDIDATE"})
|
||||
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_bundle_manifest.json")["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_manifest.json")[0]["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "calibration_manifest.json")[0]["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "position_manager_manifest.json")["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "training_export_manifest.json")["status"])
|
||||
|
||||
def test_promotion_refuses_failed_validation(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
(root / "manifests").mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_validation_result.json", {"status": "FAIL"})
|
||||
with self.assertRaises(SystemExit):
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
result = read_json(root.parent / "artifact_promotion_result.json")
|
||||
self.assertEqual("REFUSED", result["status"])
|
||||
self.assertEqual("validation result is not PASS", result["message"])
|
||||
|
||||
def test_promotion_refuses_failed_release_gate_and_overwrites_stale_result(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
(root / "manifests").mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_promotion_result.json", {"status": "ACTIVE"})
|
||||
write_json(
|
||||
root.parent / "artifact_validation_result.json",
|
||||
{
|
||||
"status": "PASS",
|
||||
"release_gate_status": "REJECTED",
|
||||
"release_gate_reasons": ["backtest_status=REJECTED"],
|
||||
},
|
||||
)
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
|
||||
result = read_json(root.parent / "artifact_promotion_result.json")
|
||||
self.assertEqual("REFUSED", result["status"])
|
||||
self.assertEqual("release gate is not PASS", result["message"])
|
||||
self.assertEqual(["backtest_status=REJECTED"], result["release_gate_reasons"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user