Load trader V4 artifacts from manifests
This commit is contained in:
@@ -10,12 +10,14 @@ public record TraderArtifactBundle(
|
||||
String pmConfigVersion,
|
||||
String bundleHashSha256,
|
||||
Set<String> providedModels,
|
||||
TraderPmConfig pmConfig
|
||||
TraderPmConfig pmConfig,
|
||||
TraderArtifactModelPolicy modelPolicy
|
||||
) {
|
||||
public TraderArtifactBundle {
|
||||
if (providedModels == null || !providedModels.containsAll(Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"))) {
|
||||
throw new IllegalArgumentException("artifact bundle must provide all five V4 models");
|
||||
}
|
||||
pmConfig = java.util.Objects.requireNonNull(pmConfig, "pmConfig is required");
|
||||
modelPolicy = java.util.Objects.requireNonNull(modelPolicy, "modelPolicy is required");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,71 +1,186 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.domain.TraderPmConfig;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
@Component
|
||||
public class TraderArtifactLoader {
|
||||
private static final Logger log = LoggerFactory.getLogger(TraderArtifactLoader.class);
|
||||
private static final Set<String> REQUIRED_MODELS = Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK");
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
public TraderArtifactLoader(TraderProperties properties) {
|
||||
public TraderArtifactLoader(TraderProperties properties, ObjectMapper objectMapper) {
|
||||
this.properties = properties;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
public TraderArtifactBundle loadActiveBundle() {
|
||||
TraderProperties.Artifact artifact = properties.artifact();
|
||||
if (artifact.modelBundleVersion().isBlank()
|
||||
|| artifact.calibrationBundleVersion().isBlank()
|
||||
|| artifact.pmConfigVersion().isBlank()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model/calibration/pm version is required");
|
||||
}
|
||||
TraderArtifactBundle bundle = deterministicP0Bundle(artifact);
|
||||
Path root = Path.of(artifact.artifactRoot());
|
||||
TraderModelBundleManifest modelManifest = readModelBundleManifest(root.resolve("manifests/model_bundle_manifest.json"));
|
||||
TraderPmConfigManifest pmManifest = readPmConfigManifest(root.resolve("manifests/position_manager_manifest.json"));
|
||||
TraderArtifactModelPolicy modelPolicy = readJson(root.resolve("model_output_policy.json"), TraderArtifactModelPolicy.class);
|
||||
validateVersions(artifact, modelManifest, pmManifest);
|
||||
validateModelManifest(modelManifest);
|
||||
validatePmManifest(pmManifest, properties.runMode());
|
||||
TraderArtifactBundle bundle = new TraderArtifactBundle(
|
||||
modelManifest.modelBundleVersion(),
|
||||
modelManifest.calibrationBundleVersion(),
|
||||
pmManifest.pmConfigVersion(),
|
||||
modelManifest.bundleHashSha256(),
|
||||
modelManifest.providedModels(),
|
||||
pmManifest.config(),
|
||||
modelPolicy);
|
||||
log.info("event=trader.artifact.loaded modelBundleVersion={} calibrationBundleVersion={} pmConfigVersion={} providedModels={}",
|
||||
bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion(), bundle.providedModels());
|
||||
return bundle;
|
||||
}
|
||||
|
||||
private TraderArtifactBundle deterministicP0Bundle(TraderProperties.Artifact artifact) {
|
||||
TraderPmConfig pmConfig = new TraderPmConfig(
|
||||
artifact.pmConfigVersion(),
|
||||
new TraderPmConfig.OpenRuleConfig(
|
||||
new BigDecimal("0.58"), new BigDecimal("0.58"),
|
||||
new BigDecimal("0.55"), new BigDecimal("0.55"),
|
||||
new BigDecimal("0.45"), new BigDecimal("1.0"),
|
||||
new BigDecimal("0.03"), new BigDecimal("0.10"), new BigDecimal("0.80")),
|
||||
new TraderPmConfig.AddRuleConfig(
|
||||
new BigDecimal("0.60"), new BigDecimal("0.60"),
|
||||
new BigDecimal("0.58"), new BigDecimal("0.55"), new BigDecimal("0.45"),
|
||||
new BigDecimal("0.45"), new BigDecimal("0.50"),
|
||||
new BigDecimal("1.0"), BigDecimal.ZERO, new BigDecimal("0.10"),
|
||||
new BigDecimal("500"), 3, 5),
|
||||
new TraderPmConfig.ExitRuleConfig(
|
||||
new BigDecimal("0.70"), new BigDecimal("0.70"), new BigDecimal("0.70"),
|
||||
new BigDecimal("0.25"), new BigDecimal("0.62"),
|
||||
new BigDecimal("0.35"), new BigDecimal("0.70"),
|
||||
new BigDecimal("5.0"), new BigDecimal("80")),
|
||||
new TraderPmConfig.SizingConfig(
|
||||
new BigDecimal("0.80"), new BigDecimal("0.05"), BigDecimal.ONE,
|
||||
new BigDecimal("0.02"), new BigDecimal("0.25"), BigDecimal.ONE,
|
||||
new BigDecimal("1.0"), new BigDecimal("80"),
|
||||
new BigDecimal("0.20"), new BigDecimal("0.50"), new BigDecimal("500"))
|
||||
);
|
||||
return new TraderArtifactBundle(
|
||||
artifact.modelBundleVersion(),
|
||||
artifact.calibrationBundleVersion(),
|
||||
artifact.pmConfigVersion(),
|
||||
"deterministic-p0-fixture",
|
||||
Set.of("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK"),
|
||||
pmConfig);
|
||||
private TraderModelBundleManifest readModelBundleManifest(Path path) {
|
||||
JsonNode root = readJsonNode(path);
|
||||
return new TraderModelBundleManifest(
|
||||
requiredText(root, "model_bundle_version", path),
|
||||
requiredText(root, "calibration_bundle_version", path),
|
||||
requiredText(root, "feature_version", path),
|
||||
requiredText(root, "label_version", path),
|
||||
requiredText(root, "split_version", path),
|
||||
textSet(root, "required_models_json", path),
|
||||
textSet(root, "provided_models_json", path),
|
||||
textSet(root, "missing_models_json", path),
|
||||
requiredText(root, "bundle_hash_sha256", path),
|
||||
root.path("complete").asBoolean(false),
|
||||
requiredText(root, "status", path));
|
||||
}
|
||||
|
||||
private TraderPmConfigManifest readPmConfigManifest(Path path) {
|
||||
JsonNode root = readJsonNode(path);
|
||||
return new TraderPmConfigManifest(
|
||||
requiredText(root, "pm_config_version", path),
|
||||
requiredText(root, "model_bundle_version", path),
|
||||
requiredText(root, "calibration_bundle_version", path),
|
||||
enumSet(root, "allowed_run_modes_json", TraderRunMode.class, path),
|
||||
convert(root.path("config_json"), TraderPmConfig.class, path),
|
||||
requiredText(root, "config_hash_sha256", path),
|
||||
requiredText(root, "status", path));
|
||||
}
|
||||
|
||||
private void validateVersions(TraderProperties.Artifact expected, TraderModelBundleManifest model, TraderPmConfigManifest pm) {
|
||||
if (!expected.modelBundleVersion().equals(model.modelBundleVersion())
|
||||
|| !expected.calibrationBundleVersion().equals(model.calibrationBundleVersion())
|
||||
|| !expected.pmConfigVersion().equals(pm.pmConfigVersion())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact version triple does not match configured model/calibration/pm versions");
|
||||
}
|
||||
if (!model.modelBundleVersion().equals(pm.modelBundleVersion())
|
||||
|| !model.calibrationBundleVersion().equals(pm.calibrationBundleVersion())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_CALIBRATION_MISMATCH,
|
||||
"model and pm manifests reference different model/calibration versions");
|
||||
}
|
||||
}
|
||||
|
||||
private void validateModelManifest(TraderModelBundleManifest manifest) {
|
||||
if (!manifest.complete() || !"ACTIVE".equals(manifest.status())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model bundle manifest must be complete and ACTIVE");
|
||||
}
|
||||
if (!manifest.requiredModels().containsAll(REQUIRED_MODELS)
|
||||
|| !manifest.providedModels().containsAll(REQUIRED_MODELS)
|
||||
|| !manifest.missingModels().isEmpty()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"model bundle must provide all five V4 models with no missing model");
|
||||
}
|
||||
}
|
||||
|
||||
private void validatePmManifest(TraderPmConfigManifest manifest, TraderRunMode runMode) {
|
||||
if (!"ACTIVE".equals(manifest.status())) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_PM_CONFIG_MISMATCH,
|
||||
"pm config manifest must be ACTIVE");
|
||||
}
|
||||
if (!manifest.allowedRunModes().contains(runMode)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_PM_CONFIG_MISMATCH,
|
||||
"pm config manifest does not allow current run mode");
|
||||
}
|
||||
}
|
||||
|
||||
private JsonNode readJsonNode(Path path) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact file is missing: " + path);
|
||||
}
|
||||
try {
|
||||
return objectMapper.readTree(path.toFile());
|
||||
} catch (IOException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact file cannot be read: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
private <T> T readJson(Path path, Class<T> type) {
|
||||
if (!Files.isRegularFile(path)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact file is missing: " + path);
|
||||
}
|
||||
try {
|
||||
return objectMapper.readValue(path.toFile(), type);
|
||||
} catch (IOException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact file cannot be read: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
private <T> T convert(JsonNode node, Class<T> type, Path path) {
|
||||
if (node == null || node.isMissingNode() || node.isNull()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact field is missing: " + path + "#config_json");
|
||||
}
|
||||
try {
|
||||
return objectMapper.treeToValue(node, type);
|
||||
} catch (IOException exception) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact field cannot be parsed: " + path + "#config_json");
|
||||
}
|
||||
}
|
||||
|
||||
private String requiredText(JsonNode node, String field, Path path) {
|
||||
String value = node.path(field).asText("");
|
||||
if (value.isBlank()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact field is required: " + path + "#" + field);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private Set<String> textSet(JsonNode node, String field, Path path) {
|
||||
JsonNode array = node.path(field);
|
||||
if (!array.isArray()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_MODEL_ARTIFACT_MISSING,
|
||||
"artifact field must be array: " + path + "#" + field);
|
||||
}
|
||||
return StreamSupport.stream(array.spliterator(), false)
|
||||
.map(JsonNode::asText)
|
||||
.collect(Collectors.toUnmodifiableSet());
|
||||
}
|
||||
|
||||
private <E extends Enum<E>> Set<E> enumSet(JsonNode node, String field, Class<E> type, Path path) {
|
||||
return textSet(node, field, path).stream()
|
||||
.map(value -> Enum.valueOf(type, value))
|
||||
.collect(Collectors.toUnmodifiableSet());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
public record TraderArtifactModelPolicy(
|
||||
DirectionPolicy direction,
|
||||
EntryPolicy entry,
|
||||
ContinuePolicy continuation,
|
||||
ExitPolicy exit,
|
||||
RiskPolicy risk,
|
||||
BigDecimal uncertainty,
|
||||
BigDecimal oodScore
|
||||
) {
|
||||
public record DirectionPolicy(
|
||||
BigDecimal longProbWhenMarkGteIndex,
|
||||
BigDecimal longProbWhenMarkLtIndex,
|
||||
BigDecimal neutralProb,
|
||||
BigDecimal expectedReturnBps,
|
||||
int horizonMinutes,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record EntryPolicy(
|
||||
BigDecimal longEntryProb,
|
||||
BigDecimal shortEntryProb,
|
||||
BigDecimal entryQualityScore,
|
||||
BigDecimal expectedEdgeBps,
|
||||
String pricePlanId,
|
||||
String pricePlanConfigHash,
|
||||
BigDecimal stopDistanceBps,
|
||||
BigDecimal targetDistanceBps,
|
||||
int maxHoldMinutes,
|
||||
BigDecimal costBps,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record ContinuePolicy(
|
||||
BigDecimal longContinueProb,
|
||||
BigDecimal shortContinueProb,
|
||||
BigDecimal trendPersistenceProb,
|
||||
BigDecimal holdEdgeBps,
|
||||
BigDecimal continueVsExitEdgeBps,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record ExitPolicy(
|
||||
BigDecimal longExitProb,
|
||||
BigDecimal shortExitProb,
|
||||
BigDecimal profitGivebackProb,
|
||||
BigDecimal reversalProb,
|
||||
BigDecimal stopRiskProb,
|
||||
BigDecimal stagnationProb,
|
||||
BigDecimal expectedGivebackBps,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
|
||||
public record RiskPolicy(
|
||||
BigDecimal marketRiskProb,
|
||||
BigDecimal positionRiskProb,
|
||||
BigDecimal marketRiskSeverityBps,
|
||||
BigDecimal positionRiskSeverityBps,
|
||||
BigDecimal drawdownProb,
|
||||
BigDecimal expectedShortfallBps,
|
||||
BigDecimal volatilityExpansionProb,
|
||||
BigDecimal spikeProb,
|
||||
BigDecimal liquidityRiskProb,
|
||||
BigDecimal liquidityCapacityRatioWhenReady,
|
||||
BigDecimal liquidityCapacityRatioWhenNotReady,
|
||||
String modelVersion
|
||||
) {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public record TraderModelBundleManifest(
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
String featureVersion,
|
||||
String labelVersion,
|
||||
String splitVersion,
|
||||
Set<String> requiredModels,
|
||||
Set<String> providedModels,
|
||||
Set<String> missingModels,
|
||||
String bundleHashSha256,
|
||||
boolean complete,
|
||||
String status
|
||||
) {
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.quantai.trader.artifact;
|
||||
|
||||
import com.quantai.trader.domain.TraderPmConfig;
|
||||
import com.quantai.trader.enums.TraderRunMode;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public record TraderPmConfigManifest(
|
||||
String pmConfigVersion,
|
||||
String modelBundleVersion,
|
||||
String calibrationBundleVersion,
|
||||
Set<TraderRunMode> allowedRunModes,
|
||||
TraderPmConfig config,
|
||||
String configHashSha256,
|
||||
String status
|
||||
) {
|
||||
}
|
||||
Reference in New Issue
Block a user