245 lines
11 KiB
Python
245 lines
11 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import shutil
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
from trader_training.io_utils import read_json, read_parquet, run_root, sha256_file, sha256_json, utc_now_text, write_json
|
||
|
|
from trader_training.pm import default_pm_config
|
||
|
|
from trader_training.schemas import (
|
||
|
|
CALIBRATION_BUNDLE_VERSION,
|
||
|
|
FEATURE_ORDER,
|
||
|
|
FEATURE_VERSION,
|
||
|
|
FIT_SPLIT,
|
||
|
|
LABEL_VERSION,
|
||
|
|
MODEL_BUNDLE_VERSION,
|
||
|
|
MODEL_OUTPUTS,
|
||
|
|
OUTPUT_MAPPING,
|
||
|
|
OUTPUT_SCHEMA,
|
||
|
|
PM_CONFIG_VERSION,
|
||
|
|
SPLIT_VERSION,
|
||
|
|
TUNE_SPLIT,
|
||
|
|
VALIDATION_LOCKED_SPLIT,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
REQUIRED_MODELS = ("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK")
|
||
|
|
EXPORT_STATUS = "CANDIDATE"
|
||
|
|
|
||
|
|
|
||
|
|
def export_artifact_bundle(args: Any) -> None:
|
||
|
|
root = run_root(args)
|
||
|
|
export_root = args.export_root or root / "export" / f"trader-model-bundle-{args.run_id}"
|
||
|
|
artifact_root = export_root / "artifact_bundle"
|
||
|
|
for folder in ("models", "schemas", "calibrators", "manifests", "examples"):
|
||
|
|
(artifact_root / folder).mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
feature_schema_path = root / "feature" / "feature_schema.json"
|
||
|
|
feature_order_path = root / "feature" / "feature_order.json"
|
||
|
|
output_schema_path = artifact_root / "schemas" / "output_schema.json"
|
||
|
|
shutil.copy2(feature_schema_path, artifact_root / "schemas" / "feature_schema.json")
|
||
|
|
shutil.copy2(feature_order_path, artifact_root / "schemas" / "feature_order.json")
|
||
|
|
write_json(output_schema_path, OUTPUT_SCHEMA)
|
||
|
|
|
||
|
|
sample_input = _sample_input(root)
|
||
|
|
write_json(artifact_root / "examples" / "sample_input.json", sample_input)
|
||
|
|
write_json(artifact_root / "examples" / "sample_output.json", _sample_output())
|
||
|
|
|
||
|
|
price_plan = read_json(root / "label" / "price_plan_context.json")
|
||
|
|
write_json(artifact_root / "price_plan_context.json", price_plan)
|
||
|
|
|
||
|
|
model_manifest_rows = []
|
||
|
|
calibration_manifest_rows = []
|
||
|
|
model_hashes = {}
|
||
|
|
train_start = _split_time(root, FIT_SPLIT, "start")
|
||
|
|
train_end = _split_time(root, FIT_SPLIT, "end")
|
||
|
|
validation_start = _split_time(root, VALIDATION_LOCKED_SPLIT, "start")
|
||
|
|
validation_end = _split_time(root, VALIDATION_LOCKED_SPLIT, "end")
|
||
|
|
calibration_train_manifest = read_json(root / "calibration" / "calibration_train_manifest.json")
|
||
|
|
calibration_quality = {row["model_name"]: row for row in calibration_train_manifest.get("calibrators", [])}
|
||
|
|
for model_name in REQUIRED_MODELS:
|
||
|
|
src = root / "model" / model_name.lower() / f"{model_name.lower()}.onnx"
|
||
|
|
dst = artifact_root / "models" / f"{model_name.lower()}.onnx"
|
||
|
|
shutil.copy2(src, dst)
|
||
|
|
model_hashes[model_name] = sha256_file(dst)
|
||
|
|
cal_src = root / "calibration" / model_name.lower() / "calibrator.json"
|
||
|
|
cal_dst = artifact_root / "calibrators" / model_name.lower() / "calibrator.json"
|
||
|
|
cal_dst.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
shutil.copy2(cal_src, cal_dst)
|
||
|
|
cal_hash = sha256_file(cal_dst)
|
||
|
|
metrics = read_json(root / "model" / model_name.lower() / "model_train_result.json")
|
||
|
|
model_manifest_rows.append(
|
||
|
|
_model_manifest_row(
|
||
|
|
model_name,
|
||
|
|
f"models/{model_name.lower()}.onnx",
|
||
|
|
model_hashes[model_name],
|
||
|
|
sha256_file(artifact_root / "schemas" / "feature_schema.json"),
|
||
|
|
sha256_file(artifact_root / "schemas" / "feature_order.json"),
|
||
|
|
sha256_file(output_schema_path),
|
||
|
|
metrics.get("metrics", {}),
|
||
|
|
metrics.get("quality_status", "UNKNOWN"),
|
||
|
|
metrics.get("quality_reasons", []),
|
||
|
|
train_start,
|
||
|
|
train_end,
|
||
|
|
validation_start,
|
||
|
|
validation_end,
|
||
|
|
EXPORT_STATUS,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
calibration_manifest_rows.append(
|
||
|
|
{
|
||
|
|
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||
|
|
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||
|
|
"model_name": model_name,
|
||
|
|
"calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0",
|
||
|
|
"calibration_method": "BINNING",
|
||
|
|
"calibrator_path": f"calibrators/{model_name.lower()}/calibrator.json",
|
||
|
|
"calibrator_hash_sha256": cal_hash,
|
||
|
|
"calibration_window_from": _split_time(root, TUNE_SPLIT, "start"),
|
||
|
|
"calibration_window_to": _split_time(root, TUNE_SPLIT, "end"),
|
||
|
|
"calibration_metrics_json": {},
|
||
|
|
"bucket_metrics_json": {},
|
||
|
|
"output_after_calibration_schema_hash": sha256_file(output_schema_path),
|
||
|
|
"quality_status": calibration_quality.get(model_name, {}).get("quality_status", "UNKNOWN"),
|
||
|
|
"quality_reasons_json": calibration_quality.get(model_name, {}).get("quality_reasons", []),
|
||
|
|
"status": EXPORT_STATUS,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
pm_payload = read_json(root / "pm-search" / "position_manager_config.json") if (root / "pm-search" / "position_manager_config.json").is_file() else {"config": default_pm_config(), "threshold_stability_json": {}}
|
||
|
|
pm_config = pm_payload["config"]
|
||
|
|
pm_hash = sha256_json(pm_config)
|
||
|
|
backtest_manifest = read_json(root / "backtest" / "backtest_manifest.json")
|
||
|
|
write_json(
|
||
|
|
artifact_root / "manifests" / "position_manager_manifest.json",
|
||
|
|
{
|
||
|
|
"pm_config_version": PM_CONFIG_VERSION,
|
||
|
|
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||
|
|
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||
|
|
"threshold_stability_json": pm_payload.get("threshold_stability_json", {}),
|
||
|
|
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
|
||
|
|
"config_json": pm_config,
|
||
|
|
"config_hash_sha256": pm_hash,
|
||
|
|
"status": EXPORT_STATUS,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
write_json(artifact_root / "manifests" / "model_manifest.json", model_manifest_rows)
|
||
|
|
write_json(artifact_root / "manifests" / "calibration_manifest.json", calibration_manifest_rows)
|
||
|
|
bundle_hash = sha256_json({"models": model_hashes, "feature_order_hash": sha256_file(artifact_root / "schemas" / "feature_order.json"), "pm_hash": pm_hash})
|
||
|
|
write_json(
|
||
|
|
artifact_root / "manifests" / "model_bundle_manifest.json",
|
||
|
|
{
|
||
|
|
"manifest_schema_version": "trader-model-bundle-manifest-v4-p0",
|
||
|
|
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||
|
|
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||
|
|
"feature_version": FEATURE_VERSION,
|
||
|
|
"label_version": LABEL_VERSION,
|
||
|
|
"split_version": SPLIT_VERSION,
|
||
|
|
"training_run_id": args.run_id,
|
||
|
|
"training_export_id": f"export-{args.run_id}",
|
||
|
|
"backtest_manifest_id": f"backtest-{args.run_id}",
|
||
|
|
"required_models_json": list(REQUIRED_MODELS),
|
||
|
|
"provided_models_json": list(REQUIRED_MODELS),
|
||
|
|
"missing_models_json": [],
|
||
|
|
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
|
||
|
|
"bundle_hash_sha256": bundle_hash,
|
||
|
|
"model_quality_status_json": {row["model_name"]: row["quality_status"] for row in model_manifest_rows},
|
||
|
|
"calibration_quality_status_json": {row["model_name"]: row["quality_status"] for row in calibration_manifest_rows},
|
||
|
|
"backtest_status": backtest_manifest.get("status"),
|
||
|
|
"backtest_status_reasons_json": backtest_manifest.get("status_reasons", []),
|
||
|
|
"backtest_metrics_json": backtest_manifest.get("metrics", {}),
|
||
|
|
"complete": True,
|
||
|
|
"status": EXPORT_STATUS,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
write_json(
|
||
|
|
artifact_root / "manifests" / "training_export_manifest.json",
|
||
|
|
{"created_at": utc_now_text(), "status": EXPORT_STATUS, "artifact_root": str(artifact_root), "bundle_hash_sha256": bundle_hash},
|
||
|
|
)
|
||
|
|
logging.info("trader.training.artifact_exported runId=%s status=%s bundleHash=%s path=%s", args.run_id, EXPORT_STATUS, bundle_hash, artifact_root)
|
||
|
|
|
||
|
|
|
||
|
|
def _model_manifest_row(
|
||
|
|
model_name: str,
|
||
|
|
artifact_path: str,
|
||
|
|
artifact_hash: str,
|
||
|
|
feature_schema_hash: str,
|
||
|
|
feature_order_hash: str,
|
||
|
|
output_schema_hash: str,
|
||
|
|
metrics: dict,
|
||
|
|
quality_status: str,
|
||
|
|
quality_reasons: list[str],
|
||
|
|
train_start: str,
|
||
|
|
train_end: str,
|
||
|
|
validation_start: str,
|
||
|
|
validation_end: str,
|
||
|
|
status: str,
|
||
|
|
) -> dict:
|
||
|
|
return {
|
||
|
|
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||
|
|
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||
|
|
"model_name": model_name,
|
||
|
|
"model_type": model_name,
|
||
|
|
"side": "BOTH",
|
||
|
|
"symbol_scope_json": ["BTC-USDT-PERP"],
|
||
|
|
"bar_interval": "1m",
|
||
|
|
"horizon_minutes": 45,
|
||
|
|
"model_format": "ONNX",
|
||
|
|
"model_runtime": "ONNX_RUNTIME_JAVA",
|
||
|
|
"model_runtime_version": "1.22.0",
|
||
|
|
"onnx_opset_version": 17,
|
||
|
|
"producer_name": "trader-training",
|
||
|
|
"producer_version": "v4-p0",
|
||
|
|
"feature_version": FEATURE_VERSION,
|
||
|
|
"feature_schema_path": "schemas/feature_schema.json",
|
||
|
|
"feature_schema_hash": feature_schema_hash,
|
||
|
|
"feature_order_path": "schemas/feature_order.json",
|
||
|
|
"feature_order_hash": feature_order_hash,
|
||
|
|
"input_tensor_name": "features",
|
||
|
|
"input_dtype": "FLOAT32",
|
||
|
|
"input_shape_json": {"features": len(FEATURE_ORDER), "batch": 1},
|
||
|
|
"input_example_path": "examples/sample_input.json",
|
||
|
|
"output_schema_path": "schemas/output_schema.json",
|
||
|
|
"output_schema_hash": output_schema_hash,
|
||
|
|
"output_tensor_names_json": ["prediction"],
|
||
|
|
"output_mapping_json": OUTPUT_MAPPING[model_name],
|
||
|
|
"output_value_rules_json": {"clip_by_output_schema": True},
|
||
|
|
"label_version": LABEL_VERSION,
|
||
|
|
"split_version": SPLIT_VERSION,
|
||
|
|
"training_fold": "fold_01",
|
||
|
|
"train_start": train_start,
|
||
|
|
"train_end": train_end,
|
||
|
|
"validation_start": validation_start,
|
||
|
|
"validation_end": validation_end,
|
||
|
|
"test_start": validation_start,
|
||
|
|
"test_end": validation_end,
|
||
|
|
"metrics_json": metrics,
|
||
|
|
"quality_status": quality_status,
|
||
|
|
"quality_reasons_json": quality_reasons,
|
||
|
|
"artifact_path": artifact_path,
|
||
|
|
"artifact_hash_sha256": artifact_hash,
|
||
|
|
"source_hash": artifact_hash,
|
||
|
|
"status": status,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _sample_input(root) -> dict:
|
||
|
|
features = read_parquet(root / "feature" / "feature_frame.parquet")
|
||
|
|
row = features[features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].iloc[0]
|
||
|
|
return {feature: float(row[feature]) for feature in FEATURE_ORDER}
|
||
|
|
|
||
|
|
|
||
|
|
def _sample_output() -> dict:
|
||
|
|
return {model: {field: 0.0 for field in fields} for model, fields in MODEL_OUTPUTS.items()}
|
||
|
|
|
||
|
|
|
||
|
|
def _split_time(root, split_id: str, key: str) -> str:
|
||
|
|
manifest = read_json(root / "split" / "split_manifest.json")
|
||
|
|
for item in manifest["splits"]:
|
||
|
|
if item["split_id"] == split_id:
|
||
|
|
return item[key]
|
||
|
|
return "2026-01-01T00:00:00Z"
|