from __future__ import annotations import logging import shutil from pathlib import Path from typing import Any import pandas as pd from trader_training.io_utils import read_json, read_parquet, run_root, sha256_file, sha256_json, utc_now_text, write_json from trader_training.pm import default_pm_config from trader_training.schemas import ( CALIBRATION_BUNDLE_VERSION, FEATURE_ORDER, FEATURE_VERSION, FIT_SPLIT, LABEL_VERSION, MODEL_BUNDLE_VERSION, MODEL_OUTPUTS, OUTPUT_MAPPING, OUTPUT_SCHEMA, PM_CONFIG_VERSION, SPLIT_VERSION, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, ) REQUIRED_MODELS = ("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK") EXPORT_STATUS = "CANDIDATE" def export_artifact_bundle(args: Any) -> None: root = run_root(args) export_root = args.export_root or root / "export" / f"trader-model-bundle-{args.run_id}" artifact_root = export_root / "artifact_bundle" for folder in ("models", "schemas", "calibrators", "manifests", "examples"): (artifact_root / folder).mkdir(parents=True, exist_ok=True) feature_schema_path = root / "feature" / "feature_schema.json" feature_order_path = root / "feature" / "feature_order.json" output_schema_path = artifact_root / "schemas" / "output_schema.json" shutil.copy2(feature_schema_path, artifact_root / "schemas" / "feature_schema.json") shutil.copy2(feature_order_path, artifact_root / "schemas" / "feature_order.json") write_json(output_schema_path, OUTPUT_SCHEMA) sample_input = _sample_input(root) write_json(artifact_root / "examples" / "sample_input.json", sample_input) write_json(artifact_root / "examples" / "sample_output.json", _sample_output()) price_plan = read_json(root / "label" / "price_plan_context.json") write_json(artifact_root / "price_plan_context.json", price_plan) model_manifest_rows = [] calibration_manifest_rows = [] model_hashes = {} train_start = _split_time(root, FIT_SPLIT, "start") train_end = _split_time(root, FIT_SPLIT, "end") validation_start = _split_time(root, VALIDATION_LOCKED_SPLIT, "start") validation_end = _split_time(root, VALIDATION_LOCKED_SPLIT, "end") calibration_train_manifest = read_json(root / "calibration" / "calibration_train_manifest.json") calibration_quality = {row["model_name"]: row for row in calibration_train_manifest.get("calibrators", [])} for model_name in REQUIRED_MODELS: src = root / "model" / model_name.lower() / f"{model_name.lower()}.onnx" dst = artifact_root / "models" / f"{model_name.lower()}.onnx" shutil.copy2(src, dst) model_hashes[model_name] = sha256_file(dst) cal_src = root / "calibration" / model_name.lower() / "calibrator.json" cal_dst = artifact_root / "calibrators" / model_name.lower() / "calibrator.json" cal_dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(cal_src, cal_dst) cal_hash = sha256_file(cal_dst) metrics = read_json(root / "model" / model_name.lower() / "model_train_result.json") model_manifest_rows.append( _model_manifest_row( model_name, f"models/{model_name.lower()}.onnx", model_hashes[model_name], sha256_file(artifact_root / "schemas" / "feature_schema.json"), sha256_file(artifact_root / "schemas" / "feature_order.json"), sha256_file(output_schema_path), metrics.get("metrics", {}), metrics.get("quality_status", "UNKNOWN"), metrics.get("quality_reasons", []), train_start, train_end, validation_start, validation_end, EXPORT_STATUS, ) ) calibration_manifest_rows.append( { "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, "model_bundle_version": MODEL_BUNDLE_VERSION, "model_name": model_name, "calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0", "calibration_method": "BINNING", "calibrator_path": f"calibrators/{model_name.lower()}/calibrator.json", "calibrator_hash_sha256": cal_hash, "calibration_window_from": _split_time(root, TUNE_SPLIT, "start"), "calibration_window_to": _split_time(root, TUNE_SPLIT, "end"), "calibration_metrics_json": {}, "bucket_metrics_json": {}, "output_after_calibration_schema_hash": sha256_file(output_schema_path), "quality_status": calibration_quality.get(model_name, {}).get("quality_status", "UNKNOWN"), "quality_reasons_json": calibration_quality.get(model_name, {}).get("quality_reasons", []), "status": EXPORT_STATUS, } ) pm_payload = read_json(root / "pm-search" / "position_manager_config.json") if (root / "pm-search" / "position_manager_config.json").is_file() else {"config": default_pm_config(), "threshold_stability_json": {}} pm_config = pm_payload["config"] pm_hash = sha256_json(pm_config) backtest_manifest = read_json(root / "backtest" / "backtest_manifest.json") write_json( artifact_root / "manifests" / "position_manager_manifest.json", { "pm_config_version": PM_CONFIG_VERSION, "model_bundle_version": MODEL_BUNDLE_VERSION, "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, "threshold_stability_json": pm_payload.get("threshold_stability_json", {}), "allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"], "config_json": pm_config, "config_hash_sha256": pm_hash, "status": EXPORT_STATUS, }, ) write_json(artifact_root / "manifests" / "model_manifest.json", model_manifest_rows) write_json(artifact_root / "manifests" / "calibration_manifest.json", calibration_manifest_rows) bundle_hash = sha256_json({"models": model_hashes, "feature_order_hash": sha256_file(artifact_root / "schemas" / "feature_order.json"), "pm_hash": pm_hash}) write_json( artifact_root / "manifests" / "model_bundle_manifest.json", { "manifest_schema_version": "trader-model-bundle-manifest-v4-p0", "model_bundle_version": MODEL_BUNDLE_VERSION, "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, "feature_version": FEATURE_VERSION, "label_version": LABEL_VERSION, "split_version": SPLIT_VERSION, "training_run_id": args.run_id, "training_export_id": f"export-{args.run_id}", "backtest_manifest_id": f"backtest-{args.run_id}", "required_models_json": list(REQUIRED_MODELS), "provided_models_json": list(REQUIRED_MODELS), "missing_models_json": [], "allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"], "bundle_hash_sha256": bundle_hash, "model_quality_status_json": {row["model_name"]: row["quality_status"] for row in model_manifest_rows}, "calibration_quality_status_json": {row["model_name"]: row["quality_status"] for row in calibration_manifest_rows}, "backtest_status": backtest_manifest.get("status"), "backtest_status_reasons_json": backtest_manifest.get("status_reasons", []), "backtest_metrics_json": backtest_manifest.get("metrics", {}), "complete": True, "status": EXPORT_STATUS, }, ) write_json( artifact_root / "manifests" / "training_export_manifest.json", {"created_at": utc_now_text(), "status": EXPORT_STATUS, "artifact_root": str(artifact_root), "bundle_hash_sha256": bundle_hash}, ) logging.info("trader.training.artifact_exported runId=%s status=%s bundleHash=%s path=%s", args.run_id, EXPORT_STATUS, bundle_hash, artifact_root) def _model_manifest_row( model_name: str, artifact_path: str, artifact_hash: str, feature_schema_hash: str, feature_order_hash: str, output_schema_hash: str, metrics: dict, quality_status: str, quality_reasons: list[str], train_start: str, train_end: str, validation_start: str, validation_end: str, status: str, ) -> dict: return { "model_bundle_version": MODEL_BUNDLE_VERSION, "calibration_bundle_version": CALIBRATION_BUNDLE_VERSION, "model_name": model_name, "model_type": model_name, "side": "BOTH", "symbol_scope_json": ["BTC-USDT-PERP"], "bar_interval": "1m", "horizon_minutes": 45, "model_format": "ONNX", "model_runtime": "ONNX_RUNTIME_JAVA", "model_runtime_version": "1.22.0", "onnx_opset_version": 17, "producer_name": "trader-training", "producer_version": "v4-p0", "feature_version": FEATURE_VERSION, "feature_schema_path": "schemas/feature_schema.json", "feature_schema_hash": feature_schema_hash, "feature_order_path": "schemas/feature_order.json", "feature_order_hash": feature_order_hash, "input_tensor_name": "features", "input_dtype": "FLOAT32", "input_shape_json": {"features": len(FEATURE_ORDER), "batch": 1}, "input_example_path": "examples/sample_input.json", "output_schema_path": "schemas/output_schema.json", "output_schema_hash": output_schema_hash, "output_tensor_names_json": ["prediction"], "output_mapping_json": OUTPUT_MAPPING[model_name], "output_value_rules_json": {"clip_by_output_schema": True}, "label_version": LABEL_VERSION, "split_version": SPLIT_VERSION, "training_fold": "fold_01", "train_start": train_start, "train_end": train_end, "validation_start": validation_start, "validation_end": validation_end, "test_start": validation_start, "test_end": validation_end, "metrics_json": metrics, "quality_status": quality_status, "quality_reasons_json": quality_reasons, "artifact_path": artifact_path, "artifact_hash_sha256": artifact_hash, "source_hash": artifact_hash, "status": status, } def _sample_input(root) -> dict: features = read_parquet(root / "feature" / "feature_frame.parquet") row = features[features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].iloc[0] return {feature: float(row[feature]) for feature in FEATURE_ORDER} def _sample_output() -> dict: return {model: {field: 0.0 for field in fields} for model, fields in MODEL_OUTPUTS.items()} def _split_time(root, split_id: str, key: str) -> str: manifest = read_json(root / "split" / "split_manifest.json") for item in manifest["splits"]: if item["split_id"] == split_id: return item[key] return "2026-01-01T00:00:00Z"