Files
quant-trader-service/training/trader_training/exporter.py
T

245 lines
11 KiB
Python

from __future__ import annotations
import logging
import shutil
from pathlib import Path
from typing import Any
import pandas as pd
from trader_training.io_utils import read_json, read_parquet, run_root, sha256_file, sha256_json, utc_now_text, write_json
from trader_training.pm import default_pm_config
from trader_training.schemas import (
CALIBRATION_BUNDLE_VERSION,
FEATURE_ORDER,
FEATURE_VERSION,
FIT_SPLIT,
LABEL_VERSION,
MODEL_BUNDLE_VERSION,
MODEL_OUTPUTS,
OUTPUT_MAPPING,
OUTPUT_SCHEMA,
PM_CONFIG_VERSION,
SPLIT_VERSION,
TUNE_SPLIT,
VALIDATION_LOCKED_SPLIT,
)
REQUIRED_MODELS = ("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK")
EXPORT_STATUS = "CANDIDATE"
def export_artifact_bundle(args: Any) -> None:
root = run_root(args)
export_root = args.export_root or root / "export" / f"trader-model-bundle-{args.run_id}"
artifact_root = export_root / "artifact_bundle"
for folder in ("models", "schemas", "calibrators", "manifests", "examples"):
(artifact_root / folder).mkdir(parents=True, exist_ok=True)
feature_schema_path = root / "feature" / "feature_schema.json"
feature_order_path = root / "feature" / "feature_order.json"
output_schema_path = artifact_root / "schemas" / "output_schema.json"
shutil.copy2(feature_schema_path, artifact_root / "schemas" / "feature_schema.json")
shutil.copy2(feature_order_path, artifact_root / "schemas" / "feature_order.json")
write_json(output_schema_path, OUTPUT_SCHEMA)
sample_input = _sample_input(root)
write_json(artifact_root / "examples" / "sample_input.json", sample_input)
write_json(artifact_root / "examples" / "sample_output.json", _sample_output())
price_plan = read_json(root / "label" / "price_plan_context.json")
write_json(artifact_root / "price_plan_context.json", price_plan)
model_manifest_rows = []
calibration_manifest_rows = []
model_hashes = {}
train_start = _split_time(root, FIT_SPLIT, "start")
train_end = _split_time(root, FIT_SPLIT, "end")
validation_start = _split_time(root, VALIDATION_LOCKED_SPLIT, "start")
validation_end = _split_time(root, VALIDATION_LOCKED_SPLIT, "end")
calibration_train_manifest = read_json(root / "calibration" / "calibration_train_manifest.json")
calibration_quality = {row["model_name"]: row for row in calibration_train_manifest.get("calibrators", [])}
for model_name in REQUIRED_MODELS:
src = root / "model" / model_name.lower() / f"{model_name.lower()}.onnx"
dst = artifact_root / "models" / f"{model_name.lower()}.onnx"
shutil.copy2(src, dst)
model_hashes[model_name] = sha256_file(dst)
cal_src = root / "calibration" / model_name.lower() / "calibrator.json"
cal_dst = artifact_root / "calibrators" / model_name.lower() / "calibrator.json"
cal_dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(cal_src, cal_dst)
cal_hash = sha256_file(cal_dst)
metrics = read_json(root / "model" / model_name.lower() / "model_train_result.json")
model_manifest_rows.append(
_model_manifest_row(
model_name,
f"models/{model_name.lower()}.onnx",
model_hashes[model_name],
sha256_file(artifact_root / "schemas" / "feature_schema.json"),
sha256_file(artifact_root / "schemas" / "feature_order.json"),
sha256_file(output_schema_path),
metrics.get("metrics", {}),
metrics.get("quality_status", "UNKNOWN"),
metrics.get("quality_reasons", []),
train_start,
train_end,
validation_start,
validation_end,
EXPORT_STATUS,
)
)
calibration_manifest_rows.append(
{
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"model_bundle_version": MODEL_BUNDLE_VERSION,
"model_name": model_name,
"calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0",
"calibration_method": "BINNING",
"calibrator_path": f"calibrators/{model_name.lower()}/calibrator.json",
"calibrator_hash_sha256": cal_hash,
"calibration_window_from": _split_time(root, TUNE_SPLIT, "start"),
"calibration_window_to": _split_time(root, TUNE_SPLIT, "end"),
"calibration_metrics_json": {},
"bucket_metrics_json": {},
"output_after_calibration_schema_hash": sha256_file(output_schema_path),
"quality_status": calibration_quality.get(model_name, {}).get("quality_status", "UNKNOWN"),
"quality_reasons_json": calibration_quality.get(model_name, {}).get("quality_reasons", []),
"status": EXPORT_STATUS,
}
)
pm_payload = read_json(root / "pm-search" / "position_manager_config.json") if (root / "pm-search" / "position_manager_config.json").is_file() else {"config": default_pm_config(), "threshold_stability_json": {}}
pm_config = pm_payload["config"]
pm_hash = sha256_json(pm_config)
backtest_manifest = read_json(root / "backtest" / "backtest_manifest.json")
write_json(
artifact_root / "manifests" / "position_manager_manifest.json",
{
"pm_config_version": PM_CONFIG_VERSION,
"model_bundle_version": MODEL_BUNDLE_VERSION,
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"threshold_stability_json": pm_payload.get("threshold_stability_json", {}),
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
"config_json": pm_config,
"config_hash_sha256": pm_hash,
"status": EXPORT_STATUS,
},
)
write_json(artifact_root / "manifests" / "model_manifest.json", model_manifest_rows)
write_json(artifact_root / "manifests" / "calibration_manifest.json", calibration_manifest_rows)
bundle_hash = sha256_json({"models": model_hashes, "feature_order_hash": sha256_file(artifact_root / "schemas" / "feature_order.json"), "pm_hash": pm_hash})
write_json(
artifact_root / "manifests" / "model_bundle_manifest.json",
{
"manifest_schema_version": "trader-model-bundle-manifest-v4-p0",
"model_bundle_version": MODEL_BUNDLE_VERSION,
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"feature_version": FEATURE_VERSION,
"label_version": LABEL_VERSION,
"split_version": SPLIT_VERSION,
"training_run_id": args.run_id,
"training_export_id": f"export-{args.run_id}",
"backtest_manifest_id": f"backtest-{args.run_id}",
"required_models_json": list(REQUIRED_MODELS),
"provided_models_json": list(REQUIRED_MODELS),
"missing_models_json": [],
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
"bundle_hash_sha256": bundle_hash,
"model_quality_status_json": {row["model_name"]: row["quality_status"] for row in model_manifest_rows},
"calibration_quality_status_json": {row["model_name"]: row["quality_status"] for row in calibration_manifest_rows},
"backtest_status": backtest_manifest.get("status"),
"backtest_status_reasons_json": backtest_manifest.get("status_reasons", []),
"backtest_metrics_json": backtest_manifest.get("metrics", {}),
"complete": True,
"status": EXPORT_STATUS,
},
)
write_json(
artifact_root / "manifests" / "training_export_manifest.json",
{"created_at": utc_now_text(), "status": EXPORT_STATUS, "artifact_root": str(artifact_root), "bundle_hash_sha256": bundle_hash},
)
logging.info("trader.training.artifact_exported runId=%s status=%s bundleHash=%s path=%s", args.run_id, EXPORT_STATUS, bundle_hash, artifact_root)
def _model_manifest_row(
model_name: str,
artifact_path: str,
artifact_hash: str,
feature_schema_hash: str,
feature_order_hash: str,
output_schema_hash: str,
metrics: dict,
quality_status: str,
quality_reasons: list[str],
train_start: str,
train_end: str,
validation_start: str,
validation_end: str,
status: str,
) -> dict:
return {
"model_bundle_version": MODEL_BUNDLE_VERSION,
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"model_name": model_name,
"model_type": model_name,
"side": "BOTH",
"symbol_scope_json": ["BTC-USDT-PERP"],
"bar_interval": "1m",
"horizon_minutes": 45,
"model_format": "ONNX",
"model_runtime": "ONNX_RUNTIME_JAVA",
"model_runtime_version": "1.22.0",
"onnx_opset_version": 17,
"producer_name": "trader-training",
"producer_version": "v4-p0",
"feature_version": FEATURE_VERSION,
"feature_schema_path": "schemas/feature_schema.json",
"feature_schema_hash": feature_schema_hash,
"feature_order_path": "schemas/feature_order.json",
"feature_order_hash": feature_order_hash,
"input_tensor_name": "features",
"input_dtype": "FLOAT32",
"input_shape_json": {"features": len(FEATURE_ORDER), "batch": 1},
"input_example_path": "examples/sample_input.json",
"output_schema_path": "schemas/output_schema.json",
"output_schema_hash": output_schema_hash,
"output_tensor_names_json": ["prediction"],
"output_mapping_json": OUTPUT_MAPPING[model_name],
"output_value_rules_json": {"clip_by_output_schema": True},
"label_version": LABEL_VERSION,
"split_version": SPLIT_VERSION,
"training_fold": "fold_01",
"train_start": train_start,
"train_end": train_end,
"validation_start": validation_start,
"validation_end": validation_end,
"test_start": validation_start,
"test_end": validation_end,
"metrics_json": metrics,
"quality_status": quality_status,
"quality_reasons_json": quality_reasons,
"artifact_path": artifact_path,
"artifact_hash_sha256": artifact_hash,
"source_hash": artifact_hash,
"status": status,
}
def _sample_input(root) -> dict:
features = read_parquet(root / "feature" / "feature_frame.parquet")
row = features[features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].iloc[0]
return {feature: float(row[feature]) for feature in FEATURE_ORDER}
def _sample_output() -> dict:
return {model: {field: 0.0 for field in fields} for model, fields in MODEL_OUTPUTS.items()}
def _split_time(root, split_id: str, key: str) -> str:
manifest = read_json(root / "split" / "split_manifest.json")
for item in manifest["splits"]:
if item["split_id"] == split_id:
return item[key]
return "2026-01-01T00:00:00Z"