Implement Trader V4 training artifact pipeline

This commit is contained in:
Codex
2026-06-27 16:15:23 +08:00
parent dad6b831b4
commit e58e4a5572
113 changed files with 7959 additions and 477 deletions
@@ -0,0 +1,177 @@
package com.quantai.trader.artifact;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.enums.TraderRunMode;
import com.quantai.trader.persistence.TraderJsonCodec;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;
@Repository
public class JdbcTraderArtifactManifestRepository implements TraderArtifactManifestRepository {
private final JdbcTemplate jdbcTemplate;
private final TraderJsonCodec jsonCodec;
public JdbcTraderArtifactManifestRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
this.jdbcTemplate = jdbcTemplate;
this.jsonCodec = new TraderJsonCodec(objectMapper);
}
@Override
public void upsertActiveBundle(TraderArtifactBundle bundle) {
upsertModelBundle(bundle.modelBundleManifest());
bundle.modelManifests().forEach(this::upsertModelManifest);
bundle.calibrationManifests().forEach(this::upsertCalibrationManifest);
upsertPmConfigManifest(bundle.pmConfigManifest());
}
private void upsertModelBundle(TraderModelBundleManifest manifest) {
jdbcTemplate.update("""
insert into trader_model_bundle_manifest
(manifest_schema_version, model_bundle_version, calibration_bundle_version,
feature_version, label_version, split_version, training_run_id, training_export_id,
backtest_manifest_id, required_models_json, provided_models_json, missing_models_json,
allowed_run_modes_json, bundle_hash_sha256, complete, status)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
on duplicate key update
manifest_schema_version = values(manifest_schema_version),
feature_version = values(feature_version),
label_version = values(label_version),
split_version = values(split_version),
training_run_id = values(training_run_id),
training_export_id = values(training_export_id),
backtest_manifest_id = values(backtest_manifest_id),
required_models_json = values(required_models_json),
provided_models_json = values(provided_models_json),
missing_models_json = values(missing_models_json),
allowed_run_modes_json = values(allowed_run_modes_json),
bundle_hash_sha256 = values(bundle_hash_sha256),
complete = values(complete),
status = values(status)
""",
manifest.manifestSchemaVersion(), manifest.modelBundleVersion(), manifest.calibrationBundleVersion(),
manifest.featureVersion(), manifest.labelVersion(), manifest.splitVersion(),
manifest.trainingRunId(), manifest.trainingExportId(), manifest.backtestManifestId(),
jsonCodec.toJson(manifest.requiredModels()),
jsonCodec.toJson(manifest.providedModels()), jsonCodec.toJson(manifest.missingModels()),
jsonCodec.toJson(manifest.allowedRunModes().stream().map(TraderRunMode::name).toList()),
manifest.bundleHashSha256(), manifest.complete(), manifest.status());
}
private void upsertModelManifest(TraderModelManifest manifest) {
jdbcTemplate.update("""
insert into trader_model_manifest
(model_bundle_version, calibration_bundle_version, model_name, model_type, side,
symbol_scope_json, bar_interval, horizon_minutes, model_format, model_runtime,
model_runtime_version, onnx_opset_version, producer_name, producer_version,
artifact_path, artifact_hash_sha256, source_hash, feature_version, feature_schema_path,
feature_schema_hash, feature_order_path, feature_order_hash, input_tensor_name, input_dtype, input_shape_json,
input_example_path, output_schema_path, output_schema_hash, output_tensor_names_json,
output_mapping_json, output_value_rules_json, label_version, split_version, training_fold,
train_start, train_end, validation_start, validation_end, test_start, test_end,
metrics_json, status)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
on duplicate key update
model_type = values(model_type),
symbol_scope_json = values(symbol_scope_json),
bar_interval = values(bar_interval),
model_format = values(model_format),
model_runtime = values(model_runtime),
model_runtime_version = values(model_runtime_version),
onnx_opset_version = values(onnx_opset_version),
producer_name = values(producer_name),
producer_version = values(producer_version),
artifact_path = values(artifact_path),
artifact_hash_sha256 = values(artifact_hash_sha256),
source_hash = values(source_hash),
feature_version = values(feature_version),
feature_schema_path = values(feature_schema_path),
feature_schema_hash = values(feature_schema_hash),
feature_order_path = values(feature_order_path),
feature_order_hash = values(feature_order_hash),
input_tensor_name = values(input_tensor_name),
input_dtype = values(input_dtype),
input_shape_json = values(input_shape_json),
input_example_path = values(input_example_path),
output_schema_path = values(output_schema_path),
output_schema_hash = values(output_schema_hash),
output_tensor_names_json = values(output_tensor_names_json),
output_mapping_json = values(output_mapping_json),
output_value_rules_json = values(output_value_rules_json),
label_version = values(label_version),
split_version = values(split_version),
training_fold = values(training_fold),
train_start = values(train_start),
train_end = values(train_end),
validation_start = values(validation_start),
validation_end = values(validation_end),
test_start = values(test_start),
test_end = values(test_end),
metrics_json = values(metrics_json),
status = values(status)
""",
manifest.modelBundleVersion(), manifest.calibrationBundleVersion(), manifest.modelName(),
manifest.modelType(), manifest.side(), jsonCodec.toJson(manifest.symbolScope()),
manifest.barInterval(), manifest.horizonMinutes(), manifest.modelFormat(), manifest.modelRuntime(),
manifest.modelRuntimeVersion(), manifest.onnxOpsetVersion(), manifest.producerName(),
manifest.producerVersion(), manifest.artifactPath(), manifest.artifactHashSha256(),
manifest.sourceHash(), manifest.featureVersion(), manifest.featureSchemaPath(),
manifest.featureSchemaHash(), manifest.featureOrderPath(), manifest.featureOrderHash(), manifest.inputTensorName(),
manifest.inputDtype(), jsonCodec.toJson(manifest.inputShapeJson()), manifest.inputExamplePath(),
manifest.outputSchemaPath(), manifest.outputSchemaHash(), jsonCodec.toJson(manifest.outputTensorNames()),
jsonCodec.toJson(manifest.outputMapping()), jsonCodec.toJson(manifest.outputValueRules()),
manifest.labelVersion(), manifest.splitVersion(), manifest.trainingFold(),
java.sql.Timestamp.from(manifest.trainStart()), java.sql.Timestamp.from(manifest.trainEnd()),
java.sql.Timestamp.from(manifest.validationStart()), java.sql.Timestamp.from(manifest.validationEnd()),
java.sql.Timestamp.from(manifest.testStart()), java.sql.Timestamp.from(manifest.testEnd()),
jsonCodec.toJson(manifest.metricsJson()), manifest.status());
}
private void upsertCalibrationManifest(TraderCalibrationManifest manifest) {
jdbcTemplate.update("""
insert into trader_calibration_manifest
(calibration_bundle_version, model_bundle_version, model_name, calibrator_version,
calibration_method, calibrator_path, calibrator_hash_sha256,
calibration_window_from, calibration_window_to, calibration_metrics_json,
bucket_metrics_json, output_after_calibration_schema_hash, status)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
on duplicate key update
calibrator_version = values(calibrator_version),
calibration_method = values(calibration_method),
calibrator_path = values(calibrator_path),
calibrator_hash_sha256 = values(calibrator_hash_sha256),
calibration_window_from = values(calibration_window_from),
calibration_window_to = values(calibration_window_to),
calibration_metrics_json = values(calibration_metrics_json),
bucket_metrics_json = values(bucket_metrics_json),
output_after_calibration_schema_hash = values(output_after_calibration_schema_hash),
status = values(status)
""",
manifest.calibrationBundleVersion(), manifest.modelBundleVersion(), manifest.modelName(),
manifest.calibratorVersion(), manifest.calibrationMethod(), manifest.calibratorPath(),
manifest.calibratorHashSha256(), java.sql.Timestamp.from(manifest.calibrationWindowFrom()),
java.sql.Timestamp.from(manifest.calibrationWindowTo()), jsonCodec.toJson(manifest.calibrationMetrics()),
jsonCodec.toJson(manifest.bucketMetricsJson()), manifest.outputAfterCalibrationSchemaHash(),
manifest.status());
}
private void upsertPmConfigManifest(TraderPmConfigManifest manifest) {
jdbcTemplate.update("""
insert into trader_pm_config_manifest
(pm_config_version, model_bundle_version, calibration_bundle_version, threshold_stability_json,
allowed_run_modes_json, config_json, config_hash_sha256, status)
values (?, ?, ?, ?, ?, ?, ?, ?)
on duplicate key update
model_bundle_version = values(model_bundle_version),
calibration_bundle_version = values(calibration_bundle_version),
threshold_stability_json = values(threshold_stability_json),
allowed_run_modes_json = values(allowed_run_modes_json),
config_json = values(config_json),
config_hash_sha256 = values(config_hash_sha256),
status = values(status)
""",
manifest.pmConfigVersion(), manifest.modelBundleVersion(), manifest.calibrationBundleVersion(),
jsonCodec.toJson(manifest.thresholdStabilityJson()),
jsonCodec.toJson(manifest.allowedRunModes().stream().map(TraderRunMode::name).toList()),
jsonCodec.toJson(manifest.config()), manifest.configHashSha256(), manifest.status());
}
}
@@ -1,7 +1,11 @@
package com.quantai.trader.artifact;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.domain.TraderPmConfig;
import com.quantai.trader.domain.TraderPricePlanContext;
import com.quantai.trader.enums.TraderErrorCode;
import java.util.List;
import java.util.Set;
public record TraderArtifactBundle(
@@ -10,14 +14,31 @@ public record TraderArtifactBundle(
String pmConfigVersion,
String bundleHashSha256,
Set<String> providedModels,
TraderModelBundleManifest modelBundleManifest,
List<TraderModelManifest> modelManifests,
List<TraderCalibrationManifest> calibrationManifests,
TraderPmConfigManifest pmConfigManifest,
TraderPmConfig pmConfig,
TraderArtifactModelPolicy modelPolicy
TraderPricePlanContext pricePlanContext,
TraderReplayModelFixture replayModelFixture
) {
public TraderArtifactBundle {
if (providedModels == null || !providedModels.containsAll(Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"))) {
throw new IllegalArgumentException("artifact bundle must provide all five V4 models");
}
modelBundleManifest = java.util.Objects.requireNonNull(modelBundleManifest, "modelBundleManifest is required");
modelManifests = List.copyOf(modelManifests);
calibrationManifests = List.copyOf(calibrationManifests);
pmConfigManifest = java.util.Objects.requireNonNull(pmConfigManifest, "pmConfigManifest is required");
pmConfig = java.util.Objects.requireNonNull(pmConfig, "pmConfig is required");
modelPolicy = java.util.Objects.requireNonNull(modelPolicy, "modelPolicy is required");
pricePlanContext = java.util.Objects.requireNonNull(pricePlanContext, "pricePlanContext is required");
}
public TraderReplayModelFixture requireReplayModelFixture() {
if (replayModelFixture == null) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"replay model fixture is required for replay fixture inference");
}
return replayModelFixture;
}
}
@@ -2,9 +2,11 @@ package com.quantai.trader.artifact;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.core.type.TypeReference;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.domain.TraderPmConfig;
import com.quantai.trader.domain.TraderPricePlanContext;
import com.quantai.trader.enums.TraderRunMode;
import com.quantai.trader.enums.TraderErrorCode;
import org.slf4j.Logger;
@@ -14,6 +16,11 @@ import org.springframework.stereotype.Component;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@@ -22,6 +29,27 @@ import java.util.stream.StreamSupport;
public class TraderArtifactLoader {
private static final Logger log = LoggerFactory.getLogger(TraderArtifactLoader.class);
private static final Set<String> REQUIRED_MODELS = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
private static final int REQUIRED_FEATURE_COUNT = 39;
private static final int REQUIRED_ONNX_OPSET_VERSION = 17;
private static final Map<String, Set<String>> REQUIRED_OUTPUT_MAPPING_KEYS = Map.of(
"DIRECTION", Set.of("long_prob", "short_prob", "neutral_prob"),
"ENTRY", Set.of("long_entry_prob", "short_entry_prob", "long_expected_net_edge_bps", "short_expected_net_edge_bps"),
"CONTINUE", Set.of("long_continue_prob", "short_continue_prob", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"),
"EXIT", Set.of("long_exit_prob", "short_exit_prob", "long_adverse_move_bps", "short_adverse_move_bps",
"adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"),
"RISK", Set.of("market_risk_prob", "long_position_risk_prob", "short_position_risk_prob",
"market_path_risk_bps", "long_position_path_risk_bps", "short_position_path_risk_bps",
"market_drawdown_prob", "volatility_expansion_prob", "spike_prob",
"liquidity_deterioration_prob", "position_drawdown_prob")
);
private static final Set<String> REJECTED_OUTPUT_MAPPING_KEYS = Set.of(
"expected_net_edge_bps",
"expected_continue_edge_bps",
"expected_giveback_bps",
"market_expected_shortfall_bps",
"position_expected_shortfall_bps",
"position_risk_target"
);
private final TraderProperties properties;
private final ObjectMapper objectMapper;
@@ -35,10 +63,15 @@ public class TraderArtifactLoader {
TraderProperties.Artifact artifact = properties.artifact();
Path root = Path.of(artifact.artifactRoot());
TraderModelBundleManifest modelManifest = readModelBundleManifest(root.resolve("manifests/model_bundle_manifest.json"));
List<TraderModelManifest> modelManifests = readModelManifests(root.resolve("manifests/model_manifest.json"));
List<TraderCalibrationManifest> calibrationManifests = readCalibrationManifests(root.resolve("manifests/calibration_manifest.json"));
TraderPmConfigManifest pmManifest = readPmConfigManifest(root.resolve("manifests/position_manager_manifest.json"));
TraderArtifactModelPolicy modelPolicy = readJson(root.resolve("model_output_policy.json"), TraderArtifactModelPolicy.class);
TraderPricePlanContext pricePlanContext = readJson(root.resolve("price_plan_context.json"), TraderPricePlanContext.class);
TraderReplayModelFixture replayModelFixture = readOptionalJson(root.resolve("replay_model_fixture.json"), TraderReplayModelFixture.class);
validateVersions(artifact, modelManifest, pmManifest);
validateModelManifest(modelManifest);
validateModelArtifacts(root, modelManifest, modelManifests);
validateCalibrationArtifacts(root, modelManifest, calibrationManifests);
validatePmManifest(pmManifest, properties.runMode());
TraderArtifactBundle bundle = new TraderArtifactBundle(
modelManifest.modelBundleVersion(),
@@ -46,24 +79,111 @@ public class TraderArtifactLoader {
pmManifest.pmConfigVersion(),
modelManifest.bundleHashSha256(),
modelManifest.providedModels(),
modelManifest,
modelManifests,
calibrationManifests,
pmManifest,
pmManifest.config(),
modelPolicy);
pricePlanContext,
replayModelFixture);
log.info("event=trader.artifact.loaded modelBundleVersion={} calibrationBundleVersion={} pmConfigVersion={} providedModels={}",
bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion(), bundle.providedModels());
return bundle;
}
private List<TraderModelManifest> readModelManifests(Path path) {
JsonNode root = readJsonNode(path);
if (!root.isArray()) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model manifest must be an array: " + path);
}
return StreamSupport.stream(root.spliterator(), false)
.map(node -> new TraderModelManifest(
requiredText(node, "model_bundle_version", path),
requiredText(node, "calibration_bundle_version", path),
requiredText(node, "model_name", path),
requiredText(node, "model_type", path),
requiredText(node, "side", path),
textSet(node, "symbol_scope_json", path),
requiredText(node, "bar_interval", path),
node.path("horizon_minutes").asInt(-1),
requiredText(node, "model_format", path),
requiredText(node, "model_runtime", path),
requiredText(node, "model_runtime_version", path),
node.path("onnx_opset_version").asInt(-1),
requiredText(node, "producer_name", path),
requiredText(node, "producer_version", path),
requiredText(node, "feature_version", path),
requiredText(node, "feature_schema_path", path),
requiredText(node, "feature_schema_hash", path),
requiredText(node, "feature_order_path", path),
requiredText(node, "feature_order_hash", path),
requiredText(node, "input_tensor_name", path),
requiredText(node, "input_dtype", path),
objectMap(node, "input_shape_json", path),
requiredText(node, "input_example_path", path),
requiredText(node, "output_schema_path", path),
requiredText(node, "output_schema_hash", path),
textSet(node, "output_tensor_names_json", path),
objectMap(node, "output_mapping_json", path),
objectMap(node, "output_value_rules_json", path),
requiredText(node, "label_version", path),
requiredText(node, "split_version", path),
requiredText(node, "training_fold", path),
instant(node, "train_start", path),
instant(node, "train_end", path),
instant(node, "validation_start", path),
instant(node, "validation_end", path),
instant(node, "test_start", path),
instant(node, "test_end", path),
objectMap(node, "metrics_json", path),
requiredText(node, "artifact_path", path),
requiredText(node, "artifact_hash_sha256", path),
requiredText(node, "source_hash", path),
requiredText(node, "status", path)))
.toList();
}
private List<TraderCalibrationManifest> readCalibrationManifests(Path path) {
JsonNode root = readJsonNode(path);
if (!root.isArray()) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"calibration manifest must be an array: " + path);
}
return StreamSupport.stream(root.spliterator(), false)
.map(node -> new TraderCalibrationManifest(
requiredText(node, "calibration_bundle_version", path),
requiredText(node, "model_bundle_version", path),
requiredText(node, "model_name", path),
requiredText(node, "calibrator_version", path),
requiredText(node, "calibration_method", path),
requiredText(node, "calibrator_path", path),
requiredText(node, "calibrator_hash_sha256", path),
instant(node, "calibration_window_from", path),
instant(node, "calibration_window_to", path),
objectMap(node, "calibration_metrics_json", path),
objectMap(node, "bucket_metrics_json", path),
requiredText(node, "output_after_calibration_schema_hash", path),
requiredText(node, "status", path)))
.toList();
}
private TraderModelBundleManifest readModelBundleManifest(Path path) {
JsonNode root = readJsonNode(path);
return new TraderModelBundleManifest(
requiredText(root, "manifest_schema_version", path),
requiredText(root, "model_bundle_version", path),
requiredText(root, "calibration_bundle_version", path),
requiredText(root, "feature_version", path),
requiredText(root, "label_version", path),
requiredText(root, "split_version", path),
requiredText(root, "training_run_id", path),
requiredText(root, "training_export_id", path),
requiredText(root, "backtest_manifest_id", path),
textSet(root, "required_models_json", path),
textSet(root, "provided_models_json", path),
textSet(root, "missing_models_json", path),
enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path),
requiredText(root, "bundle_hash_sha256", path),
root.path("complete").asBoolean(false),
requiredText(root, "status", path));
@@ -75,6 +195,7 @@ public class TraderArtifactLoader {
requiredText(root, "pm_config_version", path),
requiredText(root, "model_bundle_version", path),
requiredText(root, "calibration_bundle_version", path),
objectMap(root, "threshold_stability_json", path),
enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path),
convert(root.path("config_json"), TraderPmConfig.class, path),
requiredText(root, "config_hash_sha256", path),
@@ -106,6 +227,86 @@ public class TraderArtifactLoader {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model bundle must provide all five V4 models with no missing model");
}
if (!manifest.allowedRunModes().contains(properties.runMode())) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model bundle does not allow current run mode");
}
}
private void validateModelArtifacts(Path root, TraderModelBundleManifest bundleManifest,
List<TraderModelManifest> manifests) {
Set<String> modelNames = manifests.stream().map(TraderModelManifest::modelName).collect(Collectors.toUnmodifiableSet());
if (!modelNames.containsAll(REQUIRED_MODELS)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model_manifest.json must contain all five V4 model families");
}
for (TraderModelManifest manifest : manifests) {
if (!bundleManifest.modelBundleVersion().equals(manifest.modelBundleVersion())
|| !bundleManifest.calibrationBundleVersion().equals(manifest.calibrationBundleVersion())) {
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
"model manifest version does not match bundle manifest");
}
if (!"ACTIVE".equals(manifest.status())) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model manifest must be ACTIVE: " + manifest.modelName());
}
if (!REQUIRED_MODELS.contains(manifest.modelType())
|| !"ONNX".equals(manifest.modelFormat())
|| !"ONNX_RUNTIME_JAVA".equals(manifest.modelRuntime())
|| !"FLOAT32".equals(manifest.inputDtype())
|| manifest.horizonMinutes() <= 0
|| manifest.onnxOpsetVersion() != REQUIRED_ONNX_OPSET_VERSION
|| inputFeatureCount(manifest) != REQUIRED_FEATURE_COUNT
|| !"features".equals(manifest.inputTensorName())) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model manifest runtime contract is invalid: " + manifest.modelName());
}
validateOutputMapping(manifest);
validateSha256(root.resolve(manifest.artifactPath()), manifest.artifactHashSha256());
validateSha256(root.resolve(manifest.featureSchemaPath()), manifest.featureSchemaHash());
validateSha256(root.resolve(manifest.featureOrderPath()), manifest.featureOrderHash());
validateFeatureOrder(root.resolve(manifest.featureOrderPath()));
validateSha256(root.resolve(manifest.outputSchemaPath()), manifest.outputSchemaHash());
if (!Files.isRegularFile(root.resolve(manifest.inputExamplePath()))) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model input example is missing: " + manifest.inputExamplePath());
}
}
}
private void validateOutputMapping(TraderModelManifest manifest) {
if (manifest.outputMapping().keySet().stream().anyMatch(REJECTED_OUTPUT_MAPPING_KEYS::contains)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model output mapping contains rejected legacy key: " + manifest.modelName());
}
Set<String> requiredKeys = REQUIRED_OUTPUT_MAPPING_KEYS.get(manifest.modelType());
if (requiredKeys == null || !manifest.outputMapping().keySet().containsAll(requiredKeys)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"model output mapping does not match V4 contract: " + manifest.modelName());
}
}
private void validateCalibrationArtifacts(Path root, TraderModelBundleManifest modelManifest,
List<TraderCalibrationManifest> calibrationManifests) {
Set<String> calibratedModels = calibrationManifests.stream()
.map(TraderCalibrationManifest::modelName)
.collect(Collectors.toUnmodifiableSet());
if (!calibratedModels.containsAll(REQUIRED_MODELS)) {
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
"calibration_manifest.json must contain all five V4 model families");
}
for (TraderCalibrationManifest calibrationManifest : calibrationManifests) {
if (!modelManifest.modelBundleVersion().equals(calibrationManifest.modelBundleVersion())
|| !modelManifest.calibrationBundleVersion().equals(calibrationManifest.calibrationBundleVersion())) {
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
"calibration manifest version does not match bundle manifest");
}
if (!"ACTIVE".equals(calibrationManifest.status())) {
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
"calibration manifest must be ACTIVE");
}
validateSha256(root.resolve(calibrationManifest.calibratorPath()), calibrationManifest.calibratorHashSha256());
}
}
private void validatePmManifest(TraderPmConfigManifest manifest, TraderRunMode runMode) {
@@ -119,6 +320,29 @@ public class TraderArtifactLoader {
}
}
private int inputFeatureCount(TraderModelManifest manifest) {
Object features = manifest.inputShapeJson().get("features");
if (features instanceof Number number) {
return number.intValue();
}
if (features instanceof String text) {
try {
return Integer.parseInt(text);
} catch (NumberFormatException ignored) {
return -1;
}
}
return -1;
}
private void validateFeatureOrder(Path path) {
JsonNode featureOrder = readJsonNode(path);
if (!featureOrder.isArray() || featureOrder.size() != REQUIRED_FEATURE_COUNT) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"feature_order.json must contain exactly " + REQUIRED_FEATURE_COUNT + " fields: " + path);
}
}
private JsonNode readJsonNode(Path path) {
if (!Files.isRegularFile(path)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
@@ -132,6 +356,38 @@ public class TraderArtifactLoader {
}
}
private void validateSha256(Path path, String expectedHash) {
if (!Files.isRegularFile(path)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"artifact referenced by manifest is missing: " + path);
}
String actualHash;
try {
actualHash = sha256(path);
} catch (IOException exception) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"artifact referenced by manifest cannot be read: " + path);
}
if (!expectedHash.equals(actualHash)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"artifact sha256 mismatch: " + path);
}
}
private String sha256(Path path) throws IOException {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hash = digest.digest(Files.readAllBytes(path));
StringBuilder builder = new StringBuilder(hash.length * 2);
for (byte value : hash) {
builder.append(String.format("%02x", value & 0xff));
}
return builder.toString();
} catch (NoSuchAlgorithmException exception) {
throw new IllegalStateException("SHA-256 is not available", exception);
}
}
private <T> T readJson(Path path, Class<T> type) {
if (!Files.isRegularFile(path)) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
@@ -145,6 +401,13 @@ public class TraderArtifactLoader {
}
}
private <T> T readOptionalJson(Path path, Class<T> type) {
if (!Files.isRegularFile(path)) {
return null;
}
return readJson(path, type);
}
private <T> T convert(JsonNode node, Class<T> type, Path path) {
if (node == null || node.isMissingNode() || node.isNull()) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
@@ -167,6 +430,20 @@ public class TraderArtifactLoader {
return value;
}
private Instant instant(JsonNode node, String field, Path path) {
return Instant.parse(requiredText(node, field, path));
}
private Map<String, Object> objectMap(JsonNode node, String field, Path path) {
JsonNode value = node.path(field);
if (value.isMissingNode() || value.isNull() || !value.isObject()) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"artifact object field is required: " + path + "#" + field);
}
return objectMapper.convertValue(value, new TypeReference<>() {
});
}
private Set<String> textSet(JsonNode node, String field, Path path) {
JsonNode array = node.path(field);
if (!array.isArray()) {
@@ -0,0 +1,5 @@
package com.quantai.trader.artifact;
public interface TraderArtifactManifestRepository {
void upsertActiveBundle(TraderArtifactBundle bundle);
}
@@ -1,76 +0,0 @@
package com.quantai.trader.artifact;
import java.math.BigDecimal;
public record TraderArtifactModelPolicy(
DirectionPolicy direction,
EntryPolicy entry,
ContinuePolicy continuation,
ExitPolicy exit,
RiskPolicy risk,
BigDecimal uncertainty,
BigDecimal oodScore
) {
public record DirectionPolicy(
BigDecimal longProbWhenMarkGteIndex,
BigDecimal longProbWhenMarkLtIndex,
BigDecimal neutralProb,
BigDecimal expectedReturnBps,
int horizonMinutes,
String modelVersion
) {
}
public record EntryPolicy(
BigDecimal longEntryProb,
BigDecimal shortEntryProb,
BigDecimal entryQualityScore,
BigDecimal expectedEdgeBps,
String pricePlanId,
String pricePlanConfigHash,
BigDecimal stopDistanceBps,
BigDecimal targetDistanceBps,
int maxHoldMinutes,
BigDecimal costBps,
String modelVersion
) {
}
public record ContinuePolicy(
BigDecimal longContinueProb,
BigDecimal shortContinueProb,
BigDecimal trendPersistenceProb,
BigDecimal holdEdgeBps,
BigDecimal continueVsExitEdgeBps,
String modelVersion
) {
}
public record ExitPolicy(
BigDecimal longExitProb,
BigDecimal shortExitProb,
BigDecimal profitGivebackProb,
BigDecimal reversalProb,
BigDecimal stopRiskProb,
BigDecimal stagnationProb,
BigDecimal expectedGivebackBps,
String modelVersion
) {
}
public record RiskPolicy(
BigDecimal marketRiskProb,
BigDecimal positionRiskProb,
BigDecimal marketRiskSeverityBps,
BigDecimal positionRiskSeverityBps,
BigDecimal drawdownProb,
BigDecimal expectedShortfallBps,
BigDecimal volatilityExpansionProb,
BigDecimal spikeProb,
BigDecimal liquidityRiskProb,
BigDecimal liquidityCapacityRatioWhenReady,
BigDecimal liquidityCapacityRatioWhenNotReady,
String modelVersion
) {
}
}
@@ -0,0 +1,25 @@
package com.quantai.trader.artifact;
import java.time.Instant;
import java.util.Map;
public record TraderCalibrationManifest(
String calibrationBundleVersion,
String modelBundleVersion,
String modelName,
String calibratorVersion,
String calibrationMethod,
String calibratorPath,
String calibratorHashSha256,
Instant calibrationWindowFrom,
Instant calibrationWindowTo,
Map<String, Object> calibrationMetrics,
Map<String, Object> bucketMetricsJson,
String outputAfterCalibrationSchemaHash,
String status
) {
public TraderCalibrationManifest {
calibrationMetrics = Map.copyOf(calibrationMetrics == null ? Map.of() : calibrationMetrics);
bucketMetricsJson = Map.copyOf(bucketMetricsJson == null ? Map.of() : bucketMetricsJson);
}
}
@@ -1,16 +1,23 @@
package com.quantai.trader.artifact;
import com.quantai.trader.enums.TraderRunMode;
import java.util.Set;
public record TraderModelBundleManifest(
String manifestSchemaVersion,
String modelBundleVersion,
String calibrationBundleVersion,
String featureVersion,
String labelVersion,
String splitVersion,
String trainingRunId,
String trainingExportId,
String backtestManifestId,
Set<String> requiredModels,
Set<String> providedModels,
Set<String> missingModels,
Set<TraderRunMode> allowedRunModes,
String bundleHashSha256,
boolean complete,
String status
@@ -0,0 +1,59 @@
package com.quantai.trader.artifact;
import java.time.Instant;
import java.util.Map;
import java.util.Set;
public record TraderModelManifest(
String modelBundleVersion,
String calibrationBundleVersion,
String modelName,
String modelType,
String side,
Set<String> symbolScope,
String barInterval,
int horizonMinutes,
String modelFormat,
String modelRuntime,
String modelRuntimeVersion,
int onnxOpsetVersion,
String producerName,
String producerVersion,
String featureVersion,
String featureSchemaPath,
String featureSchemaHash,
String featureOrderPath,
String featureOrderHash,
String inputTensorName,
String inputDtype,
Map<String, Object> inputShapeJson,
String inputExamplePath,
String outputSchemaPath,
String outputSchemaHash,
Set<String> outputTensorNames,
Map<String, Object> outputMapping,
Map<String, Object> outputValueRules,
String labelVersion,
String splitVersion,
String trainingFold,
Instant trainStart,
Instant trainEnd,
Instant validationStart,
Instant validationEnd,
Instant testStart,
Instant testEnd,
Map<String, Object> metricsJson,
String artifactPath,
String artifactHashSha256,
String sourceHash,
String status
) {
public TraderModelManifest {
symbolScope = Set.copyOf(symbolScope == null ? Set.of() : symbolScope);
inputShapeJson = Map.copyOf(inputShapeJson == null ? Map.of() : inputShapeJson);
outputTensorNames = Set.copyOf(outputTensorNames == null ? Set.of() : outputTensorNames);
outputMapping = Map.copyOf(outputMapping == null ? Map.of() : outputMapping);
outputValueRules = Map.copyOf(outputValueRules == null ? Map.of() : outputValueRules);
metricsJson = Map.copyOf(metricsJson == null ? Map.of() : metricsJson);
}
}
@@ -3,15 +3,20 @@ package com.quantai.trader.artifact;
import com.quantai.trader.domain.TraderPmConfig;
import com.quantai.trader.enums.TraderRunMode;
import java.util.Map;
import java.util.Set;
public record TraderPmConfigManifest(
String pmConfigVersion,
String modelBundleVersion,
String calibrationBundleVersion,
Map<String, Object> thresholdStabilityJson,
Set<TraderRunMode> allowedRunModes,
TraderPmConfig config,
String configHashSha256,
String status
) {
public TraderPmConfigManifest {
thresholdStabilityJson = Map.copyOf(thresholdStabilityJson == null ? Map.of() : thresholdStabilityJson);
}
}
@@ -0,0 +1,66 @@
package com.quantai.trader.artifact;
import java.math.BigDecimal;
import java.util.Map;
public record TraderReplayModelFixture(
DirectionFixture direction,
EntryFixture entry,
ContinueFixture continuation,
ExitFixture exit,
RiskFixture risk,
BigDecimal uncertainty,
BigDecimal oodScore,
String featureSchemaHash,
String featureOrderHash,
String outputSchemaHash
) {
public record DirectionFixture(
BigDecimal longProbWhenMarkGteIndex,
BigDecimal longProbWhenMarkLtIndex,
BigDecimal neutralProb
) {
}
public record EntryFixture(
BigDecimal longEntryProb,
BigDecimal shortEntryProb,
BigDecimal longExpectedNetEdgeBps,
BigDecimal shortExpectedNetEdgeBps
) {
}
public record ContinueFixture(
BigDecimal longContinueProb,
BigDecimal shortContinueProb,
BigDecimal longExpectedContinueEdgeBps,
BigDecimal shortExpectedContinueEdgeBps
) {
}
public record ExitFixture(
BigDecimal longExitProb,
BigDecimal shortExitProb,
BigDecimal longAdverseMoveBps,
BigDecimal shortAdverseMoveBps,
Map<String, BigDecimal> exitReasonScores
) {
public ExitFixture {
exitReasonScores = Map.copyOf(exitReasonScores == null ? Map.of() : exitReasonScores);
}
}
public record RiskFixture(
BigDecimal marketRiskProb,
BigDecimal longPositionRiskProb,
BigDecimal shortPositionRiskProb,
BigDecimal marketPathRiskBps,
BigDecimal longPositionPathRiskBps,
BigDecimal shortPositionPathRiskBps,
Map<String, BigDecimal> riskReasonScores
) {
public RiskFixture {
riskReasonScores = Map.copyOf(riskReasonScores == null ? Map.of() : riskReasonScores);
}
}
}
@@ -6,6 +6,8 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
import java.math.BigDecimal;
import static com.quantai.trader.util.TraderNumbers.nonNegative;
import static com.quantai.trader.util.TraderNumbers.positive;
import static com.quantai.trader.util.TraderNumbers.requiredText;
@ConfigurationProperties(prefix = "trader")
@@ -23,21 +25,24 @@ public record TraderProperties(
PositionManager positionManager
) {
public TraderProperties {
serviceName = defaultText(serviceName, "quant-trader-service");
runMode = runMode == null ? TraderRunMode.SHADOW : runMode;
symbol = defaultText(symbol, "BTC-USDT-PERP");
artifact = artifact == null ? new Artifact("trader-v4-btc-p0", "cal-v4-btc-p0", "pm-v4-btc-p0", ".") : artifact;
feedback = feedback == null ? new Feedback(false) : feedback;
execution = execution == null ? new Execution(TraderExecutionMode.SHADOW, 3, 1500) : execution;
runtime = runtime == null ? new Runtime("trader:v4", true, false) : runtime;
outbox = outbox == null ? new Outbox(true, 5) : outbox;
release = release == null ? new Release(true, true, true) : release;
risk = risk == null ? new Risk(new BigDecimal("200"), BigDecimal.ONE, new BigDecimal("500")) : risk;
positionManager = positionManager == null ? new PositionManager(BigDecimal.ONE, BigDecimal.ONE) : positionManager;
serviceName = requiredText(serviceName, "serviceName");
runMode = require(runMode, "runMode");
symbol = requiredText(symbol, "symbol");
artifact = require(artifact, "artifact");
feedback = require(feedback, "feedback");
execution = require(execution, "execution");
runtime = require(runtime, "runtime");
outbox = require(outbox, "outbox");
release = require(release, "release");
risk = require(risk, "risk");
positionManager = require(positionManager, "positionManager");
}
private static String defaultText(String value, String defaultValue) {
return value == null || value.isBlank() ? defaultValue : value;
private static <T> T require(T value, String field) {
if (value == null) {
throw new IllegalArgumentException(field + " is required");
}
return value;
}
public record Artifact(
@@ -57,19 +62,30 @@ public record TraderProperties(
public record Feedback(boolean httpEnabled) {
}
public record Execution(TraderExecutionMode mode, int maxApiErrorCount, long maxExchangeLatencyMs) {
public record Execution(TraderExecutionMode mode, Integer maxApiErrorCount, Long maxExchangeLatencyMs) {
public Execution {
mode = mode == null ? TraderExecutionMode.SHADOW : mode;
mode = require(mode, "execution.mode");
if (require(maxApiErrorCount, "execution.maxApiErrorCount") < 0) {
throw new IllegalArgumentException("execution.maxApiErrorCount must be >= 0");
}
if (require(maxExchangeLatencyMs, "execution.maxExchangeLatencyMs") <= 0) {
throw new IllegalArgumentException("execution.maxExchangeLatencyMs must be > 0");
}
}
}
public record Runtime(String redisKeyPrefix, boolean requireRedisForOpenAdd, boolean tradingEnabled) {
public Runtime {
redisKeyPrefix = defaultText(redisKeyPrefix, "trader:v4");
redisKeyPrefix = requiredText(redisKeyPrefix, "runtime.redisKeyPrefix");
}
}
public record Outbox(boolean enabled, int maxRetryCount) {
public record Outbox(boolean enabled, Integer maxRetryCount) {
public Outbox {
if (require(maxRetryCount, "outbox.maxRetryCount") < 0) {
throw new IllegalArgumentException("outbox.maxRetryCount must be >= 0");
}
}
}
public record Release(boolean requireReviewForPaper, boolean requireReviewForLiveProbe, boolean activePointerCheckEnabled) {
@@ -77,16 +93,16 @@ public record TraderProperties(
public record Risk(BigDecimal maxDailyLossBps, BigDecimal maxTotalExposureRatio, BigDecimal minLiquidationBufferBps) {
public Risk {
maxDailyLossBps = maxDailyLossBps == null ? new BigDecimal("200") : maxDailyLossBps;
maxTotalExposureRatio = maxTotalExposureRatio == null ? BigDecimal.ONE : maxTotalExposureRatio;
minLiquidationBufferBps = minLiquidationBufferBps == null ? new BigDecimal("500") : minLiquidationBufferBps;
maxDailyLossBps = nonNegative(maxDailyLossBps, "risk.maxDailyLossBps");
maxTotalExposureRatio = positive(maxTotalExposureRatio, "risk.maxTotalExposureRatio");
minLiquidationBufferBps = nonNegative(minLiquidationBufferBps, "risk.minLiquidationBufferBps");
}
}
public record PositionManager(BigDecimal maxSingleLegRatio, BigDecimal maxTotalPositionRatio) {
public PositionManager {
maxSingleLegRatio = maxSingleLegRatio == null ? BigDecimal.ONE : maxSingleLegRatio;
maxTotalPositionRatio = maxTotalPositionRatio == null ? BigDecimal.ONE : maxTotalPositionRatio;
maxSingleLegRatio = positive(maxSingleLegRatio, "positionManager.maxSingleLegRatio");
maxTotalPositionRatio = positive(maxTotalPositionRatio, "positionManager.maxTotalPositionRatio");
}
}
}
@@ -14,6 +14,6 @@ public class TraderApiExceptionHandler {
@ExceptionHandler(IllegalArgumentException.class)
ResponseEntity<TraderApiError> illegalArgument(IllegalArgumentException exception) {
return ResponseEntity.badRequest().body(new TraderApiError(com.quantai.trader.enums.TraderErrorCode.TRADER_MODEL_OUTPUT_INVALID, exception.getMessage()));
return ResponseEntity.badRequest().body(new TraderApiError(com.quantai.trader.enums.TraderErrorCode.TRADER_REQUEST_INVALID, exception.getMessage()));
}
}
@@ -5,6 +5,7 @@ import com.quantai.trader.domain.FeedbackValidator;
import com.quantai.trader.domain.TraderAppFeedback;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.enums.TraderErrorCode;
import com.quantai.trader.feedback.TraderFeedbackRepository;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.PostMapping;
@@ -18,10 +19,13 @@ public class TraderFeedbackController {
private static final Logger log = LoggerFactory.getLogger(TraderFeedbackController.class);
private final TraderProperties properties;
private final FeedbackValidator feedbackValidator;
private final TraderFeedbackRepository feedbackRepository;
public TraderFeedbackController(TraderProperties properties, FeedbackValidator feedbackValidator) {
public TraderFeedbackController(TraderProperties properties, FeedbackValidator feedbackValidator,
TraderFeedbackRepository feedbackRepository) {
this.properties = properties;
this.feedbackValidator = feedbackValidator;
this.feedbackRepository = feedbackRepository;
}
@PostMapping("/api/trader/feedback")
@@ -30,6 +34,7 @@ public class TraderFeedbackController {
throw new TraderException(TraderErrorCode.TRADER_FEEDBACK_INVALID, "HTTP feedback is disabled in P0");
}
feedbackValidator.validateP0(feedback);
feedbackRepository.insert(feedback);
log.info("event=trader.feedback.accepted runId={} cycleId={} actionId={} source={}",
feedback.runId(), feedback.cycleId(), feedback.actionId(), feedback.feedbackSource());
return Map.of("accepted", true, "feedbackId", feedback.feedbackId());
@@ -2,27 +2,29 @@ package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
import java.util.Map;
import com.quantai.trader.enums.PositionSide;
import java.math.BigDecimal;
public record ContinueOutput(
BigDecimal longContinueProb,
BigDecimal shortContinueProb,
BigDecimal trendPersistenceProb,
BigDecimal holdEdgeBps,
BigDecimal continueVsExitEdgeBps,
String modelVersion,
String calibrationVersion,
Map<String, Object> explanation
BigDecimal longExpectedContinueEdgeBps,
BigDecimal shortExpectedContinueEdgeBps
) {
public ContinueOutput {
longContinueProb = probability(longContinueProb, "continue.longContinueProb");
shortContinueProb = probability(shortContinueProb, "continue.shortContinueProb");
trendPersistenceProb = probability(trendPersistenceProb, "continue.trendPersistenceProb");
holdEdgeBps = required(holdEdgeBps, "continue.holdEdgeBps");
continueVsExitEdgeBps = required(continueVsExitEdgeBps, "continue.continueVsExitEdgeBps");
modelVersion = requiredText(modelVersion, "continue.modelVersion");
calibrationVersion = requiredText(calibrationVersion, "continue.calibrationVersion");
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
longExpectedContinueEdgeBps = required(longExpectedContinueEdgeBps, "continue.longExpectedContinueEdgeBps");
shortExpectedContinueEdgeBps = required(shortExpectedContinueEdgeBps, "continue.shortExpectedContinueEdgeBps");
}
public BigDecimal continueEdgeBpsFor(PositionSide side) {
if (side == PositionSide.LONG) {
return longExpectedContinueEdgeBps;
}
if (side == PositionSide.SHORT) {
return shortExpectedContinueEdgeBps;
}
throw new IllegalArgumentException("continue edge requires LONG or SHORT side");
}
}
@@ -3,32 +3,27 @@ package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
import java.util.Map;
public record DirectionOutput(
BigDecimal longProb,
BigDecimal shortProb,
BigDecimal neutralProb,
BigDecimal directionConfidence,
BigDecimal directionMargin,
BigDecimal expectedReturnBps,
Integer horizonMinutes,
String modelVersion,
String calibrationVersion,
Map<String, Object> explanation
BigDecimal neutralProb
) {
public DirectionOutput {
longProb = probability(longProb, "direction.longProb");
shortProb = probability(shortProb, "direction.shortProb");
neutralProb = probability(neutralProb, "direction.neutralProb");
directionConfidence = probability(directionConfidence, "direction.directionConfidence");
directionMargin = nonNegative(directionMargin, "direction.directionMargin");
expectedReturnBps = required(expectedReturnBps, "direction.expectedReturnBps");
if (horizonMinutes == null || horizonMinutes <= 0) {
throw new IllegalArgumentException("direction.horizonMinutes must be > 0");
BigDecimal sum = longProb.add(shortProb).add(neutralProb);
if (sum.subtract(BigDecimal.ONE).abs().compareTo(new BigDecimal("0.000001")) > 0) {
throw new IllegalArgumentException("direction probabilities must sum to 1");
}
modelVersion = requiredText(modelVersion, "direction.modelVersion");
calibrationVersion = requiredText(calibrationVersion, "direction.calibrationVersion");
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
}
public BigDecimal margin() {
return longProb.subtract(shortProb).abs();
}
public BigDecimal confidence() {
return longProb.max(shortProb).max(neutralProb);
}
}
@@ -2,39 +2,29 @@ package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
import java.util.Map;
import com.quantai.trader.enums.PositionSide;
import java.math.BigDecimal;
public record EntryOutput(
BigDecimal longEntryProb,
BigDecimal shortEntryProb,
BigDecimal entryQualityScore,
BigDecimal expectedEdgeBps,
String pricePlanId,
String pricePlanConfigHash,
BigDecimal stopDistanceBps,
BigDecimal targetDistanceBps,
Integer maxHoldMinutes,
BigDecimal costBps,
String modelVersion,
String calibrationVersion,
Map<String, Object> explanation
BigDecimal longExpectedNetEdgeBps,
BigDecimal shortExpectedNetEdgeBps
) {
public EntryOutput {
longEntryProb = probability(longEntryProb, "entry.longEntryProb");
shortEntryProb = probability(shortEntryProb, "entry.shortEntryProb");
entryQualityScore = probability(entryQualityScore, "entry.entryQualityScore");
expectedEdgeBps = required(expectedEdgeBps, "entry.expectedEdgeBps");
pricePlanId = requiredText(pricePlanId, "entry.pricePlanId");
pricePlanConfigHash = requiredText(pricePlanConfigHash, "entry.pricePlanConfigHash");
stopDistanceBps = positive(stopDistanceBps, "entry.stopDistanceBps");
targetDistanceBps = positive(targetDistanceBps, "entry.targetDistanceBps");
if (maxHoldMinutes == null || maxHoldMinutes <= 0) {
throw new IllegalArgumentException("entry.maxHoldMinutes must be > 0");
longExpectedNetEdgeBps = required(longExpectedNetEdgeBps, "entry.longExpectedNetEdgeBps");
shortExpectedNetEdgeBps = required(shortExpectedNetEdgeBps, "entry.shortExpectedNetEdgeBps");
}
public BigDecimal netEdgeBpsFor(PositionSide side) {
if (side == PositionSide.LONG) {
return longExpectedNetEdgeBps;
}
costBps = nonNegative(costBps, "entry.costBps");
modelVersion = requiredText(modelVersion, "entry.modelVersion");
calibrationVersion = requiredText(calibrationVersion, "entry.calibrationVersion");
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
if (side == PositionSide.SHORT) {
return shortExpectedNetEdgeBps;
}
throw new IllegalArgumentException("entry edge requires LONG or SHORT side");
}
}
@@ -4,29 +4,39 @@ import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
import java.util.Map;
import java.util.Set;
public record ExitOutput(
BigDecimal longExitProb,
BigDecimal shortExitProb,
BigDecimal profitGivebackProb,
BigDecimal reversalProb,
BigDecimal stopRiskProb,
BigDecimal stagnationProb,
BigDecimal expectedGivebackBps,
String modelVersion,
String calibrationVersion,
Map<String, Object> explanation
BigDecimal longAdverseMoveBps,
BigDecimal shortAdverseMoveBps,
Map<String, BigDecimal> exitReasonScores
) {
private static final Set<String> REQUIRED_REASON_KEYS = Set.of(
"adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob");
public ExitOutput {
longExitProb = probability(longExitProb, "exit.longExitProb");
shortExitProb = probability(shortExitProb, "exit.shortExitProb");
profitGivebackProb = probability(profitGivebackProb, "exit.profitGivebackProb");
reversalProb = probability(reversalProb, "exit.reversalProb");
stopRiskProb = probability(stopRiskProb, "exit.stopRiskProb");
stagnationProb = probability(stagnationProb, "exit.stagnationProb");
expectedGivebackBps = nonNegative(expectedGivebackBps, "exit.expectedGivebackBps");
modelVersion = requiredText(modelVersion, "exit.modelVersion");
calibrationVersion = requiredText(calibrationVersion, "exit.calibrationVersion");
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
longAdverseMoveBps = nonNegative(longAdverseMoveBps, "exit.longAdverseMoveBps");
shortAdverseMoveBps = nonNegative(shortAdverseMoveBps, "exit.shortAdverseMoveBps");
exitReasonScores = checkedProbabilities(exitReasonScores, "exit.exitReasonScores");
}
public BigDecimal reasonScore(String key) {
return exitReasonScores.get(requiredText(key, "exit reason key"));
}
private static Map<String, BigDecimal> checkedProbabilities(Map<String, BigDecimal> scores, String field) {
Map<String, BigDecimal> source = scores == null ? Map.of() : scores;
if (!source.keySet().containsAll(REQUIRED_REASON_KEYS)) {
throw new IllegalArgumentException(field + " must contain " + REQUIRED_REASON_KEYS);
}
source.forEach((key, value) -> {
requiredText(key, field + ".key");
probability(value, field + "." + key);
});
return Map.copyOf(source);
}
}
@@ -6,6 +6,7 @@ public record PositionManagerInput(
TraderDecisionCycle cycle,
TraderMarketSnapshot snapshot,
TraderModelOutput modelOutput,
TraderPricePlanContext pricePlanContext,
TraderPositionState positionState,
TraderAccountState accountState,
TraderExecutionState executionState,
@@ -15,6 +16,7 @@ public record PositionManagerInput(
cycle = Objects.requireNonNull(cycle, "cycle is required");
snapshot = Objects.requireNonNull(snapshot, "snapshot is required");
modelOutput = Objects.requireNonNull(modelOutput, "modelOutput is required");
pricePlanContext = Objects.requireNonNull(pricePlanContext, "pricePlanContext is required");
positionState = Objects.requireNonNull(positionState, "positionState is required");
accountState = Objects.requireNonNull(accountState, "accountState is required");
executionState = Objects.requireNonNull(executionState, "executionState is required");
@@ -2,37 +2,68 @@ package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import com.quantai.trader.enums.PositionSide;
import java.math.BigDecimal;
import java.util.Map;
import java.util.Set;
public record RiskOutput(
BigDecimal marketRiskProb,
BigDecimal positionRiskProb,
BigDecimal marketRiskSeverityBps,
BigDecimal positionRiskSeverityBps,
BigDecimal drawdownProb,
BigDecimal expectedShortfallBps,
BigDecimal volatilityExpansionProb,
BigDecimal spikeProb,
BigDecimal liquidityRiskProb,
BigDecimal liquidityCapacityRatio,
String modelVersion,
String calibrationVersion,
Map<String, Object> explanation
BigDecimal longPositionRiskProb,
BigDecimal shortPositionRiskProb,
BigDecimal marketPathRiskBps,
BigDecimal longPositionPathRiskBps,
BigDecimal shortPositionPathRiskBps,
Map<String, BigDecimal> riskReasonScores
) {
private static final Set<String> REQUIRED_REASON_KEYS = Set.of(
"market_drawdown_prob", "volatility_expansion_prob", "spike_prob",
"liquidity_deterioration_prob", "position_drawdown_prob");
public RiskOutput {
marketRiskProb = probability(marketRiskProb, "risk.marketRiskProb");
positionRiskProb = probability(positionRiskProb, "risk.positionRiskProb");
marketRiskSeverityBps = nonNegative(marketRiskSeverityBps, "risk.marketRiskSeverityBps");
positionRiskSeverityBps = nonNegative(positionRiskSeverityBps, "risk.positionRiskSeverityBps");
drawdownProb = probability(drawdownProb, "risk.drawdownProb");
expectedShortfallBps = nonNegative(expectedShortfallBps, "risk.expectedShortfallBps");
volatilityExpansionProb = probability(volatilityExpansionProb, "risk.volatilityExpansionProb");
spikeProb = probability(spikeProb, "risk.spikeProb");
liquidityRiskProb = probability(liquidityRiskProb, "risk.liquidityRiskProb");
liquidityCapacityRatio = nonNegative(liquidityCapacityRatio, "risk.liquidityCapacityRatio");
modelVersion = requiredText(modelVersion, "risk.modelVersion");
calibrationVersion = requiredText(calibrationVersion, "risk.calibrationVersion");
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
longPositionRiskProb = probability(longPositionRiskProb, "risk.longPositionRiskProb");
shortPositionRiskProb = probability(shortPositionRiskProb, "risk.shortPositionRiskProb");
marketPathRiskBps = nonNegative(marketPathRiskBps, "risk.marketPathRiskBps");
longPositionPathRiskBps = nonNegative(longPositionPathRiskBps, "risk.longPositionPathRiskBps");
shortPositionPathRiskBps = nonNegative(shortPositionPathRiskBps, "risk.shortPositionPathRiskBps");
riskReasonScores = checkedProbabilities(riskReasonScores, "risk.riskReasonScores");
}
public BigDecimal reasonScore(String key) {
return riskReasonScores.get(requiredText(key, "risk reason key"));
}
public BigDecimal sideRiskProbFor(PositionSide side) {
if (side == PositionSide.LONG) {
return longPositionRiskProb;
}
if (side == PositionSide.SHORT) {
return shortPositionRiskProb;
}
throw new IllegalArgumentException("position risk requires LONG or SHORT side");
}
public BigDecimal positionPathRiskBpsFor(PositionSide side) {
if (side == PositionSide.LONG) {
return longPositionPathRiskBps;
}
if (side == PositionSide.SHORT) {
return shortPositionPathRiskBps;
}
throw new IllegalArgumentException("position path risk requires LONG or SHORT side");
}
private static Map<String, BigDecimal> checkedProbabilities(Map<String, BigDecimal> scores, String field) {
Map<String, BigDecimal> source = scores == null ? Map.of() : scores;
if (!source.keySet().containsAll(REQUIRED_REASON_KEYS)) {
throw new IllegalArgumentException(field + " must contain " + REQUIRED_REASON_KEYS);
}
source.forEach((key, value) -> {
requiredText(key, field + ".key");
probability(value, field + "." + key);
});
return Map.copyOf(source);
}
}
@@ -2,38 +2,28 @@ package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
import java.util.Map;
import java.util.Objects;
public record TraderModelOutput(
String modelOutputId,
String runId,
String cycleId,
String modelBundleVersion,
String calibrationBundleVersion,
TraderModelOutputMetadata metadata,
DirectionOutput direction,
EntryOutput entry,
ContinueOutput continuation,
ExitOutput exit,
RiskOutput risk,
BigDecimal uncertainty,
BigDecimal oodScore,
Map<String, Object> explanation
RiskOutput risk
) {
public TraderModelOutput {
modelOutputId = requiredText(modelOutputId, "modelOutputId");
runId = requiredText(runId, "runId");
cycleId = requiredText(cycleId, "cycleId");
modelBundleVersion = requiredText(modelBundleVersion, "modelBundleVersion");
calibrationBundleVersion = requiredText(calibrationBundleVersion, "calibrationBundleVersion");
metadata = Objects.requireNonNull(metadata, "metadata is required");
direction = Objects.requireNonNull(direction, "direction is required");
entry = Objects.requireNonNull(entry, "entry is required");
continuation = Objects.requireNonNull(continuation, "continuation is required");
exit = Objects.requireNonNull(exit, "exit is required");
risk = Objects.requireNonNull(risk, "risk is required");
uncertainty = probability(uncertainty, "uncertainty");
oodScore = probability(oodScore, "oodScore");
explanation = Map.copyOf(explanation == null ? Map.of() : explanation);
}
}
@@ -0,0 +1,39 @@
package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
import java.util.Map;
public record TraderModelOutputMetadata(
String modelBundleVersion,
String calibrationBundleVersion,
Map<String, String> modelVersions,
Map<String, String> calibrationVersions,
String featureSchemaHash,
String featureOrderHash,
String outputSchemaHash,
BigDecimal uncertainty,
BigDecimal oodScore
) {
public TraderModelOutputMetadata {
modelBundleVersion = requiredText(modelBundleVersion, "metadata.modelBundleVersion");
calibrationBundleVersion = requiredText(calibrationBundleVersion, "metadata.calibrationBundleVersion");
modelVersions = checkedTextMap(modelVersions, "metadata.modelVersions");
calibrationVersions = checkedTextMap(calibrationVersions, "metadata.calibrationVersions");
featureSchemaHash = requiredText(featureSchemaHash, "metadata.featureSchemaHash");
featureOrderHash = requiredText(featureOrderHash, "metadata.featureOrderHash");
outputSchemaHash = requiredText(outputSchemaHash, "metadata.outputSchemaHash");
uncertainty = probability(uncertainty, "metadata.uncertainty");
oodScore = probability(oodScore, "metadata.oodScore");
}
private static Map<String, String> checkedTextMap(Map<String, String> values, String field) {
Map<String, String> source = values == null ? Map.of() : values;
source.forEach((key, value) -> {
requiredText(key, field + ".key");
requiredText(value, field + "." + key);
});
return Map.copyOf(source);
}
}
@@ -81,22 +81,22 @@ public record TraderPmConfig(
BigDecimal closePositionRiskProb,
BigDecimal closeMarketRiskProb,
BigDecimal closeContinueMax,
BigDecimal reduceGivebackProb,
BigDecimal reduceAdverseMoveProb,
BigDecimal reduceContinueMin,
BigDecimal reduceContinueMax,
BigDecimal minProfitForReduceBps,
BigDecimal maxExpectedShortfallBps
BigDecimal maxPositionPathRiskBps
) {
public ExitRuleConfig {
closeExitProb = probability(closeExitProb, "exit.closeExitProb");
closePositionRiskProb = probability(closePositionRiskProb, "exit.closePositionRiskProb");
closeMarketRiskProb = probability(closeMarketRiskProb, "exit.closeMarketRiskProb");
closeContinueMax = probability(closeContinueMax, "exit.closeContinueMax");
reduceGivebackProb = probability(reduceGivebackProb, "exit.reduceGivebackProb");
reduceAdverseMoveProb = probability(reduceAdverseMoveProb, "exit.reduceAdverseMoveProb");
reduceContinueMin = probability(reduceContinueMin, "exit.reduceContinueMin");
reduceContinueMax = probability(reduceContinueMax, "exit.reduceContinueMax");
minProfitForReduceBps = nonNegative(minProfitForReduceBps, "exit.minProfitForReduceBps");
maxExpectedShortfallBps = nonNegative(maxExpectedShortfallBps, "exit.maxExpectedShortfallBps");
maxPositionPathRiskBps = nonNegative(maxPositionPathRiskBps, "exit.maxPositionPathRiskBps");
}
}
@@ -0,0 +1,25 @@
package com.quantai.trader.domain;
import static com.quantai.trader.util.TraderNumbers.*;
import java.math.BigDecimal;
public record TraderPricePlanContext(
String pricePlanId,
String pricePlanConfigHash,
BigDecimal stopDistanceBps,
BigDecimal targetDistanceBps,
int maxHoldMinutes,
BigDecimal costBps
) {
public TraderPricePlanContext {
pricePlanId = requiredText(pricePlanId, "pricePlan.pricePlanId");
pricePlanConfigHash = requiredText(pricePlanConfigHash, "pricePlan.pricePlanConfigHash");
stopDistanceBps = positive(stopDistanceBps, "pricePlan.stopDistanceBps");
targetDistanceBps = positive(targetDistanceBps, "pricePlan.targetDistanceBps");
if (maxHoldMinutes <= 0) {
throw new IllegalArgumentException("pricePlan.maxHoldMinutes must be > 0");
}
costBps = nonNegative(costBps, "pricePlan.costBps");
}
}
@@ -9,8 +9,11 @@ public enum TraderErrorCode {
TRADER_RISK_BLOCKED,
TRADER_EXECUTION_BLOCKED,
TRADER_FEEDBACK_INVALID,
TRADER_REQUEST_INVALID,
TRADER_P0_MODE_BLOCKED,
TRADER_KILL_SWITCH_ACTIVE,
TRADER_RUNTIME_CONTROL_BLOCKED,
TRADER_OUTBOX_BLOCKED,
TRADER_ACTIVE_POINTER_MISMATCH,
TRADER_PERSISTENCE_FAILED
}
@@ -0,0 +1,144 @@
package com.quantai.trader.feature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.artifact.TraderArtifactBundle;
import com.quantai.trader.artifact.TraderModelManifest;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.domain.TraderMarketSnapshot;
import com.quantai.trader.enums.TraderErrorCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.StreamSupport;
@Component
public class TraderFeatureVectorBuilder {
private static final Logger log = LoggerFactory.getLogger(TraderFeatureVectorBuilder.class);
private static final int REQUIRED_FEATURE_COUNT = 39;
private final TraderProperties properties;
private final ObjectMapper objectMapper;
private final ConcurrentMap<String, List<String>> featureOrderCache = new ConcurrentHashMap<>();
public TraderFeatureVectorBuilder(TraderProperties properties, ObjectMapper objectMapper) {
this.properties = properties;
this.objectMapper = objectMapper;
}
public float[] build(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
TraderModelManifest referenceManifest = referenceManifest(bundle);
if (!referenceManifest.featureVersion().equals(snapshot.featureVersion())) {
log.warn("event=trader.features.version_mismatch runId={} cycleId={} snapshotFeatureVersion={} modelFeatureVersion={}",
snapshot.runId(), snapshot.cycleId(), snapshot.featureVersion(), referenceManifest.featureVersion());
throw modelInputException("snapshot feature version does not match model feature version: "
+ snapshot.featureVersion() + " != " + referenceManifest.featureVersion());
}
List<String> featureOrder = featureOrder(referenceManifest);
rejectUnexpectedFeatures(snapshot, featureOrder);
float[] values = new float[featureOrder.size()];
for (int index = 0; index < featureOrder.size(); index++) {
String featureName = featureOrder.get(index);
Object rawValue = snapshot.featureJson().get(featureName);
values[index] = toFloat(snapshot, rawValue, featureName, index);
}
log.debug("event=trader.features.vector_built runId={} cycleId={} featureVersion={} featureCount={} featureOrderHash={}",
snapshot.runId(), snapshot.cycleId(), snapshot.featureVersion(), values.length, referenceManifest.featureOrderHash());
return values;
}
public List<String> featureOrder(TraderArtifactBundle bundle) {
return featureOrder(referenceManifest(bundle));
}
private TraderModelManifest referenceManifest(TraderArtifactBundle bundle) {
return bundle.modelManifests().stream()
.filter(manifest -> "DIRECTION".equals(manifest.modelType()))
.findFirst()
.orElseThrow(() -> modelInputException("DIRECTION model manifest is required for feature order"));
}
private List<String> featureOrder(TraderModelManifest manifest) {
String cacheKey = manifest.featureOrderHash() + "|" + manifest.featureOrderPath();
// 特征顺序是模型包契约的一部分,按 hash 缓存,避免每轮重复读文件。
return featureOrderCache.computeIfAbsent(cacheKey, ignored -> readFeatureOrder(manifest));
}
private List<String> readFeatureOrder(TraderModelManifest manifest) {
Path path = Path.of(properties.artifact().artifactRoot()).resolve(manifest.featureOrderPath());
if (!Files.isRegularFile(path)) {
throw modelInputException("feature_order.json is missing: " + path);
}
try {
JsonNode root = objectMapper.readTree(path.toFile());
if (!root.isArray() || root.size() != REQUIRED_FEATURE_COUNT) {
throw modelInputException("feature_order.json must contain exactly " + REQUIRED_FEATURE_COUNT + " fields: " + path);
}
List<String> order = StreamSupport.stream(root.spliterator(), false)
.map(JsonNode::asText)
.toList();
Set<String> unique = new LinkedHashSet<>(order);
if (unique.size() != order.size() || order.stream().anyMatch(String::isBlank)) {
throw modelInputException("feature_order.json contains duplicate or blank feature names: " + path);
}
log.info("event=trader.features.order_loaded featureOrderPath={} featureOrderHash={} featureCount={}",
manifest.featureOrderPath(), manifest.featureOrderHash(), order.size());
return order;
} catch (IOException exception) {
throw modelInputException("feature_order.json cannot be read: " + path);
}
}
private void rejectUnexpectedFeatures(TraderMarketSnapshot snapshot, List<String> featureOrder) {
Set<String> allowed = Set.copyOf(featureOrder);
List<String> unexpected = snapshot.featureJson().keySet().stream()
.filter(key -> !allowed.contains(key))
.sorted()
.toList();
if (!unexpected.isEmpty()) {
log.warn("event=trader.features.unexpected_fields runId={} cycleId={} unexpectedFields={}",
snapshot.runId(), snapshot.cycleId(), unexpected);
throw modelInputException("snapshot featureJson contains fields outside feature_order.json: " + unexpected);
}
}
private float toFloat(TraderMarketSnapshot snapshot, Object rawValue, String featureName, int index) {
if (rawValue == null) {
log.warn("event=trader.features.missing runId={} cycleId={} featureIndex={} featureName={}",
snapshot.runId(), snapshot.cycleId(), index + 1, featureName);
throw modelInputException("snapshot feature is missing: " + featureName);
}
double value;
if (rawValue instanceof BigDecimal decimal) {
value = decimal.doubleValue();
} else if (rawValue instanceof Number number) {
value = number.doubleValue();
} else {
log.warn("event=trader.features.non_numeric runId={} cycleId={} featureIndex={} featureName={} valueType={}",
snapshot.runId(), snapshot.cycleId(), index + 1, featureName, rawValue.getClass().getName());
throw modelInputException("snapshot feature must be numeric: " + featureName);
}
if (!Double.isFinite(value)) {
log.warn("event=trader.features.non_finite runId={} cycleId={} featureIndex={} featureName={} value={}",
snapshot.runId(), snapshot.cycleId(), index + 1, featureName, value);
throw modelInputException("snapshot feature must be finite: " + featureName);
}
return (float) value;
}
private static TraderException modelInputException(String message) {
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
}
}
@@ -0,0 +1,38 @@
package com.quantai.trader.feedback;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.domain.TraderAppFeedback;
import com.quantai.trader.persistence.TraderJsonCodec;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;
import java.sql.Timestamp;
@Repository
public class JdbcTraderFeedbackRepository implements TraderFeedbackRepository {
private final JdbcTemplate jdbcTemplate;
private final TraderJsonCodec jsonCodec;
public JdbcTraderFeedbackRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
this.jdbcTemplate = jdbcTemplate;
this.jsonCodec = new TraderJsonCodec(objectMapper);
}
@Override
public void insert(TraderAppFeedback feedback) {
jdbcTemplate.update("""
insert into trader_app_feedback
(run_id, cycle_id, feedback_id, action_id, feedback_source, is_real_fill,
order_id, order_status, app_received_time, exchange_ack_time, filled_time,
filled_price, filled_quantity, fee, slippage_bps, reject_reason, raw_feedback_json)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
feedback.runId(), feedback.cycleId(), feedback.feedbackId(), feedback.actionId(),
feedback.feedbackSource().name(), feedback.realFill(), feedback.orderId(), feedback.orderStatus(),
feedback.appReceivedTime() == null ? null : Timestamp.from(feedback.appReceivedTime()),
feedback.exchangeAckTime() == null ? null : Timestamp.from(feedback.exchangeAckTime()),
feedback.filledTime() == null ? null : Timestamp.from(feedback.filledTime()),
feedback.filledPrice(), feedback.filledQuantity(), feedback.fee(), feedback.slippageBps(),
feedback.rejectReason(), jsonCodec.toJson(feedback.rawFeedbackJson()));
}
}
@@ -0,0 +1,7 @@
package com.quantai.trader.feedback;
import com.quantai.trader.domain.TraderAppFeedback;
public interface TraderFeedbackRepository {
void insert(TraderAppFeedback feedback);
}
@@ -1,72 +0,0 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactBundle;
import com.quantai.trader.artifact.TraderArtifactModelPolicy;
import com.quantai.trader.domain.*;
import org.springframework.stereotype.Service;
import java.math.BigDecimal;
import java.util.Map;
@Service
public class ArtifactTraderModelService implements TraderModelService {
@Override
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
TraderArtifactModelPolicy policy = bundle.modelPolicy();
DirectionOutput direction = direction(snapshot, bundle, policy.direction());
return new TraderModelOutput(
"model_output_" + snapshot.cycleId(),
snapshot.runId(),
snapshot.cycleId(),
bundle.modelBundleVersion(),
bundle.calibrationBundleVersion(),
direction,
entry(bundle, policy.entry()),
continuation(bundle, policy.continuation()),
exit(bundle, policy.exit()),
risk(snapshot, bundle, policy.risk()),
policy.uncertainty(),
policy.oodScore(),
Map.of("artifactPolicy", bundle.bundleHashSha256()));
}
private DirectionOutput direction(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle, TraderArtifactModelPolicy.DirectionPolicy policy) {
BigDecimal longProb = snapshot.markPrice().compareTo(snapshot.indexPrice()) >= 0
? policy.longProbWhenMarkGteIndex()
: policy.longProbWhenMarkLtIndex();
BigDecimal shortProb = BigDecimal.ONE.subtract(longProb).subtract(policy.neutralProb()).max(BigDecimal.ZERO);
BigDecimal neutralProb = BigDecimal.ONE.subtract(longProb).subtract(shortProb);
return new DirectionOutput(longProb, shortProb, neutralProb, longProb.max(shortProb),
longProb.subtract(shortProb).abs(), policy.expectedReturnBps(), policy.horizonMinutes(),
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
}
private EntryOutput entry(TraderArtifactBundle bundle, TraderArtifactModelPolicy.EntryPolicy policy) {
return new EntryOutput(policy.longEntryProb(), policy.shortEntryProb(), policy.entryQualityScore(),
policy.expectedEdgeBps(), policy.pricePlanId(), policy.pricePlanConfigHash(),
policy.stopDistanceBps(), policy.targetDistanceBps(), policy.maxHoldMinutes(), policy.costBps(),
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
}
private ContinueOutput continuation(TraderArtifactBundle bundle, TraderArtifactModelPolicy.ContinuePolicy policy) {
return new ContinueOutput(policy.longContinueProb(), policy.shortContinueProb(), policy.trendPersistenceProb(),
policy.holdEdgeBps(), policy.continueVsExitEdgeBps(),
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
}
private ExitOutput exit(TraderArtifactBundle bundle, TraderArtifactModelPolicy.ExitPolicy policy) {
return new ExitOutput(policy.longExitProb(), policy.shortExitProb(), policy.profitGivebackProb(),
policy.reversalProb(), policy.stopRiskProb(), policy.stagnationProb(), policy.expectedGivebackBps(),
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
}
private RiskOutput risk(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle, TraderArtifactModelPolicy.RiskPolicy policy) {
BigDecimal liquidityCapacity = snapshot.dataReady()
? policy.liquidityCapacityRatioWhenReady()
: policy.liquidityCapacityRatioWhenNotReady();
return new RiskOutput(policy.marketRiskProb(), policy.positionRiskProb(), policy.marketRiskSeverityBps(),
policy.positionRiskSeverityBps(), policy.drawdownProb(), policy.expectedShortfallBps(),
policy.volatilityExpansionProb(), policy.spikeProb(), policy.liquidityRiskProb(), liquidityCapacity,
policy.modelVersion(), bundle.calibrationBundleVersion(), Map.of("source", "artifact_policy"));
}
}
@@ -0,0 +1,267 @@
package com.quantai.trader.model;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.artifact.TraderArtifactBundle;
import com.quantai.trader.artifact.TraderCalibrationManifest;
import com.quantai.trader.artifact.TraderModelManifest;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.*;
import com.quantai.trader.enums.TraderErrorCode;
import com.quantai.trader.feature.TraderFeatureVectorBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import java.math.BigDecimal;
import java.nio.file.Path;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Service
public class OnnxTraderModelService implements TraderModelService {
private static final Logger log = LoggerFactory.getLogger(OnnxTraderModelService.class);
private static final Pattern OUTPUT_REFERENCE = Pattern.compile("^([A-Za-z0-9_]+)\\[(\\d+)]$");
private static final Set<String> REQUIRED_TYPES = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
private final TraderProperties properties;
private final ObjectMapper objectMapper;
private final TraderFeatureVectorBuilder featureVectorBuilder;
private final TraderOnnxInferenceClient inferenceClient;
public OnnxTraderModelService(TraderProperties properties, ObjectMapper objectMapper,
TraderFeatureVectorBuilder featureVectorBuilder,
TraderOnnxInferenceClient inferenceClient) {
this.properties = properties;
this.objectMapper = objectMapper;
this.featureVectorBuilder = featureVectorBuilder;
this.inferenceClient = inferenceClient;
}
@Override
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
float[] features = featureVectorBuilder.build(snapshot, bundle);
Map<String, TraderModelManifest> manifests = manifestsByType(bundle);
Map<String, Map<String, BigDecimal>> rawOutputs = new LinkedHashMap<>();
Path artifactRoot = Path.of(properties.artifact().artifactRoot());
log.info("event=trader.model.onnx_started runId={} cycleId={} modelBundleVersion={} calibrationBundleVersion={} featureCount={}",
snapshot.runId(), snapshot.cycleId(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), features.length);
for (TraderModelManifest manifest : manifests.values()) {
Path modelPath = artifactRoot.resolve(manifest.artifactPath());
log.debug("event=trader.model.onnx_model_started runId={} cycleId={} modelName={} modelType={} modelPath={}",
snapshot.runId(), snapshot.cycleId(), manifest.modelName(), manifest.modelType(), modelPath);
Map<String, float[]> tensors = inferenceClient.infer(manifest, modelPath, features);
Map<String, BigDecimal> mapped = mappedOutputs(manifest, tensors);
rawOutputs.put(manifest.modelType(), mapped);
log.debug("event=trader.model.onnx_model_mapped runId={} cycleId={} modelName={} outputFieldCount={}",
snapshot.runId(), snapshot.cycleId(), manifest.modelName(), mapped.size());
}
DirectionOutput direction = direction(rawOutputs.get("DIRECTION"), calibrator(bundle, manifests.get("DIRECTION")));
EntryOutput entry = entry(rawOutputs.get("ENTRY"), calibrator(bundle, manifests.get("ENTRY")),
bounds(manifests.get("ENTRY")));
ContinueOutput continuation = continuation(rawOutputs.get("CONTINUE"), calibrator(bundle, manifests.get("CONTINUE")),
bounds(manifests.get("CONTINUE")));
ExitOutput exit = exit(rawOutputs.get("EXIT"), calibrator(bundle, manifests.get("EXIT")),
bounds(manifests.get("EXIT")));
RiskOutput risk = risk(rawOutputs.get("RISK"), calibrator(bundle, manifests.get("RISK")),
bounds(manifests.get("RISK")));
TraderModelOutput output = new TraderModelOutput(
"model_output_" + snapshot.cycleId(),
snapshot.runId(),
snapshot.cycleId(),
metadata(bundle, manifests, direction, snapshot),
direction,
entry,
continuation,
exit,
risk);
log.info("event=trader.model.onnx_evaluated runId={} cycleId={} modelBundleVersion={} featureCount={} directionLong={} directionShort={} marketRisk={}",
snapshot.runId(), snapshot.cycleId(), bundle.modelBundleVersion(), features.length,
direction.longProb(), direction.shortProb(), risk.marketRiskProb());
return output;
}
private Map<String, TraderModelManifest> manifestsByType(TraderArtifactBundle bundle) {
Map<String, TraderModelManifest> manifests = new LinkedHashMap<>();
for (TraderModelManifest manifest : bundle.modelManifests()) {
TraderModelManifest previous = manifests.putIfAbsent(manifest.modelType(), manifest);
if (previous != null) {
throw modelException("model_manifest.json must contain exactly one manifest per V4 model type: "
+ manifest.modelType());
}
}
if (!manifests.keySet().containsAll(REQUIRED_TYPES)) {
throw modelException("model_manifest.json does not contain all five V4 model types");
}
return Map.copyOf(manifests);
}
private Map<String, BigDecimal> mappedOutputs(TraderModelManifest manifest, Map<String, float[]> tensors) {
Map<String, BigDecimal> mapped = new LinkedHashMap<>();
for (Map.Entry<String, Object> entry : manifest.outputMapping().entrySet()) {
// manifest 只允许 tensor[index] 显式映射,避免 Java 根据位置临时猜字段。
if (!(entry.getValue() instanceof String referenceText)) {
throw modelException("output_mapping_json value must be tensor[index]: " + manifest.modelName());
}
Matcher matcher = OUTPUT_REFERENCE.matcher(referenceText);
if (!matcher.matches()) {
throw modelException("output_mapping_json value must be tensor[index]: " + referenceText);
}
String tensorName = matcher.group(1);
int index = Integer.parseInt(matcher.group(2));
float[] values = tensors.get(tensorName);
if (values == null) {
throw modelException("ONNX output tensor is missing: model=" + manifest.modelName()
+ ", tensor=" + tensorName);
}
if (index < 0 || index >= values.length) {
throw modelException("ONNX output tensor index is out of range: model=" + manifest.modelName()
+ ", mapping=" + referenceText);
}
mapped.put(entry.getKey(), BigDecimal.valueOf(values[index]));
}
return Map.copyOf(mapped);
}
private DirectionOutput direction(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator) {
return new DirectionOutput(
probability(raw, calibrator, "long_prob", "longProb"),
probability(raw, calibrator, "short_prob", "shortProb"),
probability(raw, calibrator, "neutral_prob", "neutralProb"));
}
private EntryOutput entry(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
TraderOutputSchemaBounds bounds) {
return new EntryOutput(
probability(raw, calibrator, "long_entry_prob", "longEntryProb"),
probability(raw, calibrator, "short_entry_prob", "shortEntryProb"),
bps(raw, bounds, "long_expected_net_edge_bps", "longExpectedNetEdgeBps"),
bps(raw, bounds, "short_expected_net_edge_bps", "shortExpectedNetEdgeBps"));
}
private ContinueOutput continuation(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
TraderOutputSchemaBounds bounds) {
return new ContinueOutput(
probability(raw, calibrator, "long_continue_prob", "longContinueProb"),
probability(raw, calibrator, "short_continue_prob", "shortContinueProb"),
bps(raw, bounds, "long_expected_continue_edge_bps", "longExpectedContinueEdgeBps"),
bps(raw, bounds, "short_expected_continue_edge_bps", "shortExpectedContinueEdgeBps"));
}
private ExitOutput exit(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
TraderOutputSchemaBounds bounds) {
return new ExitOutput(
probability(raw, calibrator, "long_exit_prob", "longExitProb"),
probability(raw, calibrator, "short_exit_prob", "shortExitProb"),
bps(raw, bounds, "long_adverse_move_bps", "longAdverseMoveBps"),
bps(raw, bounds, "short_adverse_move_bps", "shortAdverseMoveBps"),
Map.of(
"adverse_move_prob", probability(raw, calibrator, "adverse_move_prob", "adverse_move_prob"),
"reversal_prob", probability(raw, calibrator, "reversal_prob", "reversal_prob"),
"stop_hit_prob", probability(raw, calibrator, "stop_hit_prob", "stop_hit_prob"),
"stagnation_prob", probability(raw, calibrator, "stagnation_prob", "stagnation_prob")));
}
private RiskOutput risk(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
TraderOutputSchemaBounds bounds) {
return new RiskOutput(
probability(raw, calibrator, "market_risk_prob", "marketRiskProb"),
probability(raw, calibrator, "long_position_risk_prob", "longPositionRiskProb"),
probability(raw, calibrator, "short_position_risk_prob", "shortPositionRiskProb"),
bps(raw, bounds, "market_path_risk_bps", "marketPathRiskBps"),
bps(raw, bounds, "long_position_path_risk_bps", "longPositionPathRiskBps"),
bps(raw, bounds, "short_position_path_risk_bps", "shortPositionPathRiskBps"),
Map.of(
"market_drawdown_prob", probability(raw, calibrator, "market_drawdown_prob", "market_drawdown_prob"),
"volatility_expansion_prob", probability(raw, calibrator, "volatility_expansion_prob", "volatility_expansion_prob"),
"spike_prob", probability(raw, calibrator, "spike_prob", "spike_prob"),
"liquidity_deterioration_prob", probability(raw, calibrator, "liquidity_deterioration_prob", "liquidity_deterioration_prob"),
"position_drawdown_prob", probability(raw, calibrator, "position_drawdown_prob", "position_drawdown_prob")));
}
private TraderModelOutputMetadata metadata(TraderArtifactBundle bundle, Map<String, TraderModelManifest> manifests,
DirectionOutput direction, TraderMarketSnapshot snapshot) {
Map<String, String> modelVersions = manifests.values().stream()
.collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, TraderModelManifest::sourceHash));
Map<String, String> calibrationVersions = bundle.calibrationManifests().stream()
.collect(Collectors.toUnmodifiableMap(TraderCalibrationManifest::modelName, TraderCalibrationManifest::calibratorVersion));
TraderModelManifest reference = manifests.get("DIRECTION");
// 第一版不再单独输出 uncertainty 子模型,用方向最大概率的反面表示本轮不确定性。
BigDecimal uncertainty = BigDecimal.ONE.subtract(direction.confidence());
BigDecimal oodScore = dataQualityProbability(snapshot, "ood_score");
return new TraderModelOutputMetadata(
bundle.modelBundleVersion(),
bundle.calibrationBundleVersion(),
modelVersions,
calibrationVersions,
reference.featureSchemaHash(),
reference.featureOrderHash(),
reference.outputSchemaHash(),
uncertainty,
oodScore);
}
private TraderProbabilityCalibrator calibrator(TraderArtifactBundle bundle, TraderModelManifest manifest) {
TraderCalibrationManifest calibrationManifest = bundle.calibrationManifests().stream()
.filter(item -> item.modelName().equals(manifest.modelName()))
.findFirst()
.orElseThrow(() -> modelException("calibration manifest is missing for model: " + manifest.modelName()));
Path path = Path.of(properties.artifact().artifactRoot()).resolve(calibrationManifest.calibratorPath());
log.debug("event=trader.model.calibrator_loaded modelName={} calibratorVersion={} calibratorPath={}",
manifest.modelName(), calibrationManifest.calibratorVersion(), path);
return TraderProbabilityCalibrator.read(objectMapper, path, manifest.modelName());
}
private TraderOutputSchemaBounds bounds(TraderModelManifest manifest) {
return TraderOutputSchemaBounds.read(objectMapper,
Path.of(properties.artifact().artifactRoot()).resolve(manifest.outputSchemaPath()));
}
private BigDecimal probability(Map<String, BigDecimal> raw, TraderProbabilityCalibrator calibrator,
String outputKey, String targetName) {
BigDecimal value = requiredRaw(raw, outputKey);
return calibrator.calibrate(targetName, value);
}
private BigDecimal bps(Map<String, BigDecimal> raw, TraderOutputSchemaBounds bounds,
String outputKey, String fieldName) {
return bounds.clip(fieldName, requiredRaw(raw, outputKey));
}
private BigDecimal requiredRaw(Map<String, BigDecimal> raw, String outputKey) {
BigDecimal value = raw.get(outputKey);
if (value == null) {
throw modelException("ONNX mapped output is missing: " + outputKey);
}
return value;
}
private BigDecimal dataQualityProbability(TraderMarketSnapshot snapshot, String field) {
Object raw = snapshot.dataQualityJson().get(field);
BigDecimal value;
if (raw instanceof BigDecimal decimal) {
value = decimal;
} else if (raw instanceof Number number) {
value = BigDecimal.valueOf(number.doubleValue());
} else {
log.warn("event=trader.model.ood_missing runId={} cycleId={} field={}",
snapshot.runId(), snapshot.cycleId(), field);
throw modelException("snapshot dataQualityJson must contain numeric field: " + field);
}
if (value.compareTo(BigDecimal.ZERO) < 0 || value.compareTo(BigDecimal.ONE) > 0) {
log.warn("event=trader.model.ood_out_of_range runId={} cycleId={} field={} value={}",
snapshot.runId(), snapshot.cycleId(), field, value);
throw modelException("snapshot dataQualityJson probability is outside [0,1]: " + field);
}
return value;
}
private static TraderException modelException(String message) {
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
}
}
@@ -0,0 +1,78 @@
package com.quantai.trader.model;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.quantai.trader.artifact.TraderModelManifest;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.enums.TraderErrorCode;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.stereotype.Component;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@Component
public class OrtTraderOnnxInferenceClient implements TraderOnnxInferenceClient, DisposableBean {
private final OrtEnvironment environment = OrtEnvironment.getEnvironment();
private final ConcurrentMap<Path, OrtSession> sessions = new ConcurrentHashMap<>();
@Override
public Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features) {
try {
OrtSession session = sessions.computeIfAbsent(modelPath.toAbsolutePath().normalize(), this::createSession);
try (OnnxTensor input = OnnxTensor.createTensor(environment, new float[][]{features});
OrtSession.Result result = session.run(Map.of(manifest.inputTensorName(), input))) {
Map<String, float[]> outputs = new LinkedHashMap<>();
for (Map.Entry<String, OnnxValue> entry : result) {
outputs.put(entry.getKey(), flatten(entry.getValue()));
}
return Map.copyOf(outputs);
}
} catch (OrtException exception) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"ONNX inference failed for model: " + manifest.modelName());
}
}
private OrtSession createSession(Path path) {
try {
// ONNX session 建立成本高,外层按模型文件路径缓存,进程退出时统一关闭。
return environment.createSession(path.toString(), new OrtSession.SessionOptions());
} catch (OrtException exception) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"ONNX model cannot be loaded: " + path);
}
}
private float[] flatten(OnnxValue value) {
try {
Object raw = value.getValue();
if (raw instanceof float[] values) {
return Arrays.copyOf(values, values.length);
}
if (raw instanceof float[][] matrix && matrix.length == 1) {
return Arrays.copyOf(matrix[0], matrix[0].length);
}
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"ONNX output tensor must be FLOAT32 with shape [n] or [1,n]");
} catch (OrtException exception) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"ONNX output tensor cannot be read");
}
}
@Override
public void destroy() throws Exception {
for (OrtSession session : sessions.values()) {
session.close();
}
sessions.clear();
}
}
@@ -0,0 +1,94 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactBundle;
import com.quantai.trader.artifact.TraderModelManifest;
import com.quantai.trader.artifact.TraderReplayModelFixture;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.*;
import com.quantai.trader.enums.TraderErrorCode;
import com.quantai.trader.enums.TraderRunMode;
import org.springframework.stereotype.Service;
import java.math.BigDecimal;
import java.util.Map;
import java.util.stream.Collectors;
@Service
public class ReplayFixtureTraderModelService implements TraderModelService {
private final TraderRunMode runMode;
public ReplayFixtureTraderModelService(TraderProperties properties) {
this.runMode = properties.runMode();
}
public ReplayFixtureTraderModelService() {
this.runMode = TraderRunMode.REPLAY_SIM;
}
@Override
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
if (runMode != TraderRunMode.REPLAY_SIM) {
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"replay fixture model service only allows REPLAY_SIM");
}
TraderReplayModelFixture fixture = bundle.requireReplayModelFixture();
DirectionOutput direction = direction(snapshot, fixture.direction());
return new TraderModelOutput(
"model_output_" + snapshot.cycleId(),
snapshot.runId(),
snapshot.cycleId(),
metadata(bundle, fixture),
direction,
entry(fixture.entry()),
continuation(fixture.continuation()),
exit(fixture.exit()),
risk(fixture.risk()));
}
private TraderModelOutputMetadata metadata(TraderArtifactBundle bundle, TraderReplayModelFixture fixture) {
Map<String, String> modelVersions = bundle.modelManifests().stream()
.collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, ignored -> bundle.modelBundleVersion()));
Map<String, String> calibrationVersions = bundle.modelManifests().stream()
.collect(Collectors.toUnmodifiableMap(TraderModelManifest::modelName, ignored -> bundle.calibrationBundleVersion()));
return new TraderModelOutputMetadata(
bundle.modelBundleVersion(),
bundle.calibrationBundleVersion(),
modelVersions,
calibrationVersions,
fixture.featureSchemaHash(),
fixture.featureOrderHash(),
fixture.outputSchemaHash(),
fixture.uncertainty(),
fixture.oodScore());
}
private DirectionOutput direction(TraderMarketSnapshot snapshot, TraderReplayModelFixture.DirectionFixture fixture) {
BigDecimal longProb = snapshot.markPrice().compareTo(snapshot.indexPrice()) >= 0
? fixture.longProbWhenMarkGteIndex()
: fixture.longProbWhenMarkLtIndex();
BigDecimal shortProb = BigDecimal.ONE.subtract(longProb).subtract(fixture.neutralProb()).max(BigDecimal.ZERO);
BigDecimal neutralProb = BigDecimal.ONE.subtract(longProb).subtract(shortProb);
return new DirectionOutput(longProb, shortProb, neutralProb);
}
private EntryOutput entry(TraderReplayModelFixture.EntryFixture fixture) {
return new EntryOutput(fixture.longEntryProb(), fixture.shortEntryProb(),
fixture.longExpectedNetEdgeBps(), fixture.shortExpectedNetEdgeBps());
}
private ContinueOutput continuation(TraderReplayModelFixture.ContinueFixture fixture) {
return new ContinueOutput(fixture.longContinueProb(), fixture.shortContinueProb(),
fixture.longExpectedContinueEdgeBps(), fixture.shortExpectedContinueEdgeBps());
}
private ExitOutput exit(TraderReplayModelFixture.ExitFixture fixture) {
return new ExitOutput(fixture.longExitProb(), fixture.shortExitProb(),
fixture.longAdverseMoveBps(), fixture.shortAdverseMoveBps(), fixture.exitReasonScores());
}
private RiskOutput risk(TraderReplayModelFixture.RiskFixture fixture) {
return new RiskOutput(fixture.marketRiskProb(), fixture.longPositionRiskProb(), fixture.shortPositionRiskProb(),
fixture.marketPathRiskBps(), fixture.longPositionPathRiskBps(), fixture.shortPositionPathRiskBps(),
fixture.riskReasonScores());
}
}
@@ -0,0 +1,49 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderArtifactBundle;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.domain.TraderMarketSnapshot;
import com.quantai.trader.domain.TraderModelOutput;
import com.quantai.trader.enums.TraderErrorCode;
import com.quantai.trader.enums.TraderRunMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
@Primary
@Service
public class RoutingTraderModelService implements TraderModelService {
private static final Logger log = LoggerFactory.getLogger(RoutingTraderModelService.class);
private final TraderProperties properties;
private final ReplayFixtureTraderModelService replayFixtureTraderModelService;
private final OnnxTraderModelService onnxTraderModelService;
public RoutingTraderModelService(TraderProperties properties,
ReplayFixtureTraderModelService replayFixtureTraderModelService,
OnnxTraderModelService onnxTraderModelService) {
this.properties = properties;
this.replayFixtureTraderModelService = replayFixtureTraderModelService;
this.onnxTraderModelService = onnxTraderModelService;
}
@Override
public TraderModelOutput evaluate(TraderMarketSnapshot snapshot, TraderArtifactBundle bundle) {
if (properties.runMode() == TraderRunMode.REPLAY_SIM) {
log.info("event=trader.model.route runId={} cycleId={} runMode=REPLAY_SIM modelService=ReplayFixtureTraderModelService",
snapshot.runId(), snapshot.cycleId());
return replayFixtureTraderModelService.evaluate(snapshot, bundle);
}
if (properties.runMode() == TraderRunMode.SHADOW) {
log.info("event=trader.model.route runId={} cycleId={} runMode=SHADOW modelService=OnnxTraderModelService",
snapshot.runId(), snapshot.cycleId());
return onnxTraderModelService.evaluate(snapshot, bundle);
}
log.warn("event=trader.model.route_rejected runId={} cycleId={} runMode={}",
snapshot.runId(), snapshot.cycleId(), properties.runMode());
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
"P0 model service only supports REPLAY_SIM and SHADOW");
}
}
@@ -0,0 +1,10 @@
package com.quantai.trader.model;
import com.quantai.trader.artifact.TraderModelManifest;
import java.nio.file.Path;
import java.util.Map;
public interface TraderOnnxInferenceClient {
Map<String, float[]> infer(TraderModelManifest manifest, Path modelPath, float[] features);
}
@@ -0,0 +1,84 @@
package com.quantai.trader.model;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.enums.TraderErrorCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
final class TraderOutputSchemaBounds {
private static final Logger log = LoggerFactory.getLogger(TraderOutputSchemaBounds.class);
private final Map<String, Range> ranges;
private TraderOutputSchemaBounds(Map<String, Range> ranges) {
this.ranges = Map.copyOf(ranges);
}
static TraderOutputSchemaBounds read(ObjectMapper objectMapper, Path path) {
if (!Files.isRegularFile(path)) {
throw modelException("output_schema.json is missing: " + path);
}
try {
JsonNode root = objectMapper.readTree(path.toFile());
Map<String, Range> ranges = new HashMap<>();
collectRanges(root, ranges);
if (ranges.isEmpty()) {
throw modelException("output_schema.json must define field ranges: " + path);
}
log.debug("event=trader.output_schema.loaded path={} rangedFieldCount={}", path, ranges.size());
return new TraderOutputSchemaBounds(ranges);
} catch (IOException exception) {
throw modelException("output_schema.json cannot be read: " + path);
}
}
BigDecimal clip(String fieldName, BigDecimal value) {
Range range = ranges.get(fieldName);
if (range == null) {
log.warn("event=trader.output_schema.range_missing field={}", fieldName);
throw modelException("output_schema.json does not define range for field: " + fieldName);
}
if (value.compareTo(range.min()) < 0) {
log.debug("event=trader.output_schema.clipped field={} raw={} clipped={}", fieldName, value, range.min());
return range.min();
}
if (value.compareTo(range.max()) > 0) {
log.debug("event=trader.output_schema.clipped field={} raw={} clipped={}", fieldName, value, range.max());
return range.max();
}
return value;
}
private static void collectRanges(JsonNode node, Map<String, Range> ranges) {
if (!node.isObject()) {
return;
}
Iterator<Map.Entry<String, JsonNode>> iterator = node.properties().iterator();
while (iterator.hasNext()) {
Map.Entry<String, JsonNode> entry = iterator.next();
JsonNode child = entry.getValue();
JsonNode rangeNode = child.path("range");
if (rangeNode.isArray() && rangeNode.size() == 2 && rangeNode.get(0).isNumber() && rangeNode.get(1).isNumber()) {
ranges.put(entry.getKey(), new Range(rangeNode.get(0).decimalValue(), rangeNode.get(1).decimalValue()));
}
collectRanges(child, ranges);
}
}
private static TraderException modelException(String message) {
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
}
private record Range(BigDecimal min, BigDecimal max) {
}
}
@@ -0,0 +1,129 @@
package com.quantai.trader.model;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.enums.TraderErrorCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.StreamSupport;
final class TraderProbabilityCalibrator {
private static final Logger log = LoggerFactory.getLogger(TraderProbabilityCalibrator.class);
private final String modelName;
private final String method;
private final BigDecimal clipMin;
private final BigDecimal clipMax;
private final Map<String, List<Bin>> targets;
private TraderProbabilityCalibrator(String modelName, String method, BigDecimal clipMin, BigDecimal clipMax,
Map<String, List<Bin>> targets) {
this.modelName = modelName;
this.method = method;
this.clipMin = clipMin;
this.clipMax = clipMax;
this.targets = targets;
}
static TraderProbabilityCalibrator read(ObjectMapper objectMapper, Path path, String modelName) {
if (!Files.isRegularFile(path)) {
throw modelException("calibrator is missing: " + path);
}
try {
JsonNode root = objectMapper.readTree(path.toFile());
String method = requiredText(root, "method", path);
if (!"BINNING".equals(method)) {
throw modelException("probability calibrator method must be BINNING for model " + modelName);
}
JsonNode clip = root.path("clip");
BigDecimal clipMin = decimal(clip.path("min"), "clip.min", path);
BigDecimal clipMax = decimal(clip.path("max"), "clip.max", path);
JsonNode targetsNode = root.path("targets");
if (!targetsNode.isObject() || targetsNode.isEmpty()) {
throw modelException("calibrator targets must not be empty: " + path);
}
Map<String, List<Bin>> targets = StreamSupport.stream(targetsNode.properties().spliterator(), false)
.collect(java.util.stream.Collectors.toUnmodifiableMap(
Map.Entry::getKey,
entry -> bins(entry.getValue().path("bins"), entry.getKey(), path)));
log.debug("event=trader.calibrator.loaded modelName={} method={} targetCount={} path={}",
modelName, method, targets.size(), path);
return new TraderProbabilityCalibrator(modelName, method, clipMin, clipMax, targets);
} catch (IOException exception) {
throw modelException("calibrator cannot be read: " + path);
}
}
BigDecimal calibrate(String targetName, BigDecimal rawProbability) {
if (!"BINNING".equals(method)) {
throw modelException("unsupported calibrator method for model " + modelName);
}
List<Bin> bins = targets.get(targetName);
if (bins == null || bins.isEmpty()) {
log.warn("event=trader.calibrator.target_missing modelName={} target={}", modelName, targetName);
throw modelException("calibrator target is missing: model=" + modelName + ", target=" + targetName);
}
for (Bin bin : bins) {
if (rawProbability.compareTo(bin.min()) >= 0 && rawProbability.compareTo(bin.max()) <= 0) {
BigDecimal calibrated = bin.calibrated();
if (calibrated.compareTo(clipMin) < 0 || calibrated.compareTo(clipMax) > 0
|| calibrated.compareTo(BigDecimal.ZERO) < 0 || calibrated.compareTo(BigDecimal.ONE) > 0) {
log.warn("event=trader.calibrator.output_out_of_range modelName={} target={} raw={} calibrated={} clipMin={} clipMax={}",
modelName, targetName, rawProbability, calibrated, clipMin, clipMax);
throw modelException("calibrated probability is outside allowed range: model="
+ modelName + ", target=" + targetName);
}
return calibrated;
}
}
log.warn("event=trader.calibrator.bin_missing modelName={} target={} raw={}",
modelName, targetName, rawProbability);
throw modelException("raw probability does not match any calibrator bin: model="
+ modelName + ", target=" + targetName + ", value=" + rawProbability);
}
private static List<Bin> bins(JsonNode binsNode, String targetName, Path path) {
if (!binsNode.isArray() || binsNode.isEmpty()) {
throw modelException("calibrator target has no bins: " + targetName + " in " + path);
}
List<Bin> bins = new ArrayList<>();
for (JsonNode node : binsNode) {
bins.add(new Bin(
decimal(node.path("min"), targetName + ".min", path),
decimal(node.path("max"), targetName + ".max", path),
decimal(node.path("calibrated"), targetName + ".calibrated", path)));
}
return List.copyOf(bins);
}
private static String requiredText(JsonNode node, String field, Path path) {
String value = node.path(field).asText("");
if (value.isBlank()) {
throw modelException("calibrator field is required: " + field + " in " + path);
}
return value;
}
private static BigDecimal decimal(JsonNode node, String field, Path path) {
if (!node.isNumber()) {
throw modelException("calibrator numeric field is required: " + field + " in " + path);
}
return node.decimalValue();
}
private static TraderException modelException(String message) {
return new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING, message);
}
private record Bin(BigDecimal min, BigDecimal max, BigDecimal calibrated) {
}
}
@@ -1,6 +1,7 @@
package com.quantai.trader.outbox;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.enums.PositionSide;
import com.quantai.trader.persistence.TraderJsonCodec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -35,4 +36,31 @@ public class JdbcTraderOutboxRepository implements TraderOutboxRepository {
log.info("event=trader.outbox.inserted runId={} cycleId={} destination={} aggregateType={} aggregateId={} status={}",
event.runId(), event.cycleId(), event.destination(), event.aggregateType(), event.aggregateId(), event.status());
}
@Override
public void markSent(String outboxId) {
jdbcTemplate.update("""
update trader_outbox
set status = 'SENT', updated_at = current_timestamp(3)
where outbox_id = ?
""",
outboxId);
log.info("event=trader.outbox.sent outboxId={}", outboxId);
}
@Override
public boolean hasUnsentExposureIncrease(String runId, String symbol, PositionSide side) {
Integer count = jdbcTemplate.queryForObject("""
select count(*)
from trader_outbox o
join trader_action a on a.run_id = o.run_id and a.action_id = o.aggregate_id
where o.run_id = ?
and a.symbol = ?
and a.side = ?
and a.action_type in ('OPEN_LONG','OPEN_SHORT','ADD_LONG','ADD_SHORT')
and o.status in ('PENDING','SENDING','FAILED')
""",
Integer.class, runId, symbol, side.name());
return count != null && count > 0;
}
}
@@ -0,0 +1,76 @@
package com.quantai.trader.outbox;
import com.quantai.trader.domain.TraderAction;
import com.quantai.trader.domain.TraderAppFeedback;
import com.quantai.trader.domain.TraderMarketSnapshot;
import com.quantai.trader.enums.FeedbackSource;
import com.quantai.trader.feedback.TraderFeedbackRepository;
import com.quantai.trader.replay.state.TraderPostActionStateRepository;
import com.quantai.trader.replay.state.TraderReplayState;
import com.quantai.trader.replay.state.TraderReplayStateStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
import java.time.Instant;
import java.util.Map;
@Component
public class P0OutboxDispatcher implements TraderOutboxDispatcher {
private static final Logger log = LoggerFactory.getLogger(P0OutboxDispatcher.class);
private final TraderOutboxRepository outboxRepository;
private final TraderFeedbackRepository feedbackRepository;
private final TraderReplayStateStore stateStore;
private final TraderPostActionStateRepository postActionStateRepository;
public P0OutboxDispatcher(TraderOutboxRepository outboxRepository,
TraderFeedbackRepository feedbackRepository,
TraderReplayStateStore stateStore,
TraderPostActionStateRepository postActionStateRepository) {
this.outboxRepository = outboxRepository;
this.feedbackRepository = feedbackRepository;
this.stateStore = stateStore;
this.postActionStateRepository = postActionStateRepository;
}
@Override
@Transactional
public void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination) {
TraderAppFeedback feedback = p0Feedback(action, snapshot, destination);
feedbackRepository.insert(feedback);
outboxRepository.markSent("outbox_" + action.actionId());
TraderReplayState nextState = stateStore.advance(currentState, action, snapshot);
postActionStateRepository.insertPostActionState(nextState);
log.info("event=trader.outbox.dispatched runId={} cycleId={} action={} destination={} feedbackId={}",
action.runId(), action.cycleId(), action.actionType(), destination, feedback.feedbackId());
}
private TraderAppFeedback p0Feedback(TraderAction action, TraderMarketSnapshot snapshot, String destination) {
FeedbackSource source = switch (destination) {
case "REPLAY_SIM_EXECUTION" -> FeedbackSource.REPLAY_SIMULATOR;
case "SHADOW_RECORDER" -> FeedbackSource.SHADOW_APP;
default -> throw new IllegalArgumentException("P0 outbox destination is not allowed: " + destination);
};
Instant now = snapshot.snapshotTime();
return new TraderAppFeedback(
"feedback_" + action.actionId(),
action.runId(),
action.cycleId(),
action.actionId(),
source,
false,
null,
"RECORDED",
now,
now,
null,
null,
null,
null,
null,
null,
Map.of("destination", destination, "actionType", action.actionType().name()));
}
}
@@ -0,0 +1,9 @@
package com.quantai.trader.outbox;
import com.quantai.trader.domain.TraderAction;
import com.quantai.trader.domain.TraderMarketSnapshot;
import com.quantai.trader.replay.state.TraderReplayState;
public interface TraderOutboxDispatcher {
void dispatch(TraderAction action, TraderReplayState currentState, TraderMarketSnapshot snapshot, String destination);
}
@@ -1,5 +1,11 @@
package com.quantai.trader.outbox;
import com.quantai.trader.enums.PositionSide;
public interface TraderOutboxRepository {
void insert(TraderOutboxEvent event);
void markSent(String outboxId);
boolean hasUnsentExposureIncrease(String runId, String symbol, PositionSide side);
}
@@ -30,10 +30,16 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter
TraderMarketSnapshot snapshot,
TraderModelOutput modelOutput,
TraderPositionState positionState,
TraderAccountState accountState,
TraderExecutionState executionState,
TraderPositionManagerDecision pmDecision,
TraderRiskDecision riskDecision,
TraderAction action) {
upsertRun(cycle);
insertMarketSnapshot(snapshot);
insertPositionState(positionState, "PM_INPUT");
insertAccountState(accountState, "PM_INPUT");
insertExecutionState(executionState, "PM_INPUT");
insertCycle(cycle, positionState, riskDecision);
insertModelOutput(modelOutput, snapshot);
insertPmDecision(pmDecision);
@@ -66,6 +72,63 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter
Timestamp.from(cycle.cycleTime()));
}
void insertMarketSnapshot(TraderMarketSnapshot snapshot) {
jdbcTemplate.update("""
insert into trader_market_snapshot
(run_id, cycle_id, snapshot_id, symbol, snapshot_time, feature_version,
mark_price, index_price, spread_bps, funding_rate_bps,
depth_notional_5bps, depth_notional_10bps, depth_notional_25bps,
data_ready, feature_json, data_quality_json)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
snapshot.runId(), snapshot.cycleId(), snapshot.snapshotId(), snapshot.symbol(),
Timestamp.from(snapshot.snapshotTime()), snapshot.featureVersion(), snapshot.markPrice(),
snapshot.indexPrice(), snapshot.spreadBps(), snapshot.fundingRateBps(),
snapshot.depthNotional5Bps(), snapshot.depthNotional10Bps(), snapshot.depthNotional25Bps(),
snapshot.dataReady(), jsonCodec.toJson(snapshot.featureJson()), jsonCodec.toJson(snapshot.dataQualityJson()));
}
void insertPositionState(TraderPositionState state, String role) {
jdbcTemplate.update("""
insert into trader_position_state
(run_id, cycle_id, position_state_id, state_role, symbol, side, position_ratio,
average_entry_price, current_price, unrealized_pnl_bps, liquidation_buffer_bps,
add_count, remaining_add_capacity, last_add_time)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
state.runId(), state.cycleId(), state.positionStateId(), role, state.symbol(), state.side().name(),
state.positionRatio(), state.averageEntryPrice(), state.currentPrice(), state.unrealizedPnlBps(),
state.liquidationBufferBps(), state.addCount(), state.remainingAddCapacity(),
state.lastAddTime() == null ? null : Timestamp.from(state.lastAddTime()));
}
void insertAccountState(TraderAccountState state, String role) {
jdbcTemplate.update("""
insert into trader_account_state
(run_id, cycle_id, account_state_id, state_role, daily_drawdown_bps,
portfolio_exposure_ratio, remaining_symbol_capacity_ratio, consecutive_losses)
values (?, ?, ?, ?, ?, ?, ?, ?)
""",
state.runId(), state.cycleId(), state.accountStateId(), role, state.dailyDrawdownBps(),
state.portfolioExposureRatio(), state.remainingSymbolCapacityRatio(), state.consecutiveLosses());
}
void insertExecutionState(TraderExecutionState state, String role) {
jdbcTemplate.update("""
insert into trader_execution_state
(run_id, cycle_id, execution_state_id, state_role, symbol, open_orders_json,
expected_slippage_bps, exchange_latency_ms, api_error_count, maker_fee_bps,
taker_fee_bps, min_notional, price_tick_size, lot_size_step_size,
market_lot_size_step_size, liquidity_capacity_ratio)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
state.runId(), state.cycleId(), state.executionStateId(), role, state.symbol(),
jsonCodec.toJson(state.openOrders()), state.expectedSlippageBps(), state.exchangeLatencyMs(),
state.apiErrorCount(), state.makerFeeBps(), state.takerFeeBps(), state.minNotional(),
state.priceTickSize(), state.lotSizeStepSize(), state.marketLotSizeStepSize(),
state.liquidityCapacityRatio());
}
void insertCycle(TraderDecisionCycle cycle, TraderPositionState positionState, TraderRiskDecision riskDecision) {
jdbcTemplate.update("""
insert into trader_decision_cycle
@@ -83,14 +146,15 @@ public class JdbcTraderDecisionTraceWriter implements TraderDecisionTraceWriter
insert into trader_model_output
(run_id, cycle_id, model_output_id, model_bundle_version, calibration_bundle_version,
direction_json, entry_json, continue_json, exit_json, risk_json,
uncertainty, ood_score, usable, blocker)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
metadata_json, uncertainty, ood_score, usable, blocker)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
modelOutput.runId(), modelOutput.cycleId(), modelOutput.modelOutputId(),
modelOutput.modelBundleVersion(), modelOutput.calibrationBundleVersion(),
modelOutput.metadata().modelBundleVersion(), modelOutput.metadata().calibrationBundleVersion(),
jsonCodec.toJson(modelOutput.direction()), jsonCodec.toJson(modelOutput.entry()),
jsonCodec.toJson(modelOutput.continuation()), jsonCodec.toJson(modelOutput.exit()),
jsonCodec.toJson(modelOutput.risk()), modelOutput.uncertainty(), modelOutput.oodScore(),
jsonCodec.toJson(modelOutput.risk()), jsonCodec.toJson(modelOutput.metadata()),
modelOutput.metadata().uncertainty(), modelOutput.metadata().oodScore(),
snapshot.dataReady(), snapshot.dataReady() ? null : "DATA_NOT_READY");
}
@@ -7,6 +7,8 @@ public interface TraderDecisionTraceWriter {
TraderMarketSnapshot snapshot,
TraderModelOutput modelOutput,
TraderPositionState positionState,
TraderAccountState accountState,
TraderExecutionState executionState,
TraderPositionManagerDecision pmDecision,
TraderRiskDecision riskDecision,
TraderAction action);
@@ -33,7 +33,8 @@ public class TraderPositionManager {
TraderPmConfig.SizingConfig sizing = input.pmConfig().sizing();
EntryOutput entry = input.modelOutput().entry();
RiskOutput risk = input.modelOutput().risk();
BigDecimal expectedEdge = entry.expectedEdgeBps().max(BigDecimal.ZERO);
TraderPricePlanContext pricePlan = input.pricePlanContext();
BigDecimal expectedEdge = entry.netEdgeBpsFor(side).max(BigDecimal.ZERO);
if (expectedEdge.compareTo(sizing.minEdgeBps()) < 0) {
return BigDecimal.ZERO;
}
@@ -41,8 +42,8 @@ public class TraderPositionManager {
? input.modelOutput().direction().longProb()
: input.modelOutput().direction().shortProb();
BigDecimal entryProb = side.isLong() ? entry.longEntryProb() : entry.shortEntryProb();
BigDecimal stopLossBudget = entry.stopDistanceBps().add(entry.costBps()).max(BigDecimal.ONE);
BigDecimal uncertaintyPenalty = BigDecimal.ONE.subtract(input.modelOutput().uncertainty()
BigDecimal stopLossBudget = pricePlan.stopDistanceBps().add(pricePlan.costBps()).max(BigDecimal.ONE);
BigDecimal uncertaintyPenalty = BigDecimal.ONE.subtract(input.modelOutput().metadata().uncertainty()
.multiply(sizing.uncertaintyPenaltyMultiplier())).max(BigDecimal.ZERO);
BigDecimal edgeRiskBudget = TraderNumbers.safeDivide(expectedEdge, stopLossBudget)
.multiply(directionStrength)
@@ -50,7 +51,7 @@ public class TraderPositionManager {
.multiply(BigDecimal.ONE.subtract(risk.marketRiskProb()))
.multiply(uncertaintyPenalty);
BigDecimal raw = sizing.baseRatio().multiply(edgeRiskBudget);
BigDecimal liquidityCap = risk.liquidityCapacityRatio().multiply(sizing.maxLiquidityUsageRatio());
BigDecimal liquidityCap = input.executionState().liquidityCapacityRatio().multiply(sizing.maxLiquidityUsageRatio());
BigDecimal lossBudgetCap = TraderNumbers.safeDivide(sizing.maxLossPerTradeBps(), stopLossBudget);
BigDecimal hardCap = min(sizing.maxSingleLegRatio(), min(input.accountState().remainingSymbolCapacityRatio(), min(liquidityCap, lossBudgetCap)));
if (hardCap.compareTo(sizing.minInitialRatio()) < 0) {
@@ -61,7 +62,8 @@ public class TraderPositionManager {
public BigDecimal calculateAddRatio(PositionManagerInput input) {
BigDecimal raw = input.pmConfig().sizing().baseRatio()
.multiply(input.modelOutput().continuation().continueVsExitEdgeBps().max(BigDecimal.ZERO))
.multiply(input.modelOutput().continuation()
.continueEdgeBpsFor(input.positionState().side()).max(BigDecimal.ZERO))
.divide(new BigDecimal("100"), java.math.MathContext.DECIMAL64);
return TraderNumbers.clamp(raw, input.pmConfig().sizing().minAddRatio(),
min(input.pmConfig().sizing().maxAddRatio(), input.positionState().remainingAddCapacity()));
@@ -73,20 +75,20 @@ public class TraderPositionManager {
RiskOutput risk = input.modelOutput().risk();
TraderPmConfig.OpenRuleConfig open = input.pmConfig().open();
boolean longPass = direction.longProb().compareTo(open.longOpenProb()) > 0
&& direction.directionMargin().compareTo(open.minDirectionMargin()) > 0
&& direction.margin().compareTo(open.minDirectionMargin()) > 0
&& entry.longEntryProb().compareTo(open.minLongEntryProb()) > 0
&& entry.expectedEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
&& entry.longExpectedNetEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
&& risk.marketRiskProb().compareTo(open.maxMarketRiskProb()) < 0
&& risk.liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
&& input.modelOutput().oodScore().compareTo(open.maxOodScore()) <= 0
&& input.executionState().liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
&& input.modelOutput().metadata().oodScore().compareTo(open.maxOodScore()) <= 0
&& input.executionState().openOrders().isEmpty();
boolean shortPass = direction.shortProb().compareTo(open.shortOpenProb()) > 0
&& direction.directionMargin().compareTo(open.minDirectionMargin()) > 0
&& direction.margin().compareTo(open.minDirectionMargin()) > 0
&& entry.shortEntryProb().compareTo(open.minShortEntryProb()) > 0
&& entry.expectedEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
&& entry.shortExpectedNetEdgeBps().compareTo(open.minExpectedEdgeBps()) > 0
&& risk.marketRiskProb().compareTo(open.maxMarketRiskProb()) < 0
&& risk.liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
&& input.modelOutput().oodScore().compareTo(open.maxOodScore()) <= 0
&& input.executionState().liquidityCapacityRatio().compareTo(open.minLiquidityCapacityRatio()) >= 0
&& input.modelOutput().metadata().oodScore().compareTo(open.maxOodScore()) <= 0
&& input.executionState().openOrders().isEmpty();
BigDecimal longRatio = longPass ? calculateInitialRatio(input, PositionSide.LONG) : BigDecimal.ZERO;
BigDecimal shortRatio = shortPass ? calculateInitialRatio(input, PositionSide.SHORT) : BigDecimal.ZERO;
@@ -106,7 +108,7 @@ public class TraderPositionManager {
}
if (shouldReduce(input)) {
TraderActionType action = input.positionState().side().isLong() ? TraderActionType.REDUCE_LONG : TraderActionType.REDUCE_SHORT;
return decision(input, action, input.positionState().side(), null, null, new BigDecimal("0.50"), "PROFIT_GIVEBACK_REDUCE");
return decision(input, action, input.positionState().side(), null, null, new BigDecimal("0.50"), "ADVERSE_MOVE_REDUCE");
}
if (shouldAdd(input)) {
BigDecimal ratio = calculateAddRatio(input);
@@ -130,17 +132,28 @@ public class TraderPositionManager {
: input.modelOutput().continuation().shortContinueProb();
return sidePass
&& continueProb.compareTo(add.minContinueProb()) > 0
&& input.modelOutput().continuation().continueVsExitEdgeBps().compareTo(add.minContinueVsExitEdgeBps()) > 0
&& input.modelOutput().entry().expectedEdgeBps().compareTo(add.minExpectedEdgeBps()) > 0
&& input.modelOutput().continuation().continueEdgeBpsFor(input.positionState().side())
.compareTo(add.minContinueVsExitEdgeBps()) > 0
&& input.modelOutput().entry().netEdgeBpsFor(input.positionState().side())
.compareTo(add.minExpectedEdgeBps()) > 0
&& input.modelOutput().risk().marketRiskProb().compareTo(add.maxMarketRiskProb()) < 0
&& input.modelOutput().risk().positionRiskProb().compareTo(add.maxPositionRiskProb()) < 0
&& input.modelOutput().risk().liquidityCapacityRatio().compareTo(add.minLiquidityCapacityRatio()) >= 0
&& input.modelOutput().risk().sideRiskProbFor(input.positionState().side())
.compareTo(add.maxPositionRiskProb()) < 0
&& input.executionState().liquidityCapacityRatio().compareTo(add.minLiquidityCapacityRatio()) >= 0
&& input.positionState().unrealizedPnlBps().compareTo(BigDecimal.ZERO) > 0
&& input.positionState().addCount() < add.maxAddCount()
&& addCooldownPassed(input, add)
&& input.positionState().remainingAddCapacity().compareTo(BigDecimal.ZERO) > 0
&& input.executionState().openOrders().isEmpty();
}
private boolean addCooldownPassed(PositionManagerInput input, TraderPmConfig.AddRuleConfig add) {
if (input.positionState().lastAddTime() == null) {
return true;
}
return !input.cycle().cycleTime().isBefore(input.positionState().lastAddTime().plusSeconds(add.cooldownMinutes() * 60));
}
private boolean shouldClose(PositionManagerInput input) {
TraderPmConfig.ExitRuleConfig exit = input.pmConfig().exit();
boolean exitProb = input.positionState().side().isLong()
@@ -151,9 +164,11 @@ public class TraderPositionManager {
: input.modelOutput().continuation().shortContinueProb();
return exitProb
|| continueProb.compareTo(exit.closeContinueMax()) < 0
|| input.modelOutput().risk().positionRiskProb().compareTo(exit.closePositionRiskProb()) > 0
|| input.modelOutput().risk().sideRiskProbFor(input.positionState().side())
.compareTo(exit.closePositionRiskProb()) > 0
|| input.modelOutput().risk().marketRiskProb().compareTo(exit.closeMarketRiskProb()) > 0
|| input.modelOutput().risk().expectedShortfallBps().compareTo(exit.maxExpectedShortfallBps()) > 0;
|| input.modelOutput().risk().positionPathRiskBpsFor(input.positionState().side())
.compareTo(exit.maxPositionPathRiskBps()) > 0;
}
private boolean shouldReduce(PositionManagerInput input) {
@@ -161,7 +176,7 @@ public class TraderPositionManager {
BigDecimal continueProb = input.positionState().side().isLong()
? input.modelOutput().continuation().longContinueProb()
: input.modelOutput().continuation().shortContinueProb();
return input.modelOutput().exit().profitGivebackProb().compareTo(exit.reduceGivebackProb()) > 0
return input.modelOutput().exit().reasonScore("adverse_move_prob").compareTo(exit.reduceAdverseMoveProb()) > 0
&& continueProb.compareTo(exit.reduceContinueMin()) >= 0
&& continueProb.compareTo(exit.reduceContinueMax()) <= 0
&& input.positionState().unrealizedPnlBps().compareTo(exit.minProfitForReduceBps()) > 0;
@@ -173,9 +188,9 @@ public class TraderPositionManager {
private TraderPositionManagerDecision decision(PositionManagerInput input, TraderActionType action, PositionSide side,
BigDecimal targetRatio, BigDecimal addRatio, BigDecimal reduceRatio, String reason) {
EntryOutput entry = input.modelOutput().entry();
BigDecimal stopPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), entry.stopDistanceBps(), side, false) : null;
BigDecimal targetPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), entry.targetDistanceBps(), side, true) : null;
TraderPricePlanContext pricePlan = input.pricePlanContext();
BigDecimal stopPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), pricePlan.stopDistanceBps(), side, false) : null;
BigDecimal targetPrice = action.increasesExposure() ? priceFromBps(input.snapshot().markPrice(), pricePlan.targetDistanceBps(), side, true) : null;
return new TraderPositionManagerDecision(
"pm_" + input.cycle().cycleId(),
input.cycle().runId(),
@@ -186,15 +201,17 @@ public class TraderPositionManager {
input.executionState().executionStateId(),
action,
side,
action.increasesExposure() ? entry.pricePlanId() : null,
action.increasesExposure() ? entry.pricePlanConfigHash() : null,
action.increasesExposure() ? pricePlan.pricePlanId() : null,
action.increasesExposure() ? pricePlan.pricePlanConfigHash() : null,
targetRatio,
addRatio,
reduceRatio,
stopPrice,
targetPrice,
reason,
Map.of("pmConfigVersion", input.pmConfig().pmConfigVersion()));
Map.of(
"pmConfigVersion", input.pmConfig().pmConfigVersion(),
"pricePlanMaxHoldMinutes", pricePlan.maxHoldMinutes()));
}
private BigDecimal priceFromBps(BigDecimal markPrice, BigDecimal distanceBps, PositionSide side, boolean profitTarget) {
@@ -1,15 +1,44 @@
package com.quantai.trader.replay;
import static com.quantai.trader.util.TraderNumbers.nonNegative;
import static com.quantai.trader.util.TraderNumbers.positive;
import static com.quantai.trader.util.TraderNumbers.required;
import static com.quantai.trader.util.TraderNumbers.requiredText;
import java.math.BigDecimal;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
public record ReplayMarketEvent(
String runId,
String symbol,
Instant eventTime,
String featureVersion,
BigDecimal markPrice,
BigDecimal indexPrice,
BigDecimal spreadBps,
BigDecimal depthNotional5Bps
BigDecimal fundingRateBps,
BigDecimal depthNotional5Bps,
BigDecimal depthNotional10Bps,
BigDecimal depthNotional25Bps,
boolean dataReady,
Map<String, Object> featureJson,
Map<String, Object> dataQualityJson
) {
public ReplayMarketEvent {
runId = requiredText(runId, "runId");
symbol = requiredText(symbol, "symbol");
eventTime = Objects.requireNonNull(eventTime, "eventTime is required");
featureVersion = requiredText(featureVersion, "featureVersion");
markPrice = positive(markPrice, "markPrice");
indexPrice = positive(indexPrice, "indexPrice");
spreadBps = nonNegative(spreadBps, "spreadBps");
fundingRateBps = required(fundingRateBps, "fundingRateBps");
depthNotional5Bps = nonNegative(depthNotional5Bps, "depthNotional5Bps");
depthNotional10Bps = nonNegative(depthNotional10Bps, "depthNotional10Bps");
depthNotional25Bps = nonNegative(depthNotional25Bps, "depthNotional25Bps");
featureJson = Map.copyOf(featureJson == null ? Map.of() : featureJson);
dataQualityJson = Map.copyOf(dataQualityJson == null ? Map.of() : dataQualityJson);
}
}
@@ -1,6 +1,7 @@
package com.quantai.trader.replay;
import com.quantai.trader.artifact.TraderArtifactBundle;
import com.quantai.trader.artifact.TraderArtifactManifestRepository;
import com.quantai.trader.artifact.TraderArtifactLoader;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.*;
@@ -9,6 +10,7 @@ import com.quantai.trader.evidence.EvidenceAppender;
import com.quantai.trader.model.TraderModelService;
import com.quantai.trader.outbox.TraderOutboxRepository;
import com.quantai.trader.outbox.TraderOutboxEvent;
import com.quantai.trader.outbox.TraderOutboxDispatcher;
import com.quantai.trader.persistence.TraderDecisionTraceWriter;
import com.quantai.trader.position.TraderPositionManager;
import com.quantai.trader.replay.state.TraderReplayState;
@@ -16,11 +18,15 @@ import com.quantai.trader.replay.state.TraderReplayStateStore;
import com.quantai.trader.risk.RiskGateInput;
import com.quantai.trader.risk.RiskLimits;
import com.quantai.trader.risk.TraderRiskGate;
import com.quantai.trader.runtime.TraderRuntimeControlDecision;
import com.quantai.trader.runtime.TraderRuntimeControlService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import java.math.BigDecimal;
import java.time.Instant;
import java.util.Map;
@@ -36,8 +42,11 @@ public class TraderP0CycleRunner {
private final TraderActionFactory actionFactory;
private final EvidenceAppender evidenceAppender;
private final TraderDecisionTraceWriter traceWriter;
private final TraderArtifactManifestRepository artifactManifestRepository;
private final TraderOutboxRepository outboxRepository;
private final TraderOutboxDispatcher outboxDispatcher;
private final TraderReplayStateStore stateStore;
private final TraderRuntimeControlService runtimeControlService;
public TraderP0CycleRunner(TraderProperties properties,
TraderArtifactLoader artifactLoader,
@@ -47,8 +56,11 @@ public class TraderP0CycleRunner {
TraderActionFactory actionFactory,
EvidenceAppender evidenceAppender,
TraderDecisionTraceWriter traceWriter,
TraderArtifactManifestRepository artifactManifestRepository,
TraderOutboxRepository outboxRepository,
TraderReplayStateStore stateStore) {
TraderOutboxDispatcher outboxDispatcher,
TraderReplayStateStore stateStore,
TraderRuntimeControlService runtimeControlService) {
this.properties = properties;
this.artifactLoader = artifactLoader;
this.modelService = modelService;
@@ -57,47 +69,88 @@ public class TraderP0CycleRunner {
this.actionFactory = actionFactory;
this.evidenceAppender = evidenceAppender;
this.traceWriter = traceWriter;
this.artifactManifestRepository = artifactManifestRepository;
this.outboxRepository = outboxRepository;
this.outboxDispatcher = outboxDispatcher;
this.stateStore = stateStore;
this.runtimeControlService = runtimeControlService;
}
@Transactional
public TraderCycleResult runCycle(ReplayMarketEvent event) {
String cycleId = "cycle_" + event.runId() + "_" + event.eventTime().toEpochMilli();
TraderArtifactBundle bundle = artifactLoader.loadActiveBundle();
artifactManifestRepository.upsertActiveBundle(bundle);
TraderDecisionCycle cycle = new TraderDecisionCycle(event.runId(), cycleId, event.symbol(), event.eventTime(),
properties.runMode(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion());
TraderMarketSnapshot snapshot = snapshot(event, cycleId);
TraderReplayState state = stateStore.load(cycle, snapshot);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of());
TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId()));
PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput,
state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig());
TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name()));
TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(),
pmInput.executionState(), snapshot, riskLimits()));
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "RISK_DECISION", riskDecision.allowAction(), riskDecision.allowAction() ? "RISK_PASS" : riskDecision.blocker(), riskDecision.blocker(), Map.of());
TraderAction action = actionFactory.create(pmDecision, riskDecision, event.symbol());
traceWriter.persistCycleTrace(cycle, snapshot, modelOutput, pmInput.positionState(), pmDecision, riskDecision, action);
outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(),
"TRADER_ACTION", action.actionId(), "ACTION_CREATED", properties.runMode().name() + "_RECORDER",
Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now()));
stateStore.advance(state, action, snapshot);
log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING", action.runId(), action.cycleId(), action.actionType());
return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action);
runtimeControlService.acquireCycleLock(cycle);
try {
TraderMarketSnapshot snapshot = snapshot(event, cycleId);
TraderReplayState state = stateStore.load(cycle, snapshot);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of());
TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId()));
PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput,
bundle.pricePlanContext(), state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig());
TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name()));
TraderRuntimeControlDecision runtimeDecision = runtimeControlService.validateExposureIncrease(cycle, pmDecision);
if (runtimeDecision.allowed()
&& pmDecision.candidateAction().increasesExposure()
&& outboxRepository.hasUnsentExposureIncrease(cycle.runId(), cycle.symbol(), pmDecision.side())) {
runtimeDecision = TraderRuntimeControlDecision.block("OUTBOX_PENDING_OPEN_ADD");
}
TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(),
pmInput.executionState(), snapshot, riskLimits(runtimeDecision)));
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "RISK_DECISION", riskDecision.allowAction(), riskDecision.allowAction() ? "RISK_PASS" : riskDecision.blocker(), riskDecision.blocker(), Map.of());
TraderAction action = actionFactory.create(pmDecision, riskDecision, event.symbol());
traceWriter.persistCycleTrace(cycle, snapshot, modelOutput, pmInput.positionState(),
pmInput.accountState(), pmInput.executionState(), pmDecision, riskDecision, action);
String destination = outboxDestination();
outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(),
"TRADER_ACTION", action.actionId(), "ACTION_CREATED", destination,
Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now()));
dispatchAfterCommit(action, state, snapshot, destination);
log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING runtimeBlocker={}",
action.runId(), action.cycleId(), action.actionType(), runtimeDecision.blocker());
return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action);
} finally {
runtimeControlService.releaseCycleLock(cycle);
}
}
private TraderMarketSnapshot snapshot(ReplayMarketEvent event, String cycleId) {
return new TraderMarketSnapshot("snapshot_" + cycleId, event.runId(), cycleId, event.symbol(), event.eventTime(),
"feature-v4-p0", event.markPrice(), event.indexPrice(), event.spreadBps(), BigDecimal.ZERO,
event.depthNotional5Bps(), event.depthNotional5Bps(), event.depthNotional5Bps(),
event.depthNotional5Bps().compareTo(BigDecimal.ZERO) > 0, Map.of(), Map.of());
event.featureVersion(), event.markPrice(), event.indexPrice(), event.spreadBps(), event.fundingRateBps(),
event.depthNotional5Bps(), event.depthNotional10Bps(), event.depthNotional25Bps(),
event.dataReady(), event.featureJson(), event.dataQualityJson());
}
private RiskLimits riskLimits() {
private RiskLimits riskLimits(TraderRuntimeControlDecision runtimeDecision) {
return new RiskLimits(properties.risk().maxDailyLossBps(), properties.risk().maxTotalExposureRatio(),
properties.risk().minLiquidationBufferBps(), properties.execution().maxApiErrorCount(),
properties.execution().maxExchangeLatencyMs(), false, false);
properties.execution().maxExchangeLatencyMs(), false, !runtimeDecision.allowed(), runtimeDecision.blocker());
}
private String outboxDestination() {
return switch (properties.execution().mode()) {
case REPLAY_SIM -> "REPLAY_SIM_EXECUTION";
case SHADOW -> "SHADOW_RECORDER";
case PAPER, REAL -> throw new IllegalStateException("P0 runtime guard must reject PAPER/REAL execution mode");
};
}
private void dispatchAfterCommit(TraderAction action, TraderReplayState state,
TraderMarketSnapshot snapshot, String destination) {
if (!TransactionSynchronizationManager.isSynchronizationActive()) {
outboxDispatcher.dispatch(action, state, snapshot, destination);
return;
}
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
@Override
public void afterCommit() {
outboxDispatcher.dispatch(action, state, snapshot, destination);
}
});
}
}
@@ -0,0 +1,70 @@
package com.quantai.trader.replay.state;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quantai.trader.domain.TraderAccountState;
import com.quantai.trader.domain.TraderExecutionState;
import com.quantai.trader.domain.TraderPositionState;
import com.quantai.trader.persistence.TraderJsonCodec;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;
import java.sql.Timestamp;
@Repository
public class JdbcTraderPostActionStateRepository implements TraderPostActionStateRepository {
private final JdbcTemplate jdbcTemplate;
private final TraderJsonCodec jsonCodec;
public JdbcTraderPostActionStateRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
this.jdbcTemplate = jdbcTemplate;
this.jsonCodec = new TraderJsonCodec(objectMapper);
}
@Override
public void insertPostActionState(TraderReplayState state) {
insertPositionState(state.positionState());
insertAccountState(state.accountState());
insertExecutionState(state.executionState());
}
private void insertPositionState(TraderPositionState state) {
jdbcTemplate.update("""
insert into trader_position_state
(run_id, cycle_id, position_state_id, state_role, symbol, side, position_ratio,
average_entry_price, current_price, unrealized_pnl_bps, liquidation_buffer_bps,
add_count, remaining_add_capacity, last_add_time)
values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
state.runId(), state.cycleId(), state.positionStateId() + "_post", state.symbol(),
state.side().name(), state.positionRatio(), state.averageEntryPrice(), state.currentPrice(),
state.unrealizedPnlBps(), state.liquidationBufferBps(), state.addCount(),
state.remainingAddCapacity(), state.lastAddTime() == null ? null : Timestamp.from(state.lastAddTime()));
}
private void insertAccountState(TraderAccountState state) {
jdbcTemplate.update("""
insert into trader_account_state
(run_id, cycle_id, account_state_id, state_role, daily_drawdown_bps,
portfolio_exposure_ratio, remaining_symbol_capacity_ratio, consecutive_losses)
values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?)
""",
state.runId(), state.cycleId(), state.accountStateId() + "_post", state.dailyDrawdownBps(),
state.portfolioExposureRatio(), state.remainingSymbolCapacityRatio(), state.consecutiveLosses());
}
private void insertExecutionState(TraderExecutionState state) {
jdbcTemplate.update("""
insert into trader_execution_state
(run_id, cycle_id, execution_state_id, state_role, symbol, open_orders_json,
expected_slippage_bps, exchange_latency_ms, api_error_count, maker_fee_bps,
taker_fee_bps, min_notional, price_tick_size, lot_size_step_size,
market_lot_size_step_size, liquidity_capacity_ratio)
values (?, ?, ?, 'POST_ACTION', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
state.runId(), state.cycleId(), state.executionStateId() + "_post", state.symbol(),
jsonCodec.toJson(state.openOrders()), state.expectedSlippageBps(), state.exchangeLatencyMs(),
state.apiErrorCount(), state.makerFeeBps(), state.takerFeeBps(), state.minNotional(),
state.priceTickSize(), state.lotSizeStepSize(), state.marketLotSizeStepSize(),
state.liquidityCapacityRatio());
}
}
@@ -11,7 +11,6 @@ import org.springframework.stereotype.Component;
import java.math.BigDecimal;
import java.math.MathContext;
import java.time.Instant;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@@ -88,7 +87,7 @@ public class P0ReplayStateStore implements TraderReplayStateStore {
BigDecimal weightedEntry = weightedEntry(current.averageEntryPrice(), current.positionRatio(), snapshot.markPrice(), addRatio);
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
current.side(), newRatio, weightedEntry, snapshot.markPrice(), pnlBps(current.side(), weightedEntry, snapshot.markPrice()),
current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), Instant.now());
current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), snapshot.snapshotTime());
}
private TraderPositionState reducePosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
@@ -0,0 +1,5 @@
package com.quantai.trader.replay.state;
public interface TraderPostActionStateRepository {
void insertPostActionState(TraderReplayState state);
}
@@ -9,6 +9,7 @@ public record RiskLimits(
int maxApiErrorCount,
long maxExchangeLatencyMs,
boolean killSwitchActive,
boolean executionBlocked
boolean executionBlocked,
String executionBlocker
) {
}
@@ -17,7 +17,7 @@ public class TraderRiskGate {
if (input.riskLimits().killSwitchActive() && input.pmDecision().candidateAction().increasesExposure()) {
decision = block(input, "KILL_SWITCH_ACTIVE");
} else if (input.riskLimits().executionBlocked()) {
decision = block(input, "EXECUTION_BLOCKED");
decision = block(input, input.riskLimits().executionBlocker());
} else if (input.accountState().dailyDrawdownBps().compareTo(input.riskLimits().maxDailyLossBps()) >= 0) {
decision = block(input, "MAX_DAILY_LOSS");
} else if (input.accountState().portfolioExposureRatio().compareTo(input.riskLimits().maxTotalExposureRatio()) >= 0
@@ -46,6 +46,9 @@ public class TraderRiskGate {
}
private TraderRiskDecision block(RiskGateInput input, String blocker) {
if (blocker == null || blocker.isBlank()) {
blocker = "EXECUTION_BLOCKED";
}
TraderActionType finalAction = input.pmDecision().candidateAction().increasesExposure() ? TraderActionType.WAIT : input.pmDecision().candidateAction();
return new TraderRiskDecision("risk_" + input.pmDecision().cycleId(), input.pmDecision().runId(), input.pmDecision().cycleId(),
input.pmDecision().pmDecisionId(), false, input.pmDecision().candidateAction(), finalAction, blocker,
@@ -0,0 +1,93 @@
package com.quantai.trader.runtime;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.TraderDecisionCycle;
import com.quantai.trader.domain.TraderException;
import com.quantai.trader.domain.TraderPositionManagerDecision;
import com.quantai.trader.enums.TraderErrorCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import java.time.Duration;
@Component
public class RedisTraderRuntimeControlService implements TraderRuntimeControlService {
private static final Logger log = LoggerFactory.getLogger(RedisTraderRuntimeControlService.class);
private final TraderProperties properties;
private final StringRedisTemplate redisTemplate;
public RedisTraderRuntimeControlService(TraderProperties properties, StringRedisTemplate redisTemplate) {
this.properties = properties;
this.redisTemplate = redisTemplate;
}
@Override
public void acquireCycleLock(TraderDecisionCycle cycle) {
String key = key("cycle:lock:" + cycle.symbol() + ":" + cycle.cycleTime().toEpochMilli());
try {
Boolean acquired = redisTemplate.opsForValue().setIfAbsent(key, cycle.cycleId(), Duration.ofSeconds(30));
if (!Boolean.TRUE.equals(acquired)) {
throw new TraderException(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED,
"cycle lock already exists: " + key);
}
log.info("event=trader.runtime.lock_acquired runId={} cycleId={} key={}",
cycle.runId(), cycle.cycleId(), key);
} catch (TraderException exception) {
throw exception;
} catch (RuntimeException exception) {
throw new TraderException(TraderErrorCode.TRADER_RUNTIME_CONTROL_BLOCKED,
"redis unavailable while acquiring cycle lock");
}
}
@Override
public void releaseCycleLock(TraderDecisionCycle cycle) {
String key = key("cycle:lock:" + cycle.symbol() + ":" + cycle.cycleTime().toEpochMilli());
try {
redisTemplate.delete(key);
log.info("event=trader.runtime.lock_released runId={} cycleId={} key={}",
cycle.runId(), cycle.cycleId(), key);
} catch (RuntimeException exception) {
log.warn("event=trader.runtime.lock_release_failed runId={} cycleId={} key={} blocker=REDIS_UNAVAILABLE",
cycle.runId(), cycle.cycleId(), key);
}
}
@Override
public TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle,
TraderPositionManagerDecision decision) {
if (!decision.candidateAction().increasesExposure()) {
return TraderRuntimeControlDecision.allow();
}
try {
String activePointer = redisTemplate.opsForValue().get(key("model:active:" + cycle.symbol()));
String expectedPointer = cycle.modelBundleVersion() + "|" + cycle.calibrationBundleVersion() + "|" + cycle.pmConfigVersion();
if (properties.release().activePointerCheckEnabled() && activePointer == null) {
return TraderRuntimeControlDecision.block("ACTIVE_POINTER_MISSING");
}
if (properties.release().activePointerCheckEnabled() && !expectedPointer.equals(activePointer)) {
return TraderRuntimeControlDecision.block("ACTIVE_POINTER_MISMATCH");
}
if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:kill-switch:global")))) {
return TraderRuntimeControlDecision.block("KILL_SWITCH_ACTIVE");
}
if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:kill-switch:symbol:" + cycle.symbol())))) {
return TraderRuntimeControlDecision.block("KILL_SWITCH_ACTIVE");
}
if (Boolean.TRUE.equals(redisTemplate.hasKey(key("risk:execution:close-only:" + cycle.symbol())))) {
return TraderRuntimeControlDecision.block("CLOSE_ONLY_ACTIVE");
}
redisTemplate.opsForValue().set(key("runtime:open-add-probe:" + cycle.cycleId()), "1", Duration.ofSeconds(30));
return TraderRuntimeControlDecision.allow();
} catch (RuntimeException exception) {
return TraderRuntimeControlDecision.block("REDIS_UNAVAILABLE");
}
}
private String key(String suffix) {
return properties.runtime().redisKeyPrefix() + ":" + suffix;
}
}
@@ -0,0 +1,14 @@
package com.quantai.trader.runtime;
public record TraderRuntimeControlDecision(
boolean allowed,
String blocker
) {
public static TraderRuntimeControlDecision allow() {
return new TraderRuntimeControlDecision(true, null);
}
public static TraderRuntimeControlDecision block(String blocker) {
return new TraderRuntimeControlDecision(false, blocker);
}
}
@@ -0,0 +1,13 @@
package com.quantai.trader.runtime;
import com.quantai.trader.domain.TraderDecisionCycle;
import com.quantai.trader.domain.TraderPositionManagerDecision;
public interface TraderRuntimeControlService {
void acquireCycleLock(TraderDecisionCycle cycle);
void releaseCycleLock(TraderDecisionCycle cycle);
TraderRuntimeControlDecision validateExposureIncrease(TraderDecisionCycle cycle,
TraderPositionManagerDecision decision);
}
@@ -3,7 +3,6 @@ package com.quantai.trader.util;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Collection;
import java.util.Objects;
public final class TraderNumbers {
public static final BigDecimal ZERO = BigDecimal.ZERO;
@@ -14,7 +13,10 @@ public final class TraderNumbers {
}
public static BigDecimal required(BigDecimal value, String field) {
return Objects.requireNonNull(value, field + " is required");
if (value == null) {
throw new IllegalArgumentException(field + " is required");
}
return value;
}
public static String requiredText(String value, String field) {
@@ -68,6 +70,9 @@ public final class TraderNumbers {
}
public static <T> Collection<T> requiredCollection(Collection<T> value, String field) {
return Objects.requireNonNull(value, field + " is required");
if (value == null) {
throw new IllegalArgumentException(field + " is required");
}
return value;
}
}