Implement Trader V4 training artifact pipeline

This commit is contained in:
Codex
2026-06-27 16:15:23 +08:00
parent dad6b831b4
commit e58e4a5572
113 changed files with 7959 additions and 477 deletions
+33
View File
@@ -0,0 +1,33 @@
# Trader V4 Training Pipeline
This directory contains the executable training chain for Trader V4.
Large data stays under `/Users/zach/Desktop/quant-strategy-training-data`.
Run order:
```bash
PY=/Users/zach/IdeaProjects/quant-trading-ai/quant-strategy-server/.venv/bin/python
RUN_ID=btc-v4-p0-001
ROOT=/Users/zach/Desktop/quant-strategy-training-data
$PY training/scripts/01_audit_source_data.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19
$PY training/scripts/02_build_replay_1m.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19
$PY training/scripts/03_build_splits.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/04_build_feature_frame.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/05_build_price_plan_context.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/06_build_direction_labels.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/07_build_entry_labels.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/08_build_position_state_samples.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/09_build_continue_exit_risk_labels.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/10_build_train_datasets.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/11_train_small_models.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/12_calibrate_models.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/13_search_pm_thresholds.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/14_integrated_backtest.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/15_export_artifact_bundle.py --run-id $RUN_ID --data-root $ROOT
$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle
$PY training/scripts/17_promote_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --reason "validation_locked and latest_stress passed for SHADOW"
$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --require-active --run-onnx
```
Java SHADOW 只加载 `ACTIVE` 包。15 号脚本永远只生成 `CANDIDATE`,16 号校验通过且上线门槛通过后,17 号脚本才允许把包提升为 `ACTIVE`
+7
View File
@@ -0,0 +1,7 @@
pandas==2.2.3
pyarrow==24.0.0
numpy==2.4.6
scikit-learn==1.7.0
scipy==1.18.0
onnx==1.22.0
onnxruntime==1.27.0
+23
View File
@@ -0,0 +1,23 @@
from __future__ import annotations
import argparse
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.replay import write_audit_outputs
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--symbol", default="BTC-USDT-PERP")
parser.add_argument("--start-date")
parser.add_argument("--end-date")
parser.add_argument("--min-ready-days", type=int, default=250)
args = parser.parse_args()
setup_logging()
write_audit_outputs(args)
if __name__ == "__main__":
main()
+26
View File
@@ -0,0 +1,26 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.replay import build_replay_1m
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--raw-root", type=Path)
parser.add_argument("--symbol", default="BTC-USDT-PERP")
parser.add_argument("--start-date")
parser.add_argument("--end-date")
parser.add_argument("--min-minutes-per-day", type=int, default=1400)
parser.add_argument("--min-replay-ready-days", type=int, default=250)
args = parser.parse_args()
setup_logging()
build_replay_1m(args)
if __name__ == "__main__":
main()
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.replay import build_splits
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--replay-path", type=Path)
parser.add_argument("--fit-inner-start", default="2025-06-20")
parser.add_argument("--fit-inner-end", default="2026-01-15")
parser.add_argument("--tune-inner-start", default="2026-01-16")
parser.add_argument("--tune-inner-end", default="2026-02-28")
parser.add_argument("--validation-locked-start", default="2026-03-01")
parser.add_argument("--validation-locked-end", default="2026-04-30")
parser.add_argument("--latest-stress-start", default="2026-05-01")
parser.add_argument("--latest-stress-end", default="2026-06-19")
parser.add_argument("--gap-minutes", type=int, default=60)
parser.add_argument("--fold-count", type=int, default=3)
args = parser.parse_args()
setup_logging()
build_splits(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,23 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.features import build_feature_frame
from trader_training.io_utils import add_common_args, setup_logging
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--replay-path", type=Path)
parser.add_argument("--split-manifest-path", type=Path)
parser.add_argument("--allow-incomplete-days", action="store_true")
args = parser.parse_args()
setup_logging()
build_feature_frame(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,23 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.labels import write_price_plan_context
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--label-config-path", type=Path)
parser.add_argument("--cost-config-path", type=Path)
parser.add_argument("--price-plan-id", default="btc-p0-plan-45m")
args = parser.parse_args()
setup_logging()
write_price_plan_context(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,23 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.labels import build_direction_labels
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--feature-path", type=Path)
parser.add_argument("--replay-path", type=Path)
parser.add_argument("--label-config-path", type=Path)
args = parser.parse_args()
setup_logging()
build_direction_labels(args)
if __name__ == "__main__":
main()
+25
View File
@@ -0,0 +1,25 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.labels import build_entry_labels
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--feature-path", type=Path)
parser.add_argument("--replay-path", type=Path)
parser.add_argument("--label-config-path", type=Path)
parser.add_argument("--cost-config-path", type=Path)
parser.add_argument("--price-plan-context-path", type=Path)
args = parser.parse_args()
setup_logging()
build_entry_labels(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,21 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.labels import build_position_state_samples
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--entry-label-path", type=Path)
args = parser.parse_args()
setup_logging()
build_position_state_samples(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,25 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.labels import build_continue_exit_risk_labels
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--feature-path", type=Path)
parser.add_argument("--replay-path", type=Path)
parser.add_argument("--label-config-path", type=Path)
parser.add_argument("--cost-config-path", type=Path)
parser.add_argument("--price-plan-context-path", type=Path)
args = parser.parse_args()
setup_logging()
build_continue_exit_risk_labels(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,26 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.datasets import build_train_datasets
from trader_training.io_utils import add_common_args, setup_logging
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--feature-path", type=Path)
parser.add_argument("--direction-label-path", type=Path)
parser.add_argument("--entry-label-path", type=Path)
parser.add_argument("--continue-label-path", type=Path)
parser.add_argument("--exit-label-path", type=Path)
parser.add_argument("--risk-label-path", type=Path)
args = parser.parse_args()
setup_logging()
build_train_datasets(args)
if __name__ == "__main__":
main()
+20
View File
@@ -0,0 +1,20 @@
from __future__ import annotations
import argparse
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.training import train_small_models
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--max-rows", type=int, default=0)
args = parser.parse_args()
setup_logging()
train_small_models(args)
if __name__ == "__main__":
main()
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
import argparse
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.training import build_calibrators
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
args = parser.parse_args()
setup_logging()
build_calibrators(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,19 @@
from __future__ import annotations
import argparse
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.pm import search_pm_thresholds
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
args = parser.parse_args()
setup_logging()
search_pm_thresholds(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,19 @@
from __future__ import annotations
import argparse
import _bootstrap # noqa: F401
from trader_training.io_utils import add_common_args, setup_logging
from trader_training.pm import integrated_backtest
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
args = parser.parse_args()
setup_logging()
integrated_backtest(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,21 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.exporter import export_artifact_bundle
from trader_training.io_utils import add_common_args, setup_logging
def main() -> None:
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument("--export-root", type=Path)
args = parser.parse_args()
setup_logging()
export_artifact_bundle(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,22 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import setup_logging
from trader_training.validator import validate_artifact_bundle
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--artifact-root", type=Path, required=True)
parser.add_argument("--require-active", action="store_true")
parser.add_argument("--run-onnx", action="store_true")
args = parser.parse_args()
setup_logging()
validate_artifact_bundle(args)
if __name__ == "__main__":
main()
@@ -0,0 +1,21 @@
from __future__ import annotations
import argparse
from pathlib import Path
import _bootstrap # noqa: F401
from trader_training.io_utils import setup_logging
from trader_training.promote import promote_artifact_bundle
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--artifact-root", type=Path, required=True)
parser.add_argument("--reason", required=True)
args = parser.parse_args()
setup_logging()
promote_artifact_bundle(args)
if __name__ == "__main__":
main()
+8
View File
@@ -0,0 +1,8 @@
from __future__ import annotations
import sys
from pathlib import Path
TRAINING_ROOT = Path(__file__).resolve().parents[1]
if str(TRAINING_ROOT) not in sys.path:
sys.path.insert(0, str(TRAINING_ROOT))
+146
View File
@@ -0,0 +1,146 @@
from __future__ import annotations
import sys
import tempfile
import unittest
from argparse import Namespace
from pathlib import Path
import numpy as np
import pandas as pd
TRAINING_ROOT = Path(__file__).resolve().parents[1]
if str(TRAINING_ROOT) not in sys.path:
sys.path.insert(0, str(TRAINING_ROOT))
from trader_training.onnx_export import LinearHead, export_heads
from trader_training.io_utils import read_json, write_json
from trader_training.promote import promote_artifact_bundle
from trader_training.replay import build_splits
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
class TrainingContractTest(unittest.TestCase):
def test_feature_order_is_v4_contract_size(self) -> None:
self.assertEqual(39, len(FEATURE_ORDER))
self.assertEqual(len(FEATURE_ORDER), len(set(FEATURE_ORDER)))
self.assertEqual("ret_1m_bps", FEATURE_ORDER[0])
self.assertEqual("minutes_to_next_funding", FEATURE_ORDER[-1])
def test_output_mapping_matches_model_outputs(self) -> None:
for model_name, fields in MODEL_OUTPUTS.items():
self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name]))
self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields])
def test_split_builder_uses_locked_validation_contract(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)
replay_path = data_root / "replay_1m.parquet"
frame = pd.DataFrame(
{
"event_time": pd.date_range("2025-06-20", "2026-06-19", freq="D", tz="UTC"),
"symbol": "BTC-USDT-PERP",
}
)
frame.to_parquet(replay_path, index=False)
build_splits(
Namespace(
data_root=data_root,
run_id="unit-split",
replay_path=replay_path,
fit_inner_start="2025-06-20",
fit_inner_end="2026-01-15",
tune_inner_start="2026-01-16",
tune_inner_end="2026-02-28",
validation_locked_start="2026-03-01",
validation_locked_end="2026-04-30",
latest_stress_start="2026-05-01",
latest_stress_end="2026-06-19",
gap_minutes=0,
fold_count=2,
)
)
manifest = read_json(data_root / "trader-v4" / "runs" / "unit-split" / "split" / "split_manifest.json")
self.assertEqual(set(TRAINING_SPLITS), {item["split_id"] for item in manifest["splits"]})
self.assertEqual([VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], manifest["sealed_splits"])
self.assertEqual("FINAL_GATE_ONLY", manifest["latest_stress_policy"])
def test_exported_onnx_accepts_java_feature_shape(self) -> None:
import onnxruntime as ort
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "direction.onnx"
export_heads(
path,
[
LinearHead(
"direction",
"softmax",
np.zeros((39, 3), dtype=np.float32),
np.array([0.1, 0.2, 0.3], dtype=np.float32),
)
],
)
session = ort.InferenceSession(str(path))
output = session.run(None, {"features": np.zeros((1, 39), dtype=np.float32)})[0]
self.assertEqual((1, 3), output.shape)
self.assertAlmostEqual(1.0, float(output.sum()), places=6)
def test_promotion_requires_passed_validation_and_marks_all_manifests_active(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp) / "artifact_bundle"
manifest_dir = root / "manifests"
manifest_dir.mkdir(parents=True)
write_json(root.parent / "artifact_validation_result.json", {"status": "PASS", "release_gate_status": "PASS", "release_gate_reasons": []})
write_json(manifest_dir / "model_bundle_manifest.json", {"status": "CANDIDATE"})
write_json(manifest_dir / "model_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
write_json(manifest_dir / "calibration_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
write_json(manifest_dir / "position_manager_manifest.json", {"status": "CANDIDATE"})
write_json(manifest_dir / "training_export_manifest.json", {"status": "CANDIDATE"})
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_bundle_manifest.json")["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_manifest.json")[0]["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "calibration_manifest.json")[0]["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "position_manager_manifest.json")["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "training_export_manifest.json")["status"])
def test_promotion_refuses_failed_validation(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp) / "artifact_bundle"
(root / "manifests").mkdir(parents=True)
write_json(root.parent / "artifact_validation_result.json", {"status": "FAIL"})
with self.assertRaises(SystemExit):
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
result = read_json(root.parent / "artifact_promotion_result.json")
self.assertEqual("REFUSED", result["status"])
self.assertEqual("validation result is not PASS", result["message"])
def test_promotion_refuses_failed_release_gate_and_overwrites_stale_result(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp) / "artifact_bundle"
(root / "manifests").mkdir(parents=True)
write_json(root.parent / "artifact_promotion_result.json", {"status": "ACTIVE"})
write_json(
root.parent / "artifact_validation_result.json",
{
"status": "PASS",
"release_gate_status": "REJECTED",
"release_gate_reasons": ["backtest_status=REJECTED"],
},
)
with self.assertRaises(SystemExit):
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
result = read_json(root.parent / "artifact_promotion_result.json")
self.assertEqual("REFUSED", result["status"])
self.assertEqual("release gate is not PASS", result["message"])
self.assertEqual(["backtest_status=REJECTED"], result["release_gate_reasons"])
if __name__ == "__main__":
unittest.main()
+1
View File
@@ -0,0 +1 @@
"""Trader V4 training pipeline."""
+112
View File
@@ -0,0 +1,112 @@
from __future__ import annotations
import logging
from typing import Any
import pandas as pd
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS
def _feature_base(root, args) -> pd.DataFrame:
path = args.feature_path or root / "feature" / "feature_frame.parquet"
frame = read_parquet(path)
require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame")
trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy()
trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy()
if trainable.empty:
raise ValueError("no trainable feature rows; check feature quality report")
return trainable
def build_train_datasets(args: Any) -> None:
root = run_root(args)
feature = _feature_base(root, args)
dataset_dir = root / "dataset"
manifests = {}
direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet")
direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner")
manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds)
entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet")
entry_pivot = _entry_pivot(entry)
entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner")
manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds)
continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet")
continue_ds = feature.merge(
continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]],
on="sample_id",
how="inner",
)
manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds)
exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet")
exit_cols = [
"sample_id",
"long_exit_target",
"short_exit_target",
"long_adverse_move_bps",
"short_adverse_move_bps",
"adverse_move_prob_label",
"reversal_prob_label",
"stop_hit_prob_label",
"stagnation_prob_label",
]
exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner")
manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds)
risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet")
risk_cols = [
"sample_id",
"market_risk_target",
"market_path_risk_bps",
"long_position_path_risk_bps",
"short_position_path_risk_bps",
"long_position_risk_target",
"short_position_risk_target",
"market_drawdown_prob_label",
"volatility_expansion_prob_label",
"spike_prob_label",
"liquidity_deterioration_prob_label",
"position_drawdown_prob_label",
]
risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner")
manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds)
write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests})
_write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests)
logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests))
def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame:
require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels")
long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"}
)
short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"}
)
return long.merge(short, on="sample_id", how="inner")
def _write_dataset(path, frame: pd.DataFrame) -> dict:
data_hash = write_parquet(path, frame)
return manifest(
path,
{
"row_count": len(frame),
"feature_count": len(FEATURE_ORDER),
"data_hash_sha256": data_hash,
"split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {},
},
)
def _write_dataset_report(path, manifests: dict) -> None:
lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"]
for name, item in manifests.items():
lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |")
write_text(path, "\n".join(lines) + "\n")
+244
View File
@@ -0,0 +1,244 @@
from __future__ import annotations
import logging
import shutil
from pathlib import Path
from typing import Any
import pandas as pd
from trader_training.io_utils import read_json, read_parquet, run_root, sha256_file, sha256_json, utc_now_text, write_json
from trader_training.pm import default_pm_config
from trader_training.schemas import (
CALIBRATION_BUNDLE_VERSION,
FEATURE_ORDER,
FEATURE_VERSION,
FIT_SPLIT,
LABEL_VERSION,
MODEL_BUNDLE_VERSION,
MODEL_OUTPUTS,
OUTPUT_MAPPING,
OUTPUT_SCHEMA,
PM_CONFIG_VERSION,
SPLIT_VERSION,
TUNE_SPLIT,
VALIDATION_LOCKED_SPLIT,
)
REQUIRED_MODELS = ("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK")
EXPORT_STATUS = "CANDIDATE"
def export_artifact_bundle(args: Any) -> None:
root = run_root(args)
export_root = args.export_root or root / "export" / f"trader-model-bundle-{args.run_id}"
artifact_root = export_root / "artifact_bundle"
for folder in ("models", "schemas", "calibrators", "manifests", "examples"):
(artifact_root / folder).mkdir(parents=True, exist_ok=True)
feature_schema_path = root / "feature" / "feature_schema.json"
feature_order_path = root / "feature" / "feature_order.json"
output_schema_path = artifact_root / "schemas" / "output_schema.json"
shutil.copy2(feature_schema_path, artifact_root / "schemas" / "feature_schema.json")
shutil.copy2(feature_order_path, artifact_root / "schemas" / "feature_order.json")
write_json(output_schema_path, OUTPUT_SCHEMA)
sample_input = _sample_input(root)
write_json(artifact_root / "examples" / "sample_input.json", sample_input)
write_json(artifact_root / "examples" / "sample_output.json", _sample_output())
price_plan = read_json(root / "label" / "price_plan_context.json")
write_json(artifact_root / "price_plan_context.json", price_plan)
model_manifest_rows = []
calibration_manifest_rows = []
model_hashes = {}
train_start = _split_time(root, FIT_SPLIT, "start")
train_end = _split_time(root, FIT_SPLIT, "end")
validation_start = _split_time(root, VALIDATION_LOCKED_SPLIT, "start")
validation_end = _split_time(root, VALIDATION_LOCKED_SPLIT, "end")
calibration_train_manifest = read_json(root / "calibration" / "calibration_train_manifest.json")
calibration_quality = {row["model_name"]: row for row in calibration_train_manifest.get("calibrators", [])}
for model_name in REQUIRED_MODELS:
src = root / "model" / model_name.lower() / f"{model_name.lower()}.onnx"
dst = artifact_root / "models" / f"{model_name.lower()}.onnx"
shutil.copy2(src, dst)
model_hashes[model_name] = sha256_file(dst)
cal_src = root / "calibration" / model_name.lower() / "calibrator.json"
cal_dst = artifact_root / "calibrators" / model_name.lower() / "calibrator.json"
cal_dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(cal_src, cal_dst)
cal_hash = sha256_file(cal_dst)
metrics = read_json(root / "model" / model_name.lower() / "model_train_result.json")
model_manifest_rows.append(
_model_manifest_row(
model_name,
f"models/{model_name.lower()}.onnx",
model_hashes[model_name],
sha256_file(artifact_root / "schemas" / "feature_schema.json"),
sha256_file(artifact_root / "schemas" / "feature_order.json"),
sha256_file(output_schema_path),
metrics.get("metrics", {}),
metrics.get("quality_status", "UNKNOWN"),
metrics.get("quality_reasons", []),
train_start,
train_end,
validation_start,
validation_end,
EXPORT_STATUS,
)
)
calibration_manifest_rows.append(
{
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"model_bundle_version": MODEL_BUNDLE_VERSION,
"model_name": model_name,
"calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0",
"calibration_method": "BINNING",
"calibrator_path": f"calibrators/{model_name.lower()}/calibrator.json",
"calibrator_hash_sha256": cal_hash,
"calibration_window_from": _split_time(root, TUNE_SPLIT, "start"),
"calibration_window_to": _split_time(root, TUNE_SPLIT, "end"),
"calibration_metrics_json": {},
"bucket_metrics_json": {},
"output_after_calibration_schema_hash": sha256_file(output_schema_path),
"quality_status": calibration_quality.get(model_name, {}).get("quality_status", "UNKNOWN"),
"quality_reasons_json": calibration_quality.get(model_name, {}).get("quality_reasons", []),
"status": EXPORT_STATUS,
}
)
pm_payload = read_json(root / "pm-search" / "position_manager_config.json") if (root / "pm-search" / "position_manager_config.json").is_file() else {"config": default_pm_config(), "threshold_stability_json": {}}
pm_config = pm_payload["config"]
pm_hash = sha256_json(pm_config)
backtest_manifest = read_json(root / "backtest" / "backtest_manifest.json")
write_json(
artifact_root / "manifests" / "position_manager_manifest.json",
{
"pm_config_version": PM_CONFIG_VERSION,
"model_bundle_version": MODEL_BUNDLE_VERSION,
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"threshold_stability_json": pm_payload.get("threshold_stability_json", {}),
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
"config_json": pm_config,
"config_hash_sha256": pm_hash,
"status": EXPORT_STATUS,
},
)
write_json(artifact_root / "manifests" / "model_manifest.json", model_manifest_rows)
write_json(artifact_root / "manifests" / "calibration_manifest.json", calibration_manifest_rows)
bundle_hash = sha256_json({"models": model_hashes, "feature_order_hash": sha256_file(artifact_root / "schemas" / "feature_order.json"), "pm_hash": pm_hash})
write_json(
artifact_root / "manifests" / "model_bundle_manifest.json",
{
"manifest_schema_version": "trader-model-bundle-manifest-v4-p0",
"model_bundle_version": MODEL_BUNDLE_VERSION,
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"feature_version": FEATURE_VERSION,
"label_version": LABEL_VERSION,
"split_version": SPLIT_VERSION,
"training_run_id": args.run_id,
"training_export_id": f"export-{args.run_id}",
"backtest_manifest_id": f"backtest-{args.run_id}",
"required_models_json": list(REQUIRED_MODELS),
"provided_models_json": list(REQUIRED_MODELS),
"missing_models_json": [],
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
"bundle_hash_sha256": bundle_hash,
"model_quality_status_json": {row["model_name"]: row["quality_status"] for row in model_manifest_rows},
"calibration_quality_status_json": {row["model_name"]: row["quality_status"] for row in calibration_manifest_rows},
"backtest_status": backtest_manifest.get("status"),
"backtest_status_reasons_json": backtest_manifest.get("status_reasons", []),
"backtest_metrics_json": backtest_manifest.get("metrics", {}),
"complete": True,
"status": EXPORT_STATUS,
},
)
write_json(
artifact_root / "manifests" / "training_export_manifest.json",
{"created_at": utc_now_text(), "status": EXPORT_STATUS, "artifact_root": str(artifact_root), "bundle_hash_sha256": bundle_hash},
)
logging.info("trader.training.artifact_exported runId=%s status=%s bundleHash=%s path=%s", args.run_id, EXPORT_STATUS, bundle_hash, artifact_root)
def _model_manifest_row(
model_name: str,
artifact_path: str,
artifact_hash: str,
feature_schema_hash: str,
feature_order_hash: str,
output_schema_hash: str,
metrics: dict,
quality_status: str,
quality_reasons: list[str],
train_start: str,
train_end: str,
validation_start: str,
validation_end: str,
status: str,
) -> dict:
return {
"model_bundle_version": MODEL_BUNDLE_VERSION,
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
"model_name": model_name,
"model_type": model_name,
"side": "BOTH",
"symbol_scope_json": ["BTC-USDT-PERP"],
"bar_interval": "1m",
"horizon_minutes": 45,
"model_format": "ONNX",
"model_runtime": "ONNX_RUNTIME_JAVA",
"model_runtime_version": "1.22.0",
"onnx_opset_version": 17,
"producer_name": "trader-training",
"producer_version": "v4-p0",
"feature_version": FEATURE_VERSION,
"feature_schema_path": "schemas/feature_schema.json",
"feature_schema_hash": feature_schema_hash,
"feature_order_path": "schemas/feature_order.json",
"feature_order_hash": feature_order_hash,
"input_tensor_name": "features",
"input_dtype": "FLOAT32",
"input_shape_json": {"features": len(FEATURE_ORDER), "batch": 1},
"input_example_path": "examples/sample_input.json",
"output_schema_path": "schemas/output_schema.json",
"output_schema_hash": output_schema_hash,
"output_tensor_names_json": ["prediction"],
"output_mapping_json": OUTPUT_MAPPING[model_name],
"output_value_rules_json": {"clip_by_output_schema": True},
"label_version": LABEL_VERSION,
"split_version": SPLIT_VERSION,
"training_fold": "fold_01",
"train_start": train_start,
"train_end": train_end,
"validation_start": validation_start,
"validation_end": validation_end,
"test_start": validation_start,
"test_end": validation_end,
"metrics_json": metrics,
"quality_status": quality_status,
"quality_reasons_json": quality_reasons,
"artifact_path": artifact_path,
"artifact_hash_sha256": artifact_hash,
"source_hash": artifact_hash,
"status": status,
}
def _sample_input(root) -> dict:
features = read_parquet(root / "feature" / "feature_frame.parquet")
row = features[features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].iloc[0]
return {feature: float(row[feature]) for feature in FEATURE_ORDER}
def _sample_output() -> dict:
return {model: {field: 0.0 for field in fields} for model, fields in MODEL_OUTPUTS.items()}
def _split_time(root, split_id: str, key: str) -> str:
manifest = read_json(root / "split" / "split_manifest.json")
for item in manifest["splits"]:
if item["split_id"] == split_id:
return item[key]
return "2026-01-01T00:00:00Z"
+342
View File
@@ -0,0 +1,342 @@
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import pandas as pd
from trader_training.io_utils import (
manifest,
read_parquet,
require_columns,
run_root,
sha256_json,
to_utc_series,
write_json,
write_parquet,
write_text,
)
from trader_training.replay import assign_split
from trader_training.schemas import FEATURE_ORDER, FEATURE_VERSION, FEATURES, FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
META_COLUMNS = [
"sample_id",
"symbol",
"event_time",
"open_time_ms",
"split_id",
"walk_forward_fold",
"feature_version",
"data_quality_flag",
]
def _safe_divide(numerator: pd.Series, denominator: pd.Series, default: float = 0.0) -> pd.Series:
result = numerator / denominator.replace(0, np.nan)
return result.replace([np.inf, -np.inf], np.nan).fillna(default)
def _rolling_rank_last(values: pd.Series, window: int) -> pd.Series:
def calc(raw: np.ndarray) -> float:
last = raw[-1]
return float(np.sum(raw <= last) / len(raw))
return values.rolling(window, min_periods=window).apply(calc, raw=True)
def _complete_days(frame: pd.DataFrame) -> pd.DataFrame:
frame = frame.copy()
frame["event_date"] = frame["event_time"].dt.strftime("%Y-%m-%d")
counts = frame.groupby(["symbol", "event_date"])["event_time"].count()
complete = counts[counts == 1440].reset_index()[["symbol", "event_date"]]
return frame.merge(complete, on=["symbol", "event_date"], how="inner").drop(columns=["event_date"])
def build_feature_frame(args: Any) -> None:
root = run_root(args)
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
split_manifest_path = args.split_manifest_path or root / "split" / "split_manifest.json"
replay = read_parquet(replay_path)
required = [
"symbol",
"event_time",
"open_time_ms",
"open",
"high",
"low",
"close",
"volume",
"taker_buy_volume",
"taker_sell_volume",
"funding_bps",
"mark_price",
"index_price",
"next_funding_time",
"open_interest",
"spread_bps",
"level1_ofi_1m",
"liquidation_buy_notional_1m",
"liquidation_sell_notional_1m",
"liquidation_available",
]
require_columns(replay, required, "replay_1m")
replay = replay.copy()
replay["event_time"] = to_utc_series(replay["event_time"])
replay["next_funding_time"] = to_utc_series(replay["next_funding_time"])
replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True)
if not args.allow_incomplete_days:
before = len(replay)
replay = _complete_days(replay)
logging.info("trader.training.feature_complete_days rowBefore=%s rowAfter=%s", before, len(replay))
frames: list[pd.DataFrame] = []
for symbol, group in replay.groupby("symbol", sort=False):
group = group.sort_values("event_time").reset_index(drop=True).copy()
close = group["close"].astype(float)
high = group["high"].astype(float)
low = group["low"].astype(float)
volume = group["volume"].astype(float)
log_ret = np.log(close / close.shift(1))
group["ret_1m_bps"] = (close / close.shift(1) - 1.0) * 10000.0
group["ret_5m_bps"] = (close / close.shift(5) - 1.0) * 10000.0
group["ret_15m_bps"] = (close / close.shift(15) - 1.0) * 10000.0
group["ret_60m_bps"] = (close / close.shift(60) - 1.0) * 10000.0
group["ret_240m_bps"] = (close / close.shift(240) - 1.0) * 10000.0
group["realized_vol_15m_bps"] = log_ret.rolling(15, min_periods=15).std() * 10000.0
group["realized_vol_60m_bps"] = log_ret.rolling(60, min_periods=60).std() * 10000.0
group["vol_ratio_15m_60m"] = _safe_divide(group["realized_vol_15m_bps"], group["realized_vol_60m_bps"].clip(lower=1.0))
group["range_15m_bps"] = (high.rolling(15, min_periods=15).max() / low.rolling(15, min_periods=15).min() - 1.0) * 10000.0
group["range_60m_bps"] = (high.rolling(60, min_periods=60).max() / low.rolling(60, min_periods=60).min() - 1.0) * 10000.0
vol_mean = volume.rolling(60, min_periods=60).mean()
vol_std = volume.rolling(60, min_periods=60).std().replace(0, np.nan)
group["volume_zscore_60m"] = ((volume - vol_mean) / vol_std).fillna(0.0)
group["trend_consistency_15m"] = np.sign(group["ret_1m_bps"]).rolling(15, min_periods=15).mean()
high60 = high.rolling(60, min_periods=60).max()
low60 = low.rolling(60, min_periods=60).min()
group["channel_position_60m_pct"] = ((close - low60) / (high60 - low60).clip(lower=1e-12)).clip(0.0, 1.0)
prev_high60 = high.shift(1).rolling(60, min_periods=60).max()
prev_low60 = low.shift(1).rolling(60, min_periods=60).min()
group["upper_breakout_60m_bps"] = ((close / prev_high60 - 1.0).clip(lower=0.0)) * 10000.0
group["lower_breakout_60m_bps"] = ((prev_low60 / close - 1.0).clip(lower=0.0)) * 10000.0
recent_high15 = high.rolling(15, min_periods=15).max()
recent_low15 = low.rolling(15, min_periods=15).min()
broke_up = recent_high15 > prev_high60
broke_down = recent_low15 < prev_low60
group["upper_failed_break_reclaim_15m_bps"] = np.where(broke_up, ((prev_high60 - close).clip(lower=0.0) / close) * 10000.0, 0.0)
group["lower_failed_break_reclaim_15m_bps"] = np.where(broke_down, ((close - prev_low60).clip(lower=0.0) / close) * 10000.0, 0.0)
group["sweep_up_15m_bps"] = ((recent_high15 / close - 1.0).clip(lower=0.0)) * 10000.0
group["sweep_down_15m_bps"] = ((close / recent_low15 - 1.0).clip(lower=0.0)) * 10000.0
rank = _rolling_rank_last(group["range_15m_bps"], 240)
group["compression_score_4h_pct"] = 1.0 - rank
group["compression_release_15m_bps"] = (group["range_15m_bps"] - group["range_15m_bps"].rolling(240, min_periods=240).median()).clip(lower=0.0)
buy = group["taker_buy_volume"].astype(float)
sell = group["taker_sell_volume"].astype(float)
group["taker_imbalance_1m"] = _safe_divide(buy - sell, buy + sell)
group["taker_imbalance_5m"] = _safe_divide(buy.rolling(5, min_periods=5).sum() - sell.rolling(5, min_periods=5).sum(), (buy + sell).rolling(5, min_periods=5).sum())
group["taker_imbalance_15m"] = _safe_divide(buy.rolling(15, min_periods=15).sum() - sell.rolling(15, min_periods=15).sum(), (buy + sell).rolling(15, min_periods=15).sum())
group["spread_rank_24h_pct"] = _rolling_rank_last(group["spread_bps"].astype(float), 1440)
group["oi_delta_15m_bps"] = (group["open_interest"].astype(float) / group["open_interest"].astype(float).shift(15) - 1.0) * 10000.0
group["oi_delta_60m_bps"] = (group["open_interest"].astype(float) / group["open_interest"].astype(float).shift(60) - 1.0) * 10000.0
group["mark_index_basis_bps"] = (group["mark_price"].astype(float) / group["index_price"].astype(float) - 1.0) * 10000.0
liq_buy = group["liquidation_buy_notional_1m"].astype(float)
liq_sell = group["liquidation_sell_notional_1m"].astype(float)
liq_total_15 = (liq_buy + liq_sell).rolling(15, min_periods=1).sum()
group["liquidation_imbalance_15m"] = _safe_divide(liq_buy.rolling(15, min_periods=1).sum() - liq_sell.rolling(15, min_periods=1).sum(), liq_total_15)
liq_mean = liq_total_15.rolling(1440, min_periods=60).mean()
liq_std = liq_total_15.rolling(1440, min_periods=60).std().replace(0, np.nan)
group["liquidation_notional_zscore_15m"] = ((liq_total_15 - liq_mean) / liq_std).fillna(0.0)
minute_of_day = group["event_time"].dt.hour * 60 + group["event_time"].dt.minute
group["minute_of_day_sin"] = np.sin(2 * np.pi * minute_of_day / 1440.0)
group["minute_of_day_cos"] = np.cos(2 * np.pi * minute_of_day / 1440.0)
group["minutes_to_next_funding"] = ((group["next_funding_time"] - group["event_time"]).dt.total_seconds() / 60.0).clip(0.0, 480.0)
group["symbol"] = symbol
frames.append(group)
frame = pd.concat(frames, ignore_index=True)
frame["sample_id"] = frame["symbol"].astype(str) + ":" + frame["open_time_ms"].astype(str)
frame["split_id"] = assign_split(frame["event_time"], split_manifest_path)
frame["walk_forward_fold"] = np.where(frame["split_id"].eq(FIT_SPLIT), "fold_01", "NO_FOLD")
frame["feature_version"] = FEATURE_VERSION
hard_na = frame[FEATURE_ORDER].isna().any(axis=1)
optional_missing = frame["liquidation_available"].fillna(0).eq(0)
frame["data_quality_flag"] = np.where(hard_na, "WARMUP", np.where(optional_missing, "PARTIAL_OPTIONAL", "OK"))
ordered = frame[META_COLUMNS + FEATURE_ORDER].copy()
for feature in FEATURE_ORDER:
ordered[feature] = pd.to_numeric(ordered[feature], errors="coerce").astype("float32")
feature_dir = root / "feature"
data_hash = write_parquet(feature_dir / "feature_frame.parquet", ordered)
schema = [feature.as_json() for feature in FEATURES]
feature_order_hash = write_json(feature_dir / "feature_order.json", FEATURE_ORDER)
feature_schema_hash = write_json(feature_dir / "feature_schema.json", schema)
write_json(
feature_dir / "feature_frame.manifest.json",
manifest(
feature_dir / "feature_frame.parquet",
{
"row_count": len(ordered),
"ok_row_count": int(ordered["data_quality_flag"].eq("OK").sum()),
"partial_optional_row_count": int(ordered["data_quality_flag"].eq("PARTIAL_OPTIONAL").sum()),
"warmup_row_count": int(ordered["data_quality_flag"].eq("WARMUP").sum()),
"feature_count": len(FEATURE_ORDER),
"feature_version": FEATURE_VERSION,
"feature_order_hash": feature_order_hash,
"feature_schema_hash": feature_schema_hash,
"data_hash_sha256": data_hash,
},
),
)
write_feature_report(feature_dir / "feature_quality_report.md", ordered, feature_schema_hash, feature_order_hash)
logging.info(
"trader.training.feature_written runId=%s rowCount=%s splitCounts=%s eventFrom=%s eventTo=%s path=%s",
args.run_id,
len(ordered),
ordered["split_id"].value_counts().to_dict(),
ordered["event_time"].min(),
ordered["event_time"].max(),
feature_dir / "feature_frame.parquet",
)
def write_feature_report(path, frame: pd.DataFrame, feature_schema_hash: str, feature_order_hash: str) -> None:
split_rows = []
for split_id, group in frame.groupby("split_id", sort=True):
split_rows.append(
{
"split_id": split_id,
"rows": len(group),
"start": str(group["event_time"].min()),
"end": str(group["event_time"].max()),
"ok": int(group["data_quality_flag"].eq("OK").sum()),
"partial_optional": int(group["data_quality_flag"].eq("PARTIAL_OPTIONAL").sum()),
"warmup": int(group["data_quality_flag"].eq("WARMUP").sum()),
}
)
finite_rows = []
for feature in FEATURE_ORDER:
series = pd.to_numeric(frame[feature], errors="coerce")
values = series.to_numpy(dtype=float)
finite_rows.append(
{
"feature": feature,
"nan_count": int(series.isna().sum()),
"inf_count": int(np.isinf(values).sum()),
"finite_count": int(np.isfinite(values).sum()),
}
)
correlation_rows = _high_correlation_rows(frame)
drift_rows = _drift_rows(frame)
lines = [
"# Trader Feature Quality Report",
"",
f"- row_count: {len(frame)}",
f"- OK: {int(frame['data_quality_flag'].eq('OK').sum())}",
f"- PARTIAL_OPTIONAL: {int(frame['data_quality_flag'].eq('PARTIAL_OPTIONAL').sum())}",
f"- WARMUP: {int(frame['data_quality_flag'].eq('WARMUP').sum())}",
f"- feature_schema_hash: {feature_schema_hash}",
f"- feature_order_hash: {feature_order_hash}",
"",
"## Split Coverage",
"",
_markdown_table(split_rows, ["split_id", "rows", "start", "end", "ok", "partial_optional", "warmup"]),
"",
"## Source Coverage",
"",
f"- replay_1m_required_columns: present",
f"- liquidation_available_share: {float(frame['liquidation_available'].mean()):.6f}",
f"- feature_rows_with_optional_liquidation_missing: {int(frame['data_quality_flag'].eq('PARTIAL_OPTIONAL').sum())}",
"",
"## Leakage Check",
"",
"- 所有特征只使用当前分钟收盘后已经知道的数据,滚动窗口都只看 `<= t`。",
"- 未来价格、未来收益、目标标签不进入 `feature_frame.parquet`。",
"",
"## Extreme Value Check",
"",
_markdown_table(finite_rows, ["feature", "nan_count", "inf_count", "finite_count"]),
"",
"## High Correlation Check",
"",
_markdown_table(correlation_rows, ["feature_a", "feature_b", "corr_abs"]),
"",
"## Drift Check",
"",
_markdown_table(
drift_rows,
["feature", "train_p50", "tune_p50", "validation_p50", "p50_diff", "train_p99", "tune_p99", "validation_p99", "p99_diff"],
),
"",
"## Distribution",
"",
"| feature | null_count | min | p01 | p50 | p99 | max |",
"| --- | ---: | ---: | ---: | ---: | ---: | ---: |",
]
for feature in FEATURE_ORDER:
series = pd.to_numeric(frame[feature], errors="coerce")
quantiles = series.quantile([0.01, 0.5, 0.99])
lines.append(
f"| {feature} | {int(series.isna().sum())} | {series.min():.6g} | {quantiles.loc[0.01]:.6g} | {quantiles.loc[0.5]:.6g} | {quantiles.loc[0.99]:.6g} | {series.max():.6g} |"
)
write_text(path, "\n".join(lines) + "\n")
def feature_order_hash() -> str:
return sha256_json(FEATURE_ORDER)
def _high_correlation_rows(frame: pd.DataFrame) -> list[dict[str, object]]:
sample = frame[FEATURE_ORDER].apply(pd.to_numeric, errors="coerce").dropna()
if len(sample) > 5000:
sample = sample.sample(5000, random_state=7)
if sample.empty:
return [{"feature_a": "NONE", "feature_b": "NONE", "corr_abs": 0.0}]
corr = sample.corr().abs()
rows = []
for left_index, left in enumerate(FEATURE_ORDER):
for right in FEATURE_ORDER[left_index + 1 :]:
value = corr.loc[left, right]
if pd.notna(value) and value >= 0.95:
rows.append({"feature_a": left, "feature_b": right, "corr_abs": round(float(value), 6)})
return rows[:30] or [{"feature_a": "NONE", "feature_b": "NONE", "corr_abs": 0.0}]
def _drift_rows(frame: pd.DataFrame) -> list[dict[str, object]]:
train = frame[frame["split_id"].eq(FIT_SPLIT)]
validation = frame[frame["split_id"].eq(VALIDATION_LOCKED_SPLIT)]
tune = frame[frame["split_id"].eq(TUNE_SPLIT)]
rows = []
for feature in FEATURE_ORDER:
train_series = pd.to_numeric(train[feature], errors="coerce")
validation_series = pd.to_numeric(validation[feature], errors="coerce")
tune_series = pd.to_numeric(tune[feature], errors="coerce")
train_p50 = float(train_series.quantile(0.5)) if not train_series.empty else 0.0
tune_p50 = float(tune_series.quantile(0.5)) if not tune_series.empty else 0.0
validation_p50 = float(validation_series.quantile(0.5)) if not validation_series.empty else 0.0
train_p99 = float(train_series.quantile(0.99)) if not train_series.empty else 0.0
tune_p99 = float(tune_series.quantile(0.99)) if not tune_series.empty else 0.0
validation_p99 = float(validation_series.quantile(0.99)) if not validation_series.empty else 0.0
rows.append(
{
"feature": feature,
"train_p50": round(train_p50, 6),
"tune_p50": round(tune_p50, 6),
"validation_p50": round(validation_p50, 6),
"p50_diff": round(validation_p50 - train_p50, 6),
"train_p99": round(train_p99, 6),
"tune_p99": round(tune_p99, 6),
"validation_p99": round(validation_p99, 6),
"p99_diff": round(validation_p99 - train_p99, 6),
}
)
return rows
def _markdown_table(rows: list[dict[str, object]], columns: list[str]) -> str:
if not rows:
rows = [{column: "" for column in columns}]
lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"]
for row in rows:
lines.append("| " + " | ".join(str(row.get(column, "")) for column in columns) + " |")
return "\n".join(lines)
+162
View File
@@ -0,0 +1,162 @@
from __future__ import annotations
import argparse
import hashlib
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Iterable
import pandas as pd
DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data")
DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw"
def setup_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s event=%(message)s",
)
def add_common_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT)
parser.add_argument("--run-id", required=True)
parser.add_argument("--config", type=Path)
parser.add_argument("--workspace", type=Path)
parser.add_argument("--fail-fast", action="store_true")
def run_root(args: argparse.Namespace) -> Path:
return args.data_root / "trader-v4" / "runs" / args.run_id
def ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def canonical_json_bytes(value: Any) -> bytes:
return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8")
def sha256_bytes(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def sha256_json(value: Any) -> str:
return sha256_bytes(canonical_json_bytes(value))
def sha256_file(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as fh:
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def write_json(path: Path, value: Any) -> str:
ensure_dir(path.parent)
data = canonical_json_bytes(value)
path.write_bytes(data + b"\n")
return sha256_bytes(data)
def read_json(path: Path) -> Any:
with path.open("r", encoding="utf-8") as fh:
return json.load(fh)
def write_parquet(path: Path, frame: pd.DataFrame) -> str:
ensure_dir(path.parent)
frame.to_parquet(path, index=False)
return sha256_file(path)
def read_parquet(path: Path) -> pd.DataFrame:
if not path.is_file():
raise FileNotFoundError(f"required parquet is missing: {path}")
return pd.read_parquet(path)
def write_text(path: Path, text: str) -> None:
ensure_dir(path.parent)
path.write_text(text, encoding="utf-8")
def utc_now_text() -> str:
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def to_utc_series(values: pd.Series) -> pd.Series:
if pd.api.types.is_datetime64_any_dtype(values):
return pd.to_datetime(values, utc=True)
numeric = pd.to_numeric(values, errors="coerce")
if numeric.notna().any():
max_value = numeric.dropna().abs().max()
unit = "ms" if max_value > 10_000_000_000 else "s"
return pd.to_datetime(numeric, unit=unit, utc=True)
return pd.to_datetime(values, utc=True, errors="coerce")
def open_time_ms(values: pd.Series) -> pd.Series:
dt = to_utc_series(values)
return (dt.astype("int64") // 1_000_000).astype("Int64")
def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]:
return start_date, end_date
def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]:
base = raw_root / f"table={table}"
if not base.is_dir():
return []
files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet"))
selected: list[Path] = []
for file in files:
dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "")
if start_date and dt_part < start_date:
continue
if end_date and dt_part > end_date:
continue
selected.append(file)
return selected
def read_partitioned_table(
raw_root: Path,
table: str,
symbol: str,
start_date: str | None,
end_date: str | None,
columns: Iterable[str] | None = None,
) -> pd.DataFrame:
files = partition_files(raw_root, table, symbol, start_date, end_date)
if not files:
return pd.DataFrame()
logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files))
frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files]
frame = pd.concat(frames, ignore_index=True)
logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame))
return frame
def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None:
missing = [column for column in columns if column not in frame.columns]
if missing:
raise ValueError(f"{name} is missing required columns: {missing}")
def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]:
payload = {
"path": str(path),
"hash_sha256": sha256_file(path) if path.is_file() else None,
"created_at": utc_now_text(),
}
payload.update(extra)
return payload
+417
View File
@@ -0,0 +1,417 @@
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import pandas as pd
from trader_training.io_utils import (
manifest,
read_json,
read_parquet,
require_columns,
run_root,
sha256_json,
to_utc_series,
write_json,
write_parquet,
write_text,
)
from trader_training.schemas import LABEL_VERSION
DEFAULT_LABEL_CONFIG = {
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
}
DEFAULT_COST_CONFIG = {
"fee_bps": 4.0,
"slippage_bps": 2.0,
"funding_cost_bps": 0.5,
}
def _load_config(path, default):
if path is None:
return default
value = read_json(path)
merged = default.copy()
for key, item in value.items():
if isinstance(item, dict) and isinstance(merged.get(key), dict):
merged[key] = {**merged[key], **item}
else:
merged[key] = item
return merged
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
root = run_root(args)
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
features = read_parquet(feature_path)
replay = read_parquet(replay_path)
require_columns(features, ("sample_id", "symbol", "event_time", "open_time_ms", "split_id", "walk_forward_fold", "data_quality_flag"), "feature_frame")
require_columns(replay, ("symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "spread_bps"), "replay_1m")
features = features.copy()
replay = replay.copy()
features["event_time"] = to_utc_series(features["event_time"])
replay["event_time"] = to_utc_series(replay["event_time"])
replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True)
return features, replay
def _future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
start = index + 1
end = min(len(group), index + horizon + 1)
return group.iloc[start:end]
def _contiguous_future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
path = _future_path(group, index, horizon)
if len(path) < horizon:
return pd.DataFrame()
current_ms = int(group.iloc[index]["open_time_ms"])
expected = current_ms + np.arange(1, horizon + 1, dtype=np.int64) * 60_000
actual = path["open_time_ms"].astype("int64").to_numpy()
if len(actual) != len(expected) or not np.array_equal(actual, expected):
return pd.DataFrame()
return path
def _side_return_bps(side: str, entry_price: float, exit_price: float) -> float:
if side == "LONG":
return (exit_price / entry_price - 1.0) * 10000.0
return (entry_price / exit_price - 1.0) * 10000.0
def _path_stats(group: pd.DataFrame, index: int, side: str, horizon: int, target_bps: float, stop_bps: float) -> dict[str, Any]:
current = group.iloc[index]
entry = float(current["close"])
path = _contiguous_future_path(group, index, horizon)
if path.empty:
return {"valid": False}
target_price = entry * (1.0 + target_bps / 10000.0) if side == "LONG" else entry * (1.0 - target_bps / 10000.0)
stop_price = entry * (1.0 - stop_bps / 10000.0) if side == "LONG" else entry * (1.0 + stop_bps / 10000.0)
target_hit = False
stop_hit = False
ambiguous = False
time_to_target_ms = -1
time_to_stop_ms = -1
for _, row in path.iterrows():
high = float(row["high"])
low = float(row["low"])
if side == "LONG":
target_now = high >= target_price
stop_now = low <= stop_price
else:
target_now = low <= target_price
stop_now = high >= stop_price
if target_now and stop_now:
ambiguous = True
stop_hit = True
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
break
if target_now:
target_hit = True
time_to_target_ms = int(row["open_time_ms"] - current["open_time_ms"])
break
if stop_now:
stop_hit = True
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
break
exit_price = float(path.iloc[-1]["close"])
final_return_bps = _side_return_bps(side, entry, exit_price)
if side == "LONG":
mfe = (path["high"].max() / entry - 1.0) * 10000.0
mae = (entry / path["low"].min() - 1.0) * 10000.0
else:
mfe = (entry / path["low"].min() - 1.0) * 10000.0
mae = (path["high"].max() / entry - 1.0) * 10000.0
if target_hit:
gross = target_bps
elif stop_hit:
gross = -stop_bps
else:
gross = final_return_bps
return {
"valid": True,
"target_hit": int(target_hit),
"stop_hit": int(stop_hit),
"timeout_hit": int(not target_hit and not stop_hit),
"ambiguous_hit": int(ambiguous),
"time_to_target_ms": time_to_target_ms,
"time_to_stop_ms": time_to_stop_ms,
"gross_edge_bps": float(gross),
"future_return_bps": float(final_return_bps),
"mfe_bps": float(mfe),
"mae_bps": float(mae),
"future_spread_p80": float(path["spread_bps"].quantile(0.8)),
"future_realized_vol_bps": float(np.log(path["close"].astype(float) / path["close"].astype(float).shift(1)).std() * 10000.0),
}
def write_price_plan_context(args: Any) -> None:
root = run_root(args)
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
entry = labels["entry"]
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
context = {
"pricePlanId": args.price_plan_id,
"pricePlanConfigHash": sha256_json({"entry": entry, "cost": cost}),
"stopDistanceBps": float(entry["stop_bps"]),
"targetDistanceBps": float(entry["target_bps"]),
"maxHoldMinutes": int(entry["max_hold_minutes"]),
"costBps": cost_bps,
}
path = root / "label" / "price_plan_context.json"
write_json(path, context)
frame = pd.DataFrame([{
"price_plan_id": context["pricePlanId"],
"price_plan_hash": context["pricePlanConfigHash"],
"target_bps": context["targetDistanceBps"],
"stop_bps": context["stopDistanceBps"],
"max_hold_minutes": context["maxHoldMinutes"],
"cost_bps": context["costBps"],
}])
write_parquet(root / "label" / "price_plan_context.parquet", frame)
logging.info("trader.training.price_plan_written runId=%s path=%s", args.run_id, path)
def build_direction_labels(args: Any) -> None:
root = run_root(args)
config = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)["direction"]
features, replay = _base_frames(args)
horizon = int(config["horizon_minutes"])
replay = replay[["symbol", "event_time", "open_time_ms", "close"]].copy()
future = replay[["symbol", "open_time_ms", "close"]].copy()
future["open_time_ms"] = future["open_time_ms"].astype("int64") - horizon * 60_000
future = future.rename(columns={"close": "future_close"})
merged = features.merge(replay[["symbol", "open_time_ms", "close"]], on=["symbol", "open_time_ms"], how="left")
merged = merged.merge(future, on=["symbol", "open_time_ms"], how="left")
merged["future_return_bps"] = (merged["future_close"] / merged["close"] - 1.0) * 10000.0
merged["direction_label"] = np.select(
[merged["future_return_bps"] >= float(config["long_threshold_bps"]), merged["future_return_bps"] <= float(config["short_threshold_bps"])],
["LONG", "SHORT"],
default="NEUTRAL",
)
out = pd.DataFrame(
{
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"horizon_minutes": horizon,
"future_return_bps": merged["future_return_bps"],
"direction_label": merged["direction_label"],
"long_target": merged["direction_label"].eq("LONG").astype("int8"),
"short_target": merged["direction_label"].eq("SHORT").astype("int8"),
"neutral_target": merged["direction_label"].eq("NEUTRAL").astype("int8"),
"split_id": merged["split_id"],
"walk_forward_fold": merged["walk_forward_fold"],
"label_version": LABEL_VERSION,
}
).dropna(subset=["future_return_bps"])
path = root / "label" / "direction_labels.parquet"
data_hash = write_parquet(path, out)
_write_label_manifest(root / "label" / "direction_labels.manifest.json", path, out, data_hash)
_write_distribution_report(root / "label" / "direction_label_report.md", out, "direction_label")
logging.info("trader.training.direction_labels_written runId=%s rowCount=%s", args.run_id, len(out))
def build_entry_labels(args: Any) -> None:
root = run_root(args)
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
plan_path = args.price_plan_context_path or root / "label" / "price_plan_context.json"
plan = read_json(plan_path)
features, replay = _base_frames(args)
entry_conf = labels["entry"]
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
rows: list[dict[str, Any]] = []
groups, index_by_key = _group_replay_with_index(replay)
for feature in features.itertuples(index=False):
key = (feature.symbol, int(feature.open_time_ms))
index = index_by_key.get(key)
if index is None:
continue
group = groups[feature.symbol]
for side in ("LONG", "SHORT"):
stats = _path_stats(group, index, side, int(entry_conf["max_hold_minutes"]), float(entry_conf["target_bps"]), float(entry_conf["stop_bps"]))
if not stats["valid"]:
continue
expected = stats["gross_edge_bps"] - cost_bps
rows.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"side": side,
"price_plan_id": plan["pricePlanId"],
"price_plan_hash": plan["pricePlanConfigHash"],
"target_hit": stats["target_hit"],
"stop_hit": stats["stop_hit"],
"timeout_hit": stats["timeout_hit"],
"ambiguous_hit": stats["ambiguous_hit"],
"time_to_target_ms": stats["time_to_target_ms"],
"time_to_stop_ms": stats["time_to_stop_ms"],
"gross_edge_bps": stats["gross_edge_bps"],
"cost_bps": cost_bps,
"expected_net_edge_bps": expected,
"entry_target": int(stats["target_hit"] == 1 and expected >= float(entry_conf["min_expected_net_edge_bps"])),
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
out = pd.DataFrame(rows)
path = root / "label" / "entry_labels.parquet"
data_hash = write_parquet(path, out)
_write_label_manifest(root / "label" / "entry_labels.manifest.json", path, out, data_hash)
_write_distribution_report(root / "label" / "entry_label_report.md", out, "entry_target")
logging.info("trader.training.entry_labels_written runId=%s rowCount=%s", args.run_id, len(out))
def build_position_state_samples(args: Any) -> None:
root = run_root(args)
entry_path = args.entry_label_path or root / "label" / "entry_labels.parquet"
entry = read_parquet(entry_path)
if entry.empty:
raise ValueError("entry labels are required before building position samples")
samples = entry[entry["entry_target"] == 1].copy()
samples["position_age_minutes"] = 0
samples["unrealized_pnl_bps"] = 0.0
samples["mfe_bps"] = samples["gross_edge_bps"].clip(lower=0)
samples["mae_bps"] = (-samples["gross_edge_bps"]).clip(lower=0)
path = root / "label" / "position_state_samples.parquet"
data_hash = write_parquet(path, samples)
write_json(root / "label" / "position_state_samples.manifest.json", manifest(path, {"row_count": len(samples), "data_hash_sha256": data_hash}))
logging.info("trader.training.position_samples_written runId=%s rowCount=%s", args.run_id, len(samples))
def build_continue_exit_risk_labels(args: Any) -> None:
root = run_root(args)
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
plan = read_json(args.price_plan_context_path or root / "label" / "price_plan_context.json")
features, replay = _base_frames(args)
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
horizon = int(labels["continue"]["horizon_minutes"])
target_bps = float(plan["targetDistanceBps"])
stop_bps = float(plan["stopDistanceBps"])
rows_continue: list[dict[str, Any]] = []
rows_exit: list[dict[str, Any]] = []
rows_risk: list[dict[str, Any]] = []
groups, index_by_key = _group_replay_with_index(replay)
for feature in features.itertuples(index=False):
key = (feature.symbol, int(feature.open_time_ms))
index = index_by_key.get(key)
if index is None:
continue
group = groups[feature.symbol]
long_stats = _path_stats(group, index, "LONG", horizon, target_bps, stop_bps)
short_stats = _path_stats(group, index, "SHORT", horizon, target_bps, stop_bps)
if not long_stats["valid"] or not short_stats["valid"]:
continue
long_edge = long_stats["future_return_bps"] - cost_bps
short_edge = short_stats["future_return_bps"] - cost_bps
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
rows_continue.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"long_continue_target": int(long_edge >= min_continue and long_stats["mae_bps"] < stop_bps),
"short_continue_target": int(short_edge >= min_continue and short_stats["mae_bps"] < stop_bps),
"long_expected_continue_edge_bps": long_edge,
"short_expected_continue_edge_bps": short_edge,
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
stagnation = int(abs(long_stats["future_return_bps"]) <= float(labels["exit"]["stagnation_abs_return_bps"]))
rows_exit.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"long_exit_target": int(long_stats["stop_hit"] == 1 or long_stats["mae_bps"] >= adverse_threshold),
"short_exit_target": int(short_stats["stop_hit"] == 1 or short_stats["mae_bps"] >= adverse_threshold),
"long_adverse_move_bps": long_stats["mae_bps"],
"short_adverse_move_bps": short_stats["mae_bps"],
"adverse_move_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= adverse_threshold),
"reversal_prob_label": int(np.sign(long_stats["future_return_bps"]) != np.sign(feature.ret_15m_bps) if hasattr(feature, "ret_15m_bps") else 0),
"stop_hit_prob_label": int(long_stats["stop_hit"] == 1 or short_stats["stop_hit"] == 1),
"stagnation_prob_label": stagnation,
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
path_risk = max(long_stats["mae_bps"], short_stats["mae_bps"])
vol_ratio = 0.0 if long_stats["future_realized_vol_bps"] != long_stats["future_realized_vol_bps"] else long_stats["future_realized_vol_bps"]
rows_risk.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"market_risk_target": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
"market_path_risk_bps": path_risk,
"long_position_path_risk_bps": long_stats["mae_bps"],
"short_position_path_risk_bps": short_stats["mae_bps"],
"long_position_risk_target": int(long_stats["mae_bps"] >= stop_bps),
"short_position_risk_target": int(short_stats["mae_bps"] >= stop_bps),
"market_drawdown_prob_label": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
"volatility_expansion_prob_label": int(vol_ratio >= float(labels["risk"]["spike_bps"])),
"spike_prob_label": int(max(long_stats["mfe_bps"], short_stats["mfe_bps"], path_risk) >= float(labels["risk"]["spike_bps"])),
"liquidity_deterioration_prob_label": int(long_stats["future_spread_p80"] >= float(replay["spread_bps"].quantile(0.9))),
"position_drawdown_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= stop_bps),
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
outputs = [
("continue", pd.DataFrame(rows_continue), "long_continue_target"),
("exit", pd.DataFrame(rows_exit), "long_exit_target"),
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
]
report_parts = ["# Continue Exit Risk Label Report", ""]
for name, frame, target in outputs:
path = root / "label" / f"{name}_labels.parquet"
data_hash = write_parquet(path, frame)
_write_label_manifest(root / "label" / f"{name}_labels.manifest.json", path, frame, data_hash)
report_parts.append(f"## {name}")
report_parts.append("")
report_parts.append(str(frame[target].value_counts(dropna=False).to_dict() if not frame.empty else {}))
report_parts.append("")
logging.info("trader.training.%s_labels_written runId=%s rowCount=%s", name, args.run_id, len(frame))
write_text(root / "label" / "continue_exit_risk_label_report.md", "\n".join(report_parts) + "\n")
def _write_label_manifest(path, parquet_path, frame: pd.DataFrame, data_hash: str) -> None:
write_json(path, manifest(parquet_path, {"row_count": len(frame), "label_version": LABEL_VERSION, "data_hash_sha256": data_hash}))
def _write_distribution_report(path, frame: pd.DataFrame, column: str) -> None:
counts = frame[column].value_counts(dropna=False).to_dict() if not frame.empty else {}
lines = ["# Label Report", "", f"- row_count: {len(frame)}", f"- target_column: {column}", f"- distribution: {counts}", ""]
write_text(path, "\n".join(lines))
def _group_replay_with_index(replay: pd.DataFrame) -> tuple[dict[str, pd.DataFrame], dict[tuple[str, int], int]]:
groups: dict[str, pd.DataFrame] = {}
index_by_key: dict[tuple[str, int], int] = {}
for symbol, group in replay.groupby("symbol", sort=False):
grouped = group.sort_values("event_time").reset_index(drop=True)
groups[symbol] = grouped
for idx, row in grouped.iterrows():
index_by_key[(symbol, int(row["open_time_ms"]))] = idx
return groups, index_by_key
+73
View File
@@ -0,0 +1,73 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import numpy as np
@dataclass(frozen=True)
class LinearHead:
name: str
kind: str
weight: np.ndarray
bias: np.ndarray
def require_onnx():
try:
import onnx
from onnx import TensorProto, helper, numpy_helper
except ModuleNotFoundError as exc:
raise SystemExit("Python package 'onnx' is required. Install training/requirements.txt before export.") from exc
return onnx, TensorProto, helper, numpy_helper
def export_heads(path: Path, heads: list[LinearHead], feature_count: int = 39, opset: int = 17) -> None:
onnx, TensorProto, helper, numpy_helper = require_onnx()
nodes = []
initializers = []
concat_inputs = []
for idx, head in enumerate(heads):
weight = np.asarray(head.weight, dtype=np.float32)
bias = np.asarray(head.bias, dtype=np.float32).reshape(1, -1)
if weight.ndim == 1:
weight = weight.reshape(feature_count, 1)
weight_name = f"{head.name}_W"
bias_name = f"{head.name}_B"
linear_name = f"{head.name}_linear"
out_name = f"{head.name}_out"
initializers.append(numpy_helper.from_array(weight, weight_name))
initializers.append(numpy_helper.from_array(bias, bias_name))
nodes.append(helper.make_node("MatMul", ["features", weight_name], [f"{linear_name}_mm"], name=f"{head.name}_matmul"))
nodes.append(helper.make_node("Add", [f"{linear_name}_mm", bias_name], [linear_name], name=f"{head.name}_add"))
if head.kind == "sigmoid":
nodes.append(helper.make_node("Sigmoid", [linear_name], [out_name], name=f"{head.name}_sigmoid"))
elif head.kind == "softmax":
nodes.append(helper.make_node("Softmax", [linear_name], [out_name], name=f"{head.name}_softmax", axis=1))
elif head.kind == "identity":
out_name = linear_name
else:
raise ValueError(f"unsupported ONNX head kind: {head.kind}")
concat_inputs.append(out_name)
if len(concat_inputs) == 1:
nodes.append(helper.make_node("Identity", concat_inputs, ["prediction"], name="prediction_identity"))
else:
nodes.append(helper.make_node("Concat", concat_inputs, ["prediction"], name="prediction_concat", axis=1))
graph = helper.make_graph(
nodes,
"trader_v4_linear_heads",
[helper.make_tensor_value_info("features", TensorProto.FLOAT, [1, feature_count])],
[helper.make_tensor_value_info("prediction", TensorProto.FLOAT, [1, sum(_head_width(head) for head in heads)])],
initializer=initializers,
)
model = helper.make_model(graph, producer_name="trader-training", opset_imports=[helper.make_opsetid("", opset)])
model.ir_version = 10
onnx.checker.check_model(model)
path.parent.mkdir(parents=True, exist_ok=True)
onnx.save(model, path)
def _head_width(head: LinearHead) -> int:
bias = np.asarray(head.bias)
return int(bias.size)
+541
View File
@@ -0,0 +1,541 @@
from __future__ import annotations
import itertools
import logging
from typing import Any
import numpy as np
import pandas as pd
from trader_training.io_utils import read_json, read_parquet, run_root, sha256_json, write_json, write_parquet, write_text
from trader_training.schemas import LATEST_STRESS_SPLIT, PM_CONFIG_VERSION, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
def default_pm_config() -> dict:
return {
"pmConfigVersion": PM_CONFIG_VERSION,
"open": {
"longOpenProb": 0.58,
"shortOpenProb": 0.58,
"minLongEntryProb": 0.55,
"minShortEntryProb": 0.55,
"maxMarketRiskProb": 0.45,
"minExpectedEdgeBps": 3.0,
"minDirectionMargin": 0.03,
"minLiquidityCapacityRatio": 0.10,
"maxOodScore": 0.80,
},
"add": {
"minLongProb": 0.60,
"minShortProb": 0.60,
"minContinueProb": 0.58,
"minEntryProb": 0.55,
"maxExitProb": 0.45,
"maxMarketRiskProb": 0.45,
"maxPositionRiskProb": 0.50,
"minExpectedEdgeBps": 3.0,
"minContinueVsExitEdgeBps": 0.0,
"minLiquidityCapacityRatio": 0.10,
"minPostTradeLiquidationBufferBps": 500.0,
"maxAddCount": 3,
"cooldownMinutes": 5,
},
"exit": {
"closeExitProb": 0.70,
"closePositionRiskProb": 0.70,
"closeMarketRiskProb": 0.70,
"closeContinueMax": 0.25,
"reduceAdverseMoveProb": 0.62,
"reduceContinueMin": 0.35,
"reduceContinueMax": 0.70,
"minProfitForReduceBps": 5.0,
"maxPositionPathRiskBps": 80.0,
},
"sizing": {
"baseRatio": 0.80,
"minInitialRatio": 0.05,
"maxSingleLegRatio": 1.0,
"minAddRatio": 0.02,
"maxAddRatio": 0.25,
"maxTotalPositionRatio": 1.0,
"minEdgeBps": 3.0,
"maxLossPerTradeBps": 80.0,
"maxLiquidityUsageRatio": 0.20,
"uncertaintyPenaltyMultiplier": 0.50,
"minPostTradeLiquidationBufferBps": 500.0,
},
}
def search_pm_thresholds(args: Any) -> None:
root = run_root(args)
frame = _pm_tune_frame(root)
candidate_rows: list[dict[str, Any]] = []
best_score = -float("inf")
best_thresholds: dict[str, float] | None = None
best_metrics: dict[str, Any] | None = None
best_trades = pd.DataFrame()
for thresholds in _threshold_candidates():
trades = _simulate_open_trades(frame, thresholds)
metrics = _trade_metrics(trades)
score = _score_thresholds(metrics)
candidate_rows.append({**thresholds, **metrics, "score": score})
if score > best_score:
best_score = score
best_thresholds = thresholds
best_metrics = metrics
best_trades = trades
if best_thresholds is None or best_metrics is None:
raise ValueError("PM threshold search did not evaluate any candidate")
config = _pm_config_from_thresholds(best_thresholds)
threshold_stability = {
"source": "tune_predictions_and_entry_labels",
"method": "deterministic_grid_search_v1",
"candidate_count": len(candidate_rows),
"best_score": best_score,
"best_metrics": best_metrics,
}
payload = {
"pm_config_version": PM_CONFIG_VERSION,
"config": config,
"config_hash_sha256": sha256_json(config),
"threshold_stability_json": threshold_stability,
}
candidate_frame = pd.DataFrame(candidate_rows).sort_values("score", ascending=False).reset_index(drop=True)
equity_curve = _equity_curve(best_trades)
regime_metrics = _regime_metrics(best_trades)
write_json(root / "pm-search" / "position_manager_config.json", payload)
write_json(root / "pm-search" / "pm_threshold_config.json", payload)
write_text(root / "pm-search" / "pm_search_candidates.csv", candidate_frame.to_csv(index=False))
write_parquet(root / "pm-search" / "pm_backtest_trades.parquet", best_trades)
write_text(root / "pm-search" / "pm_equity_curve.csv", equity_curve.to_csv(index=False))
write_text(root / "pm-search" / "pm_regime_metrics.csv", regime_metrics.to_csv(index=False))
_write_pm_report(root / "pm-search" / "pm_threshold_report.md", candidate_frame, best_thresholds, best_metrics)
_write_pm_report(root / "pm-search" / "pm_search_report.md", candidate_frame, best_thresholds, best_metrics)
logging.info(
"trader.training.pm_thresholds_searched runId=%s candidateCount=%s bestScore=%.6f tradeCount=%s totalWeightedEdgeBps=%.6f",
args.run_id,
len(candidate_rows),
best_score,
best_metrics["trade_count"],
best_metrics["total_weighted_edge_bps"],
)
def integrated_backtest(args: Any) -> None:
root = run_root(args)
config_path = root / "pm-search" / "position_manager_config.json"
if not config_path.is_file():
raise FileNotFoundError(f"PM config is required before backtest: {config_path}")
pm_payload = read_json(config_path)
trades_path = root / "pm-search" / "pm_backtest_trades.parquet"
# PM search is allowed to use tune_inner, but final acceptance must be
# measured on the sealed validation_locked and latest_stress splits.
tune_trades = read_parquet(trades_path) if trades_path.is_file() else _simulate_open_trades(_pm_tune_frame(root), _thresholds_from_config(pm_payload["config"]))
tune_trades["eval_split"] = TUNE_SPLIT
validation_locked_trades = _simulate_open_trades(_pm_frame(root, VALIDATION_LOCKED_SPLIT), _thresholds_from_config(pm_payload["config"]))
validation_locked_trades["eval_split"] = VALIDATION_LOCKED_SPLIT
stress_trades = _simulate_open_trades(_pm_frame(root, LATEST_STRESS_SPLIT), _thresholds_from_config(pm_payload["config"]))
stress_trades["eval_split"] = LATEST_STRESS_SPLIT
trades = pd.concat([tune_trades, validation_locked_trades, stress_trades], ignore_index=True)
metrics = {
TUNE_SPLIT: _trade_metrics(tune_trades),
VALIDATION_LOCKED_SPLIT: _trade_metrics(validation_locked_trades),
LATEST_STRESS_SPLIT: _trade_metrics(stress_trades),
"combined": _trade_metrics(trades),
}
status, status_reasons = _backtest_status(metrics)
equity_curve = _equity_curve(trades)
regime_metrics = _regime_metrics(trades)
result = {
"backtest_manifest_id": f"backtest-{args.run_id}",
"mode": "VALIDATION_PM_BACKTEST",
"pm_config_hash_sha256": pm_payload["config_hash_sha256"],
"metrics": metrics,
"status_reasons": status_reasons,
"status": status,
}
write_json(root / "backtest" / "backtest_manifest.json", result)
write_parquet(root / "backtest" / "backtest_trades.parquet", trades)
write_text(root / "backtest" / "equity_curve.csv", equity_curve.to_csv(index=False))
write_text(root / "backtest" / "regime_metrics.csv", regime_metrics.to_csv(index=False))
_write_backtest_report(root / "backtest" / "backtest_report.md", result)
_write_failure_cases(root / "backtest" / "failure_cases.md", trades)
_write_no_baseline_ablation(root / "backtest" / "direction_ablation_backtest_report.md")
logging.info(
"trader.training.backtest_written runId=%s status=%s tradeCount=%s totalWeightedEdgeBps=%.6f maxDrawdownBps=%.6f",
args.run_id,
status,
metrics[VALIDATION_LOCKED_SPLIT]["trade_count"],
metrics[VALIDATION_LOCKED_SPLIT]["total_weighted_edge_bps"],
metrics[VALIDATION_LOCKED_SPLIT]["max_drawdown_bps"],
)
def _pm_tune_frame(root) -> pd.DataFrame:
return _pm_frame(root, TUNE_SPLIT)
def _pm_frame(root, split_id: str) -> pd.DataFrame:
prediction_files = {
TUNE_SPLIT: "tune_predictions.parquet",
VALIDATION_LOCKED_SPLIT: "validation_locked_predictions.parquet",
LATEST_STRESS_SPLIT: "latest_stress_predictions.parquet",
}
prediction_file = prediction_files[split_id]
direction = read_parquet(root / "model" / "direction" / prediction_file)
entry = read_parquet(root / "model" / "entry" / prediction_file).rename(
columns={
"long_expected_net_edge_bps": "pred_long_expected_net_edge_bps",
"short_expected_net_edge_bps": "pred_short_expected_net_edge_bps",
}
)
risk = read_parquet(root / "model" / "risk" / prediction_file)
entry_dataset = read_parquet(root / "dataset" / "entry_train.parquet").rename(
columns={
"long_expected_net_edge_bps": "actual_long_expected_net_edge_bps",
"short_expected_net_edge_bps": "actual_short_expected_net_edge_bps",
}
)
entry_cols = [
"sample_id",
"long_entry_prob",
"short_entry_prob",
"pred_long_expected_net_edge_bps",
"pred_short_expected_net_edge_bps",
]
risk_cols = ["sample_id", "market_risk_prob", "long_position_risk_prob", "short_position_risk_prob"]
actual_cols = ["sample_id", "actual_long_expected_net_edge_bps", "actual_short_expected_net_edge_bps", "long_entry_target", "short_entry_target"]
frame = (
direction[["sample_id", "symbol", "event_time", "split_id", "long_prob", "short_prob", "neutral_prob"]]
.merge(entry[entry_cols], on="sample_id", how="inner")
.merge(risk[risk_cols], on="sample_id", how="inner")
.merge(entry_dataset[actual_cols], on="sample_id", how="inner")
)
if frame.empty:
raise ValueError(f"PM frame is empty for {split_id}; check model predictions and entry dataset")
logging.info(
"trader.training.pm_frame_loaded splitId=%s rowCount=%s splitCounts=%s",
split_id,
len(frame),
frame["split_id"].value_counts().to_dict(),
)
return frame
def _threshold_candidates() -> list[dict[str, float]]:
values = itertools.product(
[0.54, 0.56, 0.58, 0.60],
[0.54, 0.56, 0.58, 0.60],
[0.50, 0.52, 0.55, 0.58],
[0.35, 0.45, 0.55],
[1.0, 2.0, 3.0, 5.0],
[0.02, 0.03, 0.05],
)
return [
{
"long_open_prob": long_prob,
"short_open_prob": short_prob,
"min_entry_prob": entry_prob,
"max_market_risk_prob": risk_prob,
"min_expected_edge_bps": edge_bps,
"min_direction_margin": margin,
}
for long_prob, short_prob, entry_prob, risk_prob, edge_bps, margin in values
]
def _simulate_open_trades(frame: pd.DataFrame, thresholds: dict[str, float]) -> pd.DataFrame:
long_mask = (
(frame["long_prob"] >= thresholds["long_open_prob"])
& ((frame["long_prob"] - frame["short_prob"]) >= thresholds["min_direction_margin"])
& (frame["long_entry_prob"] >= thresholds["min_entry_prob"])
& (frame["market_risk_prob"] <= thresholds["max_market_risk_prob"])
& (frame["pred_long_expected_net_edge_bps"] >= thresholds["min_expected_edge_bps"])
)
short_mask = (
(frame["short_prob"] >= thresholds["short_open_prob"])
& ((frame["short_prob"] - frame["long_prob"]) >= thresholds["min_direction_margin"])
& (frame["short_entry_prob"] >= thresholds["min_entry_prob"])
& (frame["market_risk_prob"] <= thresholds["max_market_risk_prob"])
& (frame["pred_short_expected_net_edge_bps"] >= thresholds["min_expected_edge_bps"])
)
long_score = frame["pred_long_expected_net_edge_bps"] + (frame["long_prob"] - frame["short_prob"]) * 10.0
short_score = frame["pred_short_expected_net_edge_bps"] + (frame["short_prob"] - frame["long_prob"]) * 10.0
side = np.where(long_mask & (~short_mask | (long_score >= short_score)), "LONG", np.where(short_mask, "SHORT", ""))
trades = frame.loc[side != ""].copy().reset_index(drop=True)
if trades.empty:
return _empty_trade_frame()
trades["side"] = side[side != ""]
is_long = trades["side"].eq("LONG")
trades["direction_prob"] = np.where(is_long, trades["long_prob"], trades["short_prob"])
trades["entry_prob"] = np.where(is_long, trades["long_entry_prob"], trades["short_entry_prob"])
trades["predicted_edge_bps"] = np.where(is_long, trades["pred_long_expected_net_edge_bps"], trades["pred_short_expected_net_edge_bps"])
trades["actual_edge_bps"] = np.where(is_long, trades["actual_long_expected_net_edge_bps"], trades["actual_short_expected_net_edge_bps"])
trades["entry_target"] = np.where(is_long, trades["long_entry_target"], trades["short_entry_target"])
trades["planned_ratio"] = _planned_ratio(trades["predicted_edge_bps"], trades["market_risk_prob"], thresholds["min_expected_edge_bps"])
trades["weighted_edge_bps"] = trades["actual_edge_bps"] * trades["planned_ratio"]
trades["threshold_hash"] = sha256_json(thresholds)[:16]
return trades[
[
"sample_id",
"symbol",
"event_time",
"split_id",
"side",
"direction_prob",
"entry_prob",
"market_risk_prob",
"predicted_edge_bps",
"actual_edge_bps",
"entry_target",
"planned_ratio",
"weighted_edge_bps",
"threshold_hash",
]
].sort_values("event_time")
def _empty_trade_frame() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"sample_id",
"symbol",
"event_time",
"split_id",
"side",
"direction_prob",
"entry_prob",
"market_risk_prob",
"predicted_edge_bps",
"actual_edge_bps",
"entry_target",
"planned_ratio",
"weighted_edge_bps",
"threshold_hash",
]
)
def _planned_ratio(predicted_edge: pd.Series, market_risk: pd.Series, min_edge: float) -> np.ndarray:
edge_strength = ((predicted_edge.astype(float) - min_edge) / 20.0).clip(lower=0.0, upper=1.5)
risk_discount = (1.0 - market_risk.astype(float)).clip(lower=0.0, upper=1.0)
return (edge_strength * risk_discount).clip(lower=0.05, upper=1.0).to_numpy()
def _trade_metrics(trades: pd.DataFrame) -> dict[str, Any]:
if trades.empty:
return {
"trade_count": 0,
"win_rate": 0.0,
"avg_actual_edge_bps": 0.0,
"avg_weighted_edge_bps": 0.0,
"total_weighted_edge_bps": 0.0,
"max_drawdown_bps": 0.0,
"avg_planned_ratio": 0.0,
"profit_factor": 0.0,
"max_consecutive_losses": 0,
}
equity = trades["weighted_edge_bps"].astype(float).cumsum()
drawdown = equity.cummax() - equity
gains = trades.loc[trades["weighted_edge_bps"] > 0, "weighted_edge_bps"].astype(float).sum()
losses = -trades.loc[trades["weighted_edge_bps"] < 0, "weighted_edge_bps"].astype(float).sum()
return {
"trade_count": int(len(trades)),
"win_rate": float((trades["actual_edge_bps"].astype(float) > 0).mean()),
"avg_actual_edge_bps": float(trades["actual_edge_bps"].astype(float).mean()),
"avg_weighted_edge_bps": float(trades["weighted_edge_bps"].astype(float).mean()),
"total_weighted_edge_bps": float(equity.iloc[-1]),
"max_drawdown_bps": float(drawdown.max()),
"avg_planned_ratio": float(trades["planned_ratio"].astype(float).mean()),
"profit_factor": float(gains / losses) if losses > 0 else float("inf"),
"max_consecutive_losses": _max_consecutive_losses(trades["weighted_edge_bps"].astype(float).to_numpy()),
}
def _max_consecutive_losses(values: np.ndarray) -> int:
max_count = 0
current = 0
for value in values:
if value < 0:
current += 1
max_count = max(max_count, current)
else:
current = 0
return max_count
def _backtest_status(metrics: dict[str, dict[str, Any]]) -> tuple[str, list[str]]:
reasons: list[str] = []
validation_locked = metrics[VALIDATION_LOCKED_SPLIT]
stress = metrics[LATEST_STRESS_SPLIT]
if validation_locked["total_weighted_edge_bps"] <= 0:
reasons.append("validation_locked_net_edge_not_positive")
if validation_locked["trade_count"] < 80:
reasons.append("validation_locked_trade_count_below_80")
if validation_locked["profit_factor"] < 1.15:
reasons.append("validation_locked_profit_factor_below_1.15")
if validation_locked["avg_weighted_edge_bps"] <= 0:
reasons.append("validation_locked_avg_trade_edge_not_positive")
if validation_locked["max_consecutive_losses"] > 8:
reasons.append("validation_locked_max_consecutive_losses_above_8")
if stress["trade_count"] < 20:
reasons.append("latest_stress_trade_count_below_20")
if stress["profit_factor"] < 1.0:
reasons.append("latest_stress_profit_factor_below_1.0")
if stress["avg_weighted_edge_bps"] < -3.0:
reasons.append("latest_stress_avg_trade_edge_below_minus_3")
if stress["max_consecutive_losses"] > 10:
reasons.append("latest_stress_max_consecutive_losses_above_10")
if validation_locked["total_weighted_edge_bps"] > 0 and stress["total_weighted_edge_bps"] < -0.5 * validation_locked["total_weighted_edge_bps"]:
reasons.append("latest_stress_loss_too_large_vs_validation")
return ("REJECTED", reasons) if reasons else ("PASS", [])
def _score_thresholds(metrics: dict[str, Any]) -> float:
if metrics["trade_count"] == 0:
return -1_000_000.0
low_sample_penalty = max(0, 20 - int(metrics["trade_count"])) * 5.0
return (
metrics["avg_weighted_edge_bps"] * np.sqrt(metrics["trade_count"])
+ metrics["total_weighted_edge_bps"] * 0.05
- metrics["max_drawdown_bps"] * 0.25
- low_sample_penalty
)
def _pm_config_from_thresholds(thresholds: dict[str, float]) -> dict:
config = default_pm_config()
config["open"].update(
{
"longOpenProb": thresholds["long_open_prob"],
"shortOpenProb": thresholds["short_open_prob"],
"minLongEntryProb": thresholds["min_entry_prob"],
"minShortEntryProb": thresholds["min_entry_prob"],
"maxMarketRiskProb": thresholds["max_market_risk_prob"],
"minExpectedEdgeBps": thresholds["min_expected_edge_bps"],
"minDirectionMargin": thresholds["min_direction_margin"],
}
)
config["add"]["maxMarketRiskProb"] = thresholds["max_market_risk_prob"]
config["add"]["minExpectedEdgeBps"] = thresholds["min_expected_edge_bps"]
config["sizing"]["minEdgeBps"] = thresholds["min_expected_edge_bps"]
config["sizing"]["maxSingleLegRatio"] = 1.0
return config
def _thresholds_from_config(config: dict) -> dict[str, float]:
open_config = config["open"]
return {
"long_open_prob": float(open_config["longOpenProb"]),
"short_open_prob": float(open_config["shortOpenProb"]),
"min_entry_prob": float(min(open_config["minLongEntryProb"], open_config["minShortEntryProb"])),
"max_market_risk_prob": float(open_config["maxMarketRiskProb"]),
"min_expected_edge_bps": float(open_config["minExpectedEdgeBps"]),
"min_direction_margin": float(open_config["minDirectionMargin"]),
}
def _equity_curve(trades: pd.DataFrame) -> pd.DataFrame:
if trades.empty:
return pd.DataFrame(columns=["event_time", "trade_index", "weighted_edge_bps", "equity_bps", "drawdown_bps"])
curve = trades[["event_time", "weighted_edge_bps"]].copy().reset_index(drop=True)
curve["trade_index"] = np.arange(1, len(curve) + 1)
curve["equity_bps"] = curve["weighted_edge_bps"].astype(float).cumsum()
curve["drawdown_bps"] = curve["equity_bps"].cummax() - curve["equity_bps"]
return curve[["event_time", "trade_index", "weighted_edge_bps", "equity_bps", "drawdown_bps"]]
def _regime_metrics(trades: pd.DataFrame) -> pd.DataFrame:
if trades.empty:
return pd.DataFrame(columns=["split_id", "side", "trade_count", "win_rate", "avg_actual_edge_bps", "total_weighted_edge_bps"])
rows = []
for (split_id, side), group in trades.groupby(["split_id", "side"], sort=True):
metrics = _trade_metrics(group)
rows.append(
{
"split_id": split_id,
"side": side,
"trade_count": metrics["trade_count"],
"win_rate": metrics["win_rate"],
"avg_actual_edge_bps": metrics["avg_actual_edge_bps"],
"total_weighted_edge_bps": metrics["total_weighted_edge_bps"],
}
)
return pd.DataFrame(rows)
def _write_pm_report(path, candidates: pd.DataFrame, best_thresholds: dict[str, float], best_metrics: dict[str, Any]) -> None:
top = candidates.head(10)
lines = [
"# PM Threshold Report",
"",
"本次不是固定写死阈值,而是在验证集上试一组可复现的阈值,选择净收益、回撤、交易数量综合更好的那组。",
"",
"## Best Thresholds",
"",
"```json",
str(best_thresholds).replace("'", '"'),
"```",
"",
"## Best Metrics",
"",
"```json",
str(best_metrics).replace("'", '"'),
"```",
"",
"## Top Candidates",
"",
_markdown_table(top.to_dict("records"), list(top.columns)),
"",
]
write_text(path, "\n".join(lines))
def _write_backtest_report(path, result: dict[str, Any]) -> None:
lines = [
"# Integrated Backtest Report",
"",
"这里用验证集模型输出和 PM 阈值生成交易明细,统计净收益、胜率、回撤和分段表现。",
"",
"```json",
str(result).replace("'", '"'),
"```",
"",
]
write_text(path, "\n".join(lines))
def _write_failure_cases(path, trades: pd.DataFrame) -> None:
worst = trades.sort_values("weighted_edge_bps").head(20) if not trades.empty else trades
lines = [
"# Backtest Failure Cases",
"",
"按加权净收益从差到好列出最差样本,方便回看特征、标签和阈值。",
"",
_markdown_table(worst.to_dict("records"), list(worst.columns)) if not worst.empty else "无交易样本。",
"",
]
write_text(path, "\n".join(lines))
def _write_no_baseline_ablation(path) -> None:
lines = [
"# Direction Ablation Backtest Report",
"",
"- status: NO_BASELINE",
"- reason: 当前 run 目录没有旧 Direction 基准模型包,所以首版不能做只替换 Direction 的消融回测。",
"- action: 后续版本必须拿上一版 ACTIVE 包做 baseline,再比较新 Direction 是否真的提升。",
"",
]
write_text(path, "\n".join(lines))
def _markdown_table(rows: list[dict[str, Any]], columns: list[str]) -> str:
lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"]
for row in rows:
lines.append("| " + " | ".join(str(row.get(column, "")) for column in columns) + " |")
return "\n".join(lines)
+129
View File
@@ -0,0 +1,129 @@
from __future__ import annotations
import logging
from typing import Any
from trader_training.io_utils import read_json, utc_now_text, write_json, write_text
def promote_artifact_bundle(args: Any) -> None:
artifact_root = args.artifact_root
validation_path = artifact_root.parent / "artifact_validation_result.json"
validation = read_json(validation_path)
if validation.get("status") != "PASS":
_refuse(artifact_root, args.reason, "validation result is not PASS", validation)
if validation.get("release_gate_status") != "PASS":
_refuse(artifact_root, args.reason, "release gate is not PASS", validation)
promotion = {
"promoted_at": utc_now_text(),
"reason": args.reason,
"validation_result_path": str(validation_path),
}
bundle_path = artifact_root / "manifests" / "model_bundle_manifest.json"
bundle = read_json(bundle_path)
model_manifest_path = artifact_root / "manifests" / "model_manifest.json"
model_rows = read_json(model_manifest_path)
calibration_manifest_path = artifact_root / "manifests" / "calibration_manifest.json"
calibration_rows = read_json(calibration_manifest_path)
pm_manifest_path = artifact_root / "manifests" / "position_manager_manifest.json"
pm_manifest = read_json(pm_manifest_path)
export_manifest_path = artifact_root / "manifests" / "training_export_manifest.json"
export_manifest = read_json(export_manifest_path)
# Check every manifest before writing any ACTIVE status, so a bad bundle
# cannot be left half-promoted if one file fails late.
_require_candidate("model_bundle_manifest", bundle.get("status"))
for row in model_rows:
_require_candidate(f"model_manifest.{row.get('model_name')}", row.get("status"))
for row in calibration_rows:
_require_candidate(f"calibration_manifest.{row.get('model_name')}", row.get("status"))
_require_candidate("position_manager_manifest", pm_manifest.get("status"))
_require_candidate("training_export_manifest", export_manifest.get("status"))
bundle["status"] = "ACTIVE"
bundle["promotion_json"] = promotion
for row in model_rows:
row["status"] = "ACTIVE"
row["promotion_json"] = promotion
for row in calibration_rows:
row["status"] = "ACTIVE"
row["promotion_json"] = promotion
pm_manifest["status"] = "ACTIVE"
pm_manifest["promotion_json"] = promotion
export_manifest["status"] = "ACTIVE"
export_manifest["promotion_json"] = promotion
write_json(bundle_path, bundle)
write_json(model_manifest_path, model_rows)
write_json(calibration_manifest_path, calibration_rows)
write_json(pm_manifest_path, pm_manifest)
write_json(export_manifest_path, export_manifest)
result = {"status": "ACTIVE", "artifact_root": str(artifact_root), "promotion": promotion}
write_json(artifact_root.parent / "artifact_promotion_result.json", result)
write_text(
artifact_root.parent / "artifact_promotion_report.md",
"\n".join(
[
"# Artifact Promotion Report",
"",
"- status: ACTIVE",
f"- artifact_root: {artifact_root}",
f"- reason: {args.reason}",
f"- promoted_at: {promotion['promoted_at']}",
"",
]
),
)
logging.info("trader.training.artifact_promoted status=ACTIVE path=%s reason=%s", artifact_root, args.reason)
def _require_candidate(name: str, status: str | None) -> None:
if status != "CANDIDATE":
raise SystemExit(f"artifact promotion refused: {name} status must be CANDIDATE, actual={status}")
def _refuse(artifact_root: Any, reason: str, message: str, validation: dict[str, Any]) -> None:
result = {
"status": "REFUSED",
"artifact_root": str(artifact_root),
"reason": reason,
"message": message,
"validation_status": validation.get("status"),
"release_gate_status": validation.get("release_gate_status"),
"release_gate_reasons": validation.get("release_gate_reasons", []),
"refused_at": utc_now_text(),
}
write_json(artifact_root.parent / "artifact_promotion_result.json", result)
write_text(
artifact_root.parent / "artifact_promotion_report.md",
"\n".join(
[
"# Artifact Promotion Report",
"",
"- status: REFUSED",
f"- artifact_root: {artifact_root}",
f"- reason: {reason}",
f"- message: {message}",
f"- validation_status: {result['validation_status']}",
f"- release_gate_status: {result['release_gate_status']}",
f"- release_gate_reasons: {result['release_gate_reasons']}",
f"- refused_at: {result['refused_at']}",
"",
]
),
)
logging.warning(
"trader.training.artifact_promotion_refused status=REFUSED path=%s reason=%s message=%s releaseGate=%s",
artifact_root,
reason,
message,
result["release_gate_status"],
)
raise SystemExit(f"artifact promotion refused: {message}, reasons={result['release_gate_reasons']}")
+496
View File
@@ -0,0 +1,496 @@
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from trader_training.io_utils import (
DEFAULT_RAW_ROOT,
ensure_dir,
manifest,
open_time_ms,
partition_files,
read_json,
read_partitioned_table,
read_parquet,
require_columns,
run_root,
to_utc_series,
utc_now_text,
write_json,
write_parquet,
write_text,
)
from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, SPLIT_VERSION, TRAINING_SPLITS, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
def audit_source_data(data_root: Path, symbol: str, start_date: str | None, end_date: str | None, min_ready_days: int = 250) -> dict[str, Any]:
raw_root = data_root / "crypto-lake" / "raw"
required_tables = ("candles", "trades", "level_1", "funding", "open_interest")
optional_tables = ("liquidations",)
rows: list[dict[str, Any]] = []
table_dates: dict[str, set[str]] = {}
for table in required_tables + optional_tables:
files = partition_files(raw_root, table, symbol, start_date, end_date)
dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files})
table_dates[table] = set(dates)
rows.append(
{
"table": table,
"required": table in required_tables,
"file_count": len(files),
"first_date": dates[0] if dates else None,
"last_date": dates[-1] if dates else None,
"status": "OK" if files or table in optional_tables else "MISSING",
}
)
all_dates = _audit_date_range(table_dates, required_tables, start_date, end_date)
replay_ready_days = []
excluded_days = []
for day in all_dates:
missing_required = [table for table in required_tables if day not in table_dates[table]]
missing_optional = [table for table in optional_tables if day not in table_dates[table]]
if missing_required:
excluded_days.append({"date": day, "reason": "MISSING_REQUIRED_TABLE", "missing_required_tables": missing_required, "missing_optional_tables": missing_optional})
else:
replay_ready_days.append(day)
result = {
"symbol": symbol,
"start_date": start_date,
"end_date": end_date,
"raw_root": str(raw_root),
"tables": rows,
"replay_ready_day_count": len(replay_ready_days),
"excluded_day_count": len(excluded_days),
"replay_ready_days": replay_ready_days,
"excluded_days": excluded_days,
"created_at": utc_now_text(),
"ready": all(row["status"] == "OK" for row in rows if row["required"]) and len(replay_ready_days) >= min_ready_days,
}
return result
def write_audit_outputs(args: Any) -> None:
root = run_root(args)
result = audit_source_data(args.data_root, args.symbol, args.start_date, args.end_date, int(args.min_ready_days))
path = root / "raw-manifest" / "source_data_audit.json"
write_json(path, result)
write_json(root / "raw-manifest" / "source_data_manifest.json", result)
write_json(root / "raw-manifest" / "excluded_days.json", result["excluded_days"])
write_text(root / "raw-manifest" / "replay_ready_days.txt", "\n".join(result["replay_ready_days"]) + ("\n" if result["replay_ready_days"] else ""))
report_lines = [
"# Trader Source Data Audit",
"",
f"- symbol: {result['symbol']}",
f"- raw_root: {result['raw_root']}",
f"- ready: {result['ready']}",
f"- replay_ready_day_count: {result['replay_ready_day_count']}",
f"- excluded_day_count: {result['excluded_day_count']}",
"",
"| table | required | file_count | first_date | last_date | status |",
"| --- | --- | ---: | --- | --- | --- |",
]
for row in result["tables"]:
report_lines.append(
f"| {row['table']} | {row['required']} | {row['file_count']} | {row['first_date']} | {row['last_date']} | {row['status']} |"
)
write_text(root / "raw-manifest" / "source_data_audit.md", "\n".join(report_lines) + "\n")
logging.info(
"trader.training.audit_written runId=%s ready=%s readyDays=%s excludedDays=%s path=%s",
args.run_id,
result["ready"],
result["replay_ready_day_count"],
result["excluded_day_count"],
path,
)
if not result["ready"]:
raise SystemExit("required raw tables are missing; see source_data_audit.md")
def _audit_date_range(table_dates: dict[str, set[str]], required_tables: tuple[str, ...], start_date: str | None, end_date: str | None) -> list[str]:
if start_date and end_date:
start = pd.Timestamp(start_date)
end = pd.Timestamp(end_date)
else:
dates = sorted(set().union(*(table_dates[table] for table in required_tables)))
if not dates:
return []
start = pd.Timestamp(start_date or dates[0])
end = pd.Timestamp(end_date or dates[-1])
return [day.strftime("%Y-%m-%d") for day in pd.date_range(start, end, freq="D")]
def _minute_frame(frame: pd.DataFrame, time_column: str = "origin_time") -> pd.DataFrame:
frame = frame.copy()
frame["event_time"] = to_utc_series(frame[time_column]).dt.floor("min")
frame["open_time_ms"] = open_time_ms(frame["event_time"])
return frame
def _read_candles(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
candles = read_partitioned_table(
raw_root,
"candles",
symbol,
start_date,
end_date,
columns=("origin_time", "start", "open", "high", "low", "close", "volume", "symbol"),
)
if candles.empty:
raise ValueError("candles raw data is required to build replay_1m")
time_col = "start" if "start" in candles.columns else "origin_time"
candles = _minute_frame(candles, time_col)
keep = ["symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "volume"]
candles = candles[keep].sort_values(["symbol", "event_time"]).drop_duplicates(["symbol", "event_time"], keep="last")
for column in ("open", "high", "low", "close", "volume"):
candles[column] = pd.to_numeric(candles[column], errors="coerce")
return candles
def _read_trades(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
trades = read_partitioned_table(
raw_root,
"trades",
symbol,
start_date,
end_date,
columns=("origin_time", "side", "quantity", "symbol"),
)
if trades.empty:
raise ValueError("trades raw data is required for taker imbalance")
trades = _minute_frame(trades)
trades["quantity"] = pd.to_numeric(trades["quantity"], errors="coerce").fillna(0.0)
side = trades["side"].astype(str).str.upper()
trades["taker_buy_volume"] = np.where(side.eq("BUY"), trades["quantity"], 0.0)
trades["taker_sell_volume"] = np.where(side.eq("SELL"), trades["quantity"], 0.0)
return (
trades.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[["taker_buy_volume", "taker_sell_volume"]]
.sum()
)
def _read_level1(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
level1 = read_partitioned_table(
raw_root,
"level_1",
symbol,
start_date,
end_date,
columns=("origin_time", "bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size", "symbol"),
)
if level1.empty:
raise ValueError("level_1 raw data is required for spread and OFI")
level1 = _minute_frame(level1)
for column in ("bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"):
level1[column] = pd.to_numeric(level1[column], errors="coerce")
level1 = level1.dropna(subset=["bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"])
level1 = level1.sort_values(["symbol", "event_time", "origin_time"])
group = level1.groupby("symbol", sort=False, observed=True)
prev_bid_price = group["bid_0_price"].shift(1)
prev_bid_size = group["bid_0_size"].shift(1)
prev_ask_price = group["ask_0_price"].shift(1)
prev_ask_size = group["ask_0_size"].shift(1)
bid_ofi = np.select(
[level1["bid_0_price"] > prev_bid_price, level1["bid_0_price"].eq(prev_bid_price)],
[level1["bid_0_size"], level1["bid_0_size"] - prev_bid_size],
default=-prev_bid_size,
)
ask_ofi = np.select(
[level1["ask_0_price"] < prev_ask_price, level1["ask_0_price"].eq(prev_ask_price)],
[level1["ask_0_size"], prev_ask_size - level1["ask_0_size"]],
default=-prev_ask_size,
)
level1["ofi_raw"] = np.nan_to_num(bid_ofi + ask_ofi, nan=0.0)
level1["depth"] = (level1["bid_0_size"] + level1["ask_0_size"]).clip(lower=1e-12)
level1["mid"] = (level1["bid_0_price"] + level1["ask_0_price"]) / 2.0
level1["spread_bps"] = (level1["ask_0_price"] - level1["bid_0_price"]) / level1["mid"] * 10000.0
agg = level1.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True).agg(
best_bid_price=("bid_0_price", "last"),
best_ask_price=("ask_0_price", "last"),
spread_bps=("spread_bps", "last"),
ofi_sum=("ofi_raw", "sum"),
depth_mean=("depth", "mean"),
)
agg["level1_ofi_1m"] = agg["ofi_sum"] / agg["depth_mean"].clip(lower=1e-12)
return agg.drop(columns=["ofi_sum", "depth_mean"])
def _read_liquidations(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
files = partition_files(raw_root, "liquidations", symbol, start_date, end_date)
if not files:
return pd.DataFrame(columns=["symbol", "event_time", "open_time_ms", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"])
liquidations = read_partitioned_table(
raw_root,
"liquidations",
symbol,
start_date,
end_date,
columns=("origin_time", "side", "quantity", "price", "symbol"),
)
liquidations = _minute_frame(liquidations)
liquidations["quantity"] = pd.to_numeric(liquidations["quantity"], errors="coerce").fillna(0.0)
liquidations["price"] = pd.to_numeric(liquidations["price"], errors="coerce").fillna(0.0)
liquidations["notional"] = liquidations["quantity"] * liquidations["price"]
side = liquidations["side"].astype(str).str.upper()
liquidations["liquidation_buy_notional_1m"] = np.where(side.eq("BUY"), liquidations["notional"], 0.0)
liquidations["liquidation_sell_notional_1m"] = np.where(side.eq("SELL"), liquidations["notional"], 0.0)
agg = liquidations.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[
["liquidation_buy_notional_1m", "liquidation_sell_notional_1m"]
].sum()
agg["liquidation_available"] = 1.0
return agg
def _asof_column(
replay: pd.DataFrame,
raw_root: Path,
table: str,
symbol: str,
start_date: str | None,
end_date: str | None,
columns: tuple[str, ...],
) -> pd.DataFrame:
frame = read_partitioned_table(raw_root, table, symbol, start_date, end_date, columns=("origin_time", "symbol", *columns))
if frame.empty:
raise ValueError(f"{table} raw data is required")
frame = _minute_frame(frame)
for column in columns:
if column.endswith("time"):
continue
frame[column] = pd.to_numeric(frame[column], errors="coerce")
frame = frame.sort_values(["symbol", "event_time"])
left = replay[["symbol", "event_time"]].sort_values(["symbol", "event_time"])
merged = pd.merge_asof(
left,
frame[["symbol", "event_time", *columns]].sort_values(["symbol", "event_time"]),
by="symbol",
on="event_time",
direction="backward",
tolerance=pd.Timedelta(hours=12),
)
return merged
def build_replay_1m(args: Any) -> None:
root = run_root(args)
raw_root = args.raw_root or DEFAULT_RAW_ROOT
logging.info("trader.training.replay_started runId=%s symbol=%s rawRoot=%s", args.run_id, args.symbol, raw_root)
replay = _read_candles(raw_root, args.symbol, args.start_date, args.end_date)
trades = _read_trades(raw_root, args.symbol, args.start_date, args.end_date)
level1 = _read_level1(raw_root, args.symbol, args.start_date, args.end_date)
liquidations = _read_liquidations(raw_root, args.symbol, args.start_date, args.end_date)
replay = replay.merge(trades, on=["symbol", "event_time", "open_time_ms"], how="left")
replay = replay.merge(level1, on=["symbol", "event_time", "open_time_ms"], how="left")
replay = replay.merge(liquidations, on=["symbol", "event_time", "open_time_ms"], how="left")
replay[["taker_buy_volume", "taker_sell_volume"]] = replay[["taker_buy_volume", "taker_sell_volume"]].fillna(0.0)
for column in ("liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"):
replay[column] = replay[column].fillna(0.0)
funding = _asof_column(replay, raw_root, "funding", args.symbol, args.start_date, args.end_date, ("rate", "mark_price", "index_price", "next_funding_time"))
funding = funding.rename(columns={"rate": "funding_rate"})
funding["funding_bps"] = pd.to_numeric(funding["funding_rate"], errors="coerce") * 10000.0
replay = replay.merge(funding.drop(columns=["funding_rate"]), on=["symbol", "event_time"], how="left")
replay["next_funding_time"] = to_utc_series(replay["next_funding_time"])
oi = _asof_column(replay, raw_root, "open_interest", args.symbol, args.start_date, args.end_date, ("open_interest",))
replay = replay.merge(oi, on=["symbol", "event_time"], how="left")
replay["timeframe"] = "1m"
replay["source_coverage"] = "crypto_lake_raw"
required = [
"open",
"high",
"low",
"close",
"volume",
"best_bid_price",
"best_ask_price",
"spread_bps",
"level1_ofi_1m",
"funding_bps",
"mark_price",
"index_price",
"open_interest",
]
replay["event_date"] = replay["event_time"].dt.strftime("%Y-%m-%d")
missing_required = replay[required].isna().any(axis=1)
day_quality = (
replay.assign(missing_required=missing_required.astype(int))
.groupby("event_date", as_index=False, observed=True)
.agg(row_count=("event_time", "count"), missing_required_rows=("missing_required", "sum"))
)
day_quality["ready"] = (day_quality["row_count"] >= int(args.min_minutes_per_day)) & day_quality["missing_required_rows"].eq(0)
ready_days = sorted(day_quality.loc[day_quality["ready"], "event_date"].astype(str).tolist())
excluded_days = [
{
"date": row.event_date,
"row_count": int(row.row_count),
"missing_required_rows": int(row.missing_required_rows),
"reason": "MISSING_REQUIRED_MARKET_FIELDS" if int(row.missing_required_rows) else "INCOMPLETE_MINUTE_COUNT",
}
for row in day_quality.loc[~day_quality["ready"]].itertuples(index=False)
]
if len(ready_days) < int(args.min_replay_ready_days):
write_json(root / "replay" / "excluded_days.json", excluded_days)
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + ("\n" if ready_days else ""))
raise ValueError(f"replay_1m has only {len(ready_days)} replay-ready days, required {args.min_replay_ready_days}")
before_filter = len(replay)
replay = replay[replay["event_date"].isin(ready_days)].copy()
logging.info(
"trader.training.replay_ready_days_selected runId=%s readyDays=%s excludedDays=%s rowBefore=%s rowAfter=%s",
args.run_id,
len(ready_days),
len(excluded_days),
before_filter,
len(replay),
)
columns = [
"symbol",
"timeframe",
"event_time",
"open_time_ms",
"open",
"high",
"low",
"close",
"volume",
"taker_buy_volume",
"taker_sell_volume",
"funding_bps",
"mark_price",
"index_price",
"next_funding_time",
"open_interest",
"best_bid_price",
"best_ask_price",
"spread_bps",
"level1_ofi_1m",
"liquidation_buy_notional_1m",
"liquidation_sell_notional_1m",
"liquidation_available",
"source_coverage",
]
replay = replay[columns].sort_values(["symbol", "event_time"]).reset_index(drop=True)
path = root / "replay" / "replay_1m.parquet"
data_hash = write_parquet(path, replay)
write_json(
root / "replay" / "replay_1m.manifest.json",
manifest(
path,
{
"row_count": len(replay),
"hash_sha256": data_hash,
"replay_ready_day_count": len(ready_days),
"excluded_day_count": len(excluded_days),
"min_minutes_per_day": int(args.min_minutes_per_day),
},
),
)
write_json(root / "replay" / "excluded_days.json", excluded_days)
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + "\n")
logging.info("trader.training.replay_written runId=%s rowCount=%s readyDays=%s path=%s", args.run_id, len(replay), len(ready_days), path)
def build_splits(args: Any) -> None:
root = run_root(args)
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
replay = read_parquet(replay_path)
require_columns(replay, ("event_time", "symbol"), "replay_1m")
replay["event_time"] = to_utc_series(replay["event_time"])
replay = replay.sort_values(["event_time", "symbol"]).reset_index(drop=True)
if len(replay) < 10:
raise ValueError("not enough replay rows to build time splits")
gap = int(args.gap_minutes)
intervals = _fixed_split_intervals(args, gap)
replay_start = replay["event_time"].min()
replay_end = replay["event_time"].max()
intervals = [
(split_id, max(start, replay_start), min(end, replay_end))
for split_id, start, end in intervals
if max(start, replay_start) <= min(end, replay_end)
]
if {item[0] for item in intervals} != set(TRAINING_SPLITS):
raise ValueError(f"fixed split dates do not fit replay coverage: replay_start={replay_start} replay_end={replay_end}")
split_manifest = {
"split_version": SPLIT_VERSION,
"created_at": utc_now_text(),
"source_replay_path": str(replay_path),
"gap_minutes": gap,
# Sealed splits are withheld from broad parameter search. They only answer
# whether a finished candidate survives final validation and recent stress.
"sealed_splits": [VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT],
"latest_stress_policy": "FINAL_GATE_ONLY",
"requested_splits": {
FIT_SPLIT: [args.fit_inner_start, args.fit_inner_end],
TUNE_SPLIT: [args.tune_inner_start, args.tune_inner_end],
VALIDATION_LOCKED_SPLIT: [args.validation_locked_start, args.validation_locked_end],
LATEST_STRESS_SPLIT: [args.latest_stress_start, args.latest_stress_end],
},
"splits": [
{"split_id": split_id, "start": start.isoformat().replace("+00:00", "Z"), "end": end.isoformat().replace("+00:00", "Z")}
for split_id, start, end in intervals
if start <= end
],
}
fold_count = max(1, int(args.fold_count))
fit_interval = next(item for item in intervals if item[0] == FIT_SPLIT)
tune_interval = next(item for item in intervals if item[0] == TUNE_SPLIT)
train_times = pd.Series(pd.date_range(fit_interval[1], fit_interval[2], periods=fold_count + 1))
folds = []
for idx in range(fold_count):
folds.append(
{
"walk_forward_fold": f"fold_{idx + 1:02d}",
"train_start": fit_interval[1].isoformat().replace("+00:00", "Z"),
"train_end": train_times.iloc[idx + 1].isoformat().replace("+00:00", "Z"),
"validation_start": tune_interval[1].isoformat().replace("+00:00", "Z"),
"validation_end": tune_interval[2].isoformat().replace("+00:00", "Z"),
}
)
ensure_dir(root / "split")
write_json(root / "split" / "split_manifest.json", split_manifest)
write_json(root / "split" / "walk_forward_folds.json", {"split_version": SPLIT_VERSION, "folds": folds})
_write_purge_embargo_report(root / "split" / "purge_embargo_report.md", intervals, gap)
logging.info("trader.training.splits_written runId=%s splitCount=%s foldCount=%s", args.run_id, len(split_manifest["splits"]), len(folds))
def assign_split(event_times: pd.Series, split_manifest_path: Path) -> pd.Series:
manifest_data = read_json(split_manifest_path)
result = pd.Series("NO_SPLIT", index=event_times.index, dtype="object")
values = to_utc_series(event_times)
for item in manifest_data["splits"]:
start = pd.Timestamp(item["start"])
end = pd.Timestamp(item["end"])
mask = values.between(start, end, inclusive="both")
result.loc[mask] = item["split_id"]
return result
def _fixed_split_intervals(args: Any, gap_minutes: int) -> list[tuple[str, pd.Timestamp, pd.Timestamp]]:
gap = pd.Timedelta(minutes=gap_minutes)
return [
(FIT_SPLIT, _start_of_day(args.fit_inner_start), _end_of_day(args.fit_inner_end) - gap),
(TUNE_SPLIT, _start_of_day(args.tune_inner_start) + gap, _end_of_day(args.tune_inner_end) - gap),
(VALIDATION_LOCKED_SPLIT, _start_of_day(args.validation_locked_start) + gap, _end_of_day(args.validation_locked_end) - gap),
(LATEST_STRESS_SPLIT, _start_of_day(args.latest_stress_start) + gap, _end_of_day(args.latest_stress_end)),
]
def _start_of_day(value: str) -> pd.Timestamp:
return pd.Timestamp(value, tz="UTC")
def _end_of_day(value: str) -> pd.Timestamp:
return pd.Timestamp(value, tz="UTC") + pd.Timedelta(days=1) - pd.Timedelta(minutes=1)
def _write_purge_embargo_report(path: Path, intervals: list[tuple[str, pd.Timestamp, pd.Timestamp]], gap_minutes: int) -> None:
lines = ["# Purge Embargo Report", "", f"- gap_minutes: {gap_minutes}", "", "| split_id | start | end |", "| --- | --- | --- |"]
for split_id, start, end in intervals:
lines.append(f"| {split_id} | {start.isoformat()} | {end.isoformat()} |")
write_text(path, "\n".join(lines) + "\n")
+206
View File
@@ -0,0 +1,206 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
FEATURE_VERSION = "feature-v4-p0"
LABEL_VERSION = "label-v4-p0"
SPLIT_VERSION = "split-v4-p0"
MODEL_BUNDLE_VERSION = "trader-v4-btc-p0"
CALIBRATION_BUNDLE_VERSION = "cal-v4-btc-p0"
PM_CONFIG_VERSION = "pm-v4-btc-p0"
OUTPUT_SCHEMA_VERSION = "output-schema-v4-btc-p0"
FIT_SPLIT = "fit_inner"
TUNE_SPLIT = "tune_inner"
VALIDATION_LOCKED_SPLIT = "validation_locked"
LATEST_STRESS_SPLIT = "latest_stress"
TRAINING_SPLITS = (FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT)
@dataclass(frozen=True)
class FeatureDef:
order: int
name: str
cn_name: str
meaning: str
source_tables: tuple[str, ...]
formula: str
lookback_window: str
unit: str
dtype: str
null_rule: str
live_available: bool
leakage_check: str
owner_models: tuple[str, ...]
def as_json(self) -> dict[str, Any]:
return {
"order": self.order,
"name": self.name,
"cn_name": self.cn_name,
"meaning": self.meaning,
"source_tables": list(self.source_tables),
"formula": self.formula,
"lookback_window": self.lookback_window,
"unit": self.unit,
"dtype": self.dtype,
"null_rule": self.null_rule,
"live_available": self.live_available,
"leakage_check": self.leakage_check,
"owner_models": list(self.owner_models),
}
ALL_MODELS = ("Direction", "Entry", "Continue", "Exit", "Risk")
FEATURES: tuple[FeatureDef, ...] = (
FeatureDef(1, "ret_1m_bps", "最近1分钟收益", "Latest short return.", ("replay_1m",), "close_t / close_t-1m - 1", "1m", "bps", "float32", "WARMUP", True, "uses <= t close only", ALL_MODELS),
FeatureDef(2, "ret_5m_bps", "最近5分钟收益", "Short trend.", ("replay_1m",), "close_t / close_t-5m - 1", "5m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Entry", "Continue", "Exit")),
FeatureDef(3, "ret_15m_bps", "最近15分钟收益", "Near trend.", ("replay_1m",), "close_t / close_t-15m - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Entry", "Continue", "Exit")),
FeatureDef(4, "ret_60m_bps", "最近60分钟收益", "Baseline trend.", ("replay_1m",), "close_t / close_t-60m - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Continue", "Exit", "Risk")),
FeatureDef(5, "ret_240m_bps", "最近240分钟收益", "Four-hour trend.", ("replay_1m",), "close_t / close_t-240m - 1", "240m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Continue", "Risk")),
FeatureDef(6, "realized_vol_15m_bps", "15分钟波动", "Near realized volatility.", ("replay_1m",), "std(log_return_1m, 15) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Entry", "Exit", "Risk")),
FeatureDef(7, "realized_vol_60m_bps", "60分钟波动", "Baseline realized volatility.", ("replay_1m",), "std(log_return_1m, 60) * 10000", "60m", "bps", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Entry", "Exit", "Risk")),
FeatureDef(8, "vol_ratio_15m_60m", "近端波动放大", "Near volatility versus baseline.", ("feature",), "realized_vol_15m_bps / max(realized_vol_60m_bps, 1)", "15m/60m", "ratio", "float32", "WARMUP", True, "derived from <= t features", ("Entry", "Exit", "Risk")),
FeatureDef(9, "range_15m_bps", "15分钟振幅", "Near high-low range.", ("replay_1m",), "max(high_15m) / min(low_15m) - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t high/low only", ("Entry", "Exit", "Risk")),
FeatureDef(10, "range_60m_bps", "60分钟振幅", "Baseline high-low range.", ("replay_1m",), "max(high_60m) / min(low_60m) - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t high/low only", ("Direction", "Entry", "Risk")),
FeatureDef(11, "volume_zscore_60m", "60分钟成交量异常", "Current volume abnormality.", ("replay_1m",), "(volume_t - mean(volume_60m)) / std(volume_60m)", "60m", "zscore", "float32", "std=0 -> 0", True, "uses <= t volume only", ("Direction", "Entry", "Risk")),
FeatureDef(12, "trend_consistency_15m", "15分钟方向连续性", "Signed return consistency.", ("replay_1m",), "mean(sign(ret_1m), 15)", "15m", "ratio", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Continue", "Exit")),
FeatureDef(13, "channel_position_60m_pct", "60分钟通道位置", "Close position in recent channel.", ("replay_1m",), "(close_t - low_60m) / max(high_60m - low_60m, tick)", "60m", "pct", "float32", "WARMUP", True, "uses <= t high/low/close only", ("Direction", "Entry", "Continue")),
FeatureDef(14, "upper_breakout_60m_bps", "向上突破距离", "Upper breakout distance.", ("replay_1m",), "max(0, close_t / prev_high_60m_excl_t - 1) * 10000", "60m", "bps", "float32", "WARMUP", True, "current close versus prior window only", ("Direction", "Entry", "Continue")),
FeatureDef(15, "lower_breakout_60m_bps", "向下跌破距离", "Lower breakdown distance.", ("replay_1m",), "max(0, prev_low_60m_excl_t / close_t - 1) * 10000", "60m", "bps", "float32", "WARMUP", True, "current close versus prior window only", ("Direction", "Entry", "Continue")),
FeatureDef(16, "upper_failed_break_reclaim_15m_bps", "上破失败回落", "Failed upper breakout reclaim.", ("replay_1m",), "if high_15m broke prior high then max(0, prev_high_60m - close_t) / close_t * 10000", "15m/60m", "bps", "float32", "no event -> 0", True, "prior high excludes t", ("Entry", "Exit", "Risk")),
FeatureDef(17, "lower_failed_break_reclaim_15m_bps", "下破失败收回", "Failed lower breakdown reclaim.", ("replay_1m",), "if low_15m broke prior low then max(0, close_t - prev_low_60m) / close_t * 10000", "15m/60m", "bps", "float32", "no event -> 0", True, "prior low excludes t", ("Entry", "Exit", "Risk")),
FeatureDef(18, "sweep_up_15m_bps", "上影扫高", "Upper sweep size.", ("replay_1m",), "max(0, max(high_15m) / close_t - 1) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t high/close only", ("Exit", "Risk")),
FeatureDef(19, "sweep_down_15m_bps", "下影扫低", "Lower sweep size.", ("replay_1m",), "max(0, close_t / min(low_15m) - 1) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t low/close only", ("Exit", "Risk")),
FeatureDef(20, "compression_score_4h_pct", "4小时压缩分位", "Higher means recent range is compressed.", ("feature",), "1 - percentile_rank(range_15m_bps over 240m)", "240m", "pct", "float32", "WARMUP", True, "rolling rank uses <= t", ("Direction", "Entry")),
FeatureDef(21, "compression_release_15m_bps", "压缩释放幅度", "Range release versus 4h median.", ("feature",), "max(0, range_15m_bps - median(range_15m_bps over 240m))", "15m/240m", "bps", "float32", "WARMUP", True, "rolling median uses <= t", ("Direction", "Entry", "Risk")),
FeatureDef(22, "taker_imbalance_1m", "1分钟主动买卖差", "Taker buy/sell imbalance.", ("trades", "replay_1m"), "(buy_1m - sell_1m) / max(total_1m, eps)", "1m", "ratio", "float32", "volume=0 -> 0", True, "uses current closed minute trades only", ("Direction", "Entry", "Continue")),
FeatureDef(23, "taker_imbalance_5m", "5分钟主动买卖差", "Short taker imbalance.", ("trades", "replay_1m"), "(buy_5m - sell_5m) / max(total_5m, eps)", "5m", "ratio", "float32", "WARMUP", True, "uses <= t trades only", ("Direction", "Entry", "Continue")),
FeatureDef(24, "taker_imbalance_15m", "15分钟主动买卖差", "Near taker imbalance.", ("trades", "replay_1m"), "(buy_15m - sell_15m) / max(total_15m, eps)", "15m", "ratio", "float32", "WARMUP", True, "uses <= t trades only", ("Direction", "Continue", "Exit")),
FeatureDef(25, "level1_ofi_1m", "1分钟盘口订单流", "Best bid/ask order-flow imbalance.", ("level_1", "replay_1m"), "sum(OFI changes in minute) / mean(level1 depth)", "1m", "ratio", "float32", "missing -> fail", True, "uses current closed minute L1 only", ("Direction", "Entry", "Risk")),
FeatureDef(26, "spread_bps", "买卖价差", "Best bid/ask spread.", ("level_1", "replay_1m"), "(best_ask - best_bid) / mid * 10000", "1m", "bps", "float32", "missing -> fail", True, "uses current closed minute L1 only", ("Entry", "Exit", "Risk")),
FeatureDef(27, "spread_rank_24h_pct", "24小时价差分位", "Spread congestion rank.", ("feature",), "percentile_rank(spread_bps over 24h)", "24h", "pct", "float32", "WARMUP", True, "rolling rank uses <= t", ("Entry", "Exit", "Risk")),
FeatureDef(28, "oi_delta_15m_bps", "15分钟持仓变化", "Open-interest short change.", ("open_interest", "replay_1m"), "open_interest_t / open_interest_t-15m - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t OI only", ("Direction", "Continue", "Risk")),
FeatureDef(29, "oi_delta_60m_bps", "60分钟持仓变化", "Open-interest baseline change.", ("open_interest", "replay_1m"), "open_interest_t / open_interest_t-60m - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t OI only", ("Direction", "Continue", "Risk")),
FeatureDef(30, "funding_bps", "资金费率", "Current funding rate.", ("funding", "replay_1m"), "rate * 10000", "as-of", "bps", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Direction", "Entry", "Risk")),
FeatureDef(31, "mark_index_basis_bps", "标记价指数价偏离", "Mark-index basis.", ("funding", "replay_1m"), "mark_price / index_price - 1", "as-of", "bps", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Direction", "Entry", "Risk")),
FeatureDef(32, "liquidation_buy_notional_1m", "1分钟买向爆仓金额", "Buy-side liquidation notional.", ("liquidations", "replay_1m"), "sum(quantity * price for BUY)", "1m", "quote", "float32", "missing partition -> 0 with flag", True, "uses current closed minute liquidations only", ("Entry", "Exit", "Risk")),
FeatureDef(33, "liquidation_sell_notional_1m", "1分钟卖向爆仓金额", "Sell-side liquidation notional.", ("liquidations", "replay_1m"), "sum(quantity * price for SELL)", "1m", "quote", "float32", "missing partition -> 0 with flag", True, "uses current closed minute liquidations only", ("Entry", "Exit", "Risk")),
FeatureDef(34, "liquidation_imbalance_15m", "15分钟爆仓方向差", "Liquidation imbalance.", ("liquidations", "replay_1m"), "(buy_15m - sell_15m) / max(total_15m, eps)", "15m", "ratio", "float32", "missing partition -> 0 with flag", True, "uses <= t liquidations only", ("Direction", "Entry", "Exit", "Risk")),
FeatureDef(35, "liquidation_notional_zscore_15m", "爆仓金额异常", "Liquidation notional zscore.", ("liquidations", "replay_1m"), "(liq_15m - mean_24h) / std_24h", "15m/24h", "zscore", "float32", "missing partition -> 0 with flag", True, "rolling window uses <= t", ("Entry", "Exit", "Risk")),
FeatureDef(36, "liquidation_available", "爆仓数据可用", "Whether liquidation data exists.", ("liquidations", "replay_1m"), "day partition exists", "day", "0/1", "float32", "never null", True, "partition availability known by event day", ("Entry", "Exit", "Risk")),
FeatureDef(37, "minute_of_day_sin", "日内时间正弦", "Time of day cyclic feature.", ("event_time",), "sin(2*pi*minute_of_day/1440)", "event_time", "ratio", "float32", "never null", True, "event timestamp only", ("Direction", "Entry", "Risk")),
FeatureDef(38, "minute_of_day_cos", "日内时间余弦", "Time of day cyclic feature.", ("event_time",), "cos(2*pi*minute_of_day/1440)", "event_time", "ratio", "float32", "never null", True, "event timestamp only", ("Direction", "Entry", "Risk")),
FeatureDef(39, "minutes_to_next_funding", "距离下次资金费分钟", "Minutes to next funding settlement.", ("funding", "replay_1m"), "clip((next_funding_time - event_time) / 60000, 0, 480)", "as-of", "minute", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Entry", "Continue", "Risk")),
)
FEATURE_ORDER = [feature.name for feature in FEATURES]
OUTPUT_SCHEMA: dict[str, Any] = {
"output_schema_version": OUTPUT_SCHEMA_VERSION,
"direction": {
"longProb": {"type": "decimal", "range": [0.0, 1.0]},
"shortProb": {"type": "decimal", "range": [0.0, 1.0]},
"neutralProb": {"type": "decimal", "range": [0.0, 1.0]},
"sum_rule": "longProb + shortProb + neutralProb must equal 1.0 within 0.000001",
},
"entry": {
"longEntryProb": {"type": "decimal", "range": [0.0, 1.0]},
"shortEntryProb": {"type": "decimal", "range": [0.0, 1.0]},
"longExpectedNetEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
"shortExpectedNetEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
},
"continuation": {
"longContinueProb": {"type": "decimal", "range": [0.0, 1.0]},
"shortContinueProb": {"type": "decimal", "range": [0.0, 1.0]},
"longExpectedContinueEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
"shortExpectedContinueEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
},
"exit": {
"longExitProb": {"type": "decimal", "range": [0.0, 1.0]},
"shortExitProb": {"type": "decimal", "range": [0.0, 1.0]},
"longAdverseMoveBps": {"type": "decimal", "range": [0.0, 500.0]},
"shortAdverseMoveBps": {"type": "decimal", "range": [0.0, 500.0]},
"exitReasonScores": {
"adverse_move_prob": {"type": "decimal", "range": [0.0, 1.0]},
"reversal_prob": {"type": "decimal", "range": [0.0, 1.0]},
"stop_hit_prob": {"type": "decimal", "range": [0.0, 1.0]},
"stagnation_prob": {"type": "decimal", "range": [0.0, 1.0]},
},
},
"risk": {
"marketRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
"longPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
"shortPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
"marketPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
"longPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
"shortPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
"riskReasonScores": {
"market_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]},
"volatility_expansion_prob": {"type": "decimal", "range": [0.0, 1.0]},
"spike_prob": {"type": "decimal", "range": [0.0, 1.0]},
"liquidity_deterioration_prob": {"type": "decimal", "range": [0.0, 1.0]},
"position_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]},
},
},
}
MODEL_OUTPUTS: dict[str, list[str]] = {
"DIRECTION": ["long_prob", "short_prob", "neutral_prob"],
"ENTRY": ["long_entry_prob", "short_entry_prob", "long_expected_net_edge_bps", "short_expected_net_edge_bps"],
"CONTINUE": ["long_continue_prob", "short_continue_prob", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"],
"EXIT": [
"long_exit_prob",
"short_exit_prob",
"long_adverse_move_bps",
"short_adverse_move_bps",
"adverse_move_prob",
"reversal_prob",
"stop_hit_prob",
"stagnation_prob",
],
"RISK": [
"market_risk_prob",
"long_position_risk_prob",
"short_position_risk_prob",
"market_path_risk_bps",
"long_position_path_risk_bps",
"short_position_path_risk_bps",
"market_drawdown_prob",
"volatility_expansion_prob",
"spike_prob",
"liquidity_deterioration_prob",
"position_drawdown_prob",
],
}
PROBABILITY_TARGET_NAMES: dict[str, list[str]] = {
"DIRECTION": ["longProb", "shortProb", "neutralProb"],
"ENTRY": ["longEntryProb", "shortEntryProb"],
"CONTINUE": ["longContinueProb", "shortContinueProb"],
"EXIT": ["longExitProb", "shortExitProb", "adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"],
"RISK": [
"marketRiskProb",
"longPositionRiskProb",
"shortPositionRiskProb",
"market_drawdown_prob",
"volatility_expansion_prob",
"spike_prob",
"liquidity_deterioration_prob",
"position_drawdown_prob",
],
}
OUTPUT_MAPPING: dict[str, dict[str, str]] = {
model: {field: f"prediction[{index}]" for index, field in enumerate(fields)}
for model, fields in MODEL_OUTPUTS.items()
}
+581
View File
@@ -0,0 +1,581 @@
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import accuracy_score, log_loss, mean_absolute_error, roc_auc_score
from sklearn.preprocessing import StandardScaler
from trader_training.io_utils import read_parquet, run_root, sha256_file, write_json, write_parquet, write_text
from trader_training.onnx_export import LinearHead, export_heads
from trader_training.schemas import FEATURE_ORDER, FIT_SPLIT, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, PROBABILITY_TARGET_NAMES, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
@dataclass
class HeadResult:
field: str
target_name: str | None
kind: str
weight: np.ndarray
bias: np.ndarray
metrics: dict[str, Any]
tune_prediction: np.ndarray
tune_target: np.ndarray | None
TARGETS = {
"DIRECTION": {
"dataset": "direction_train.parquet",
"heads": [("direction", "multiclass", ["long_target", "short_target", "neutral_target"], ["long_prob", "short_prob", "neutral_prob"], ["longProb", "shortProb", "neutralProb"])],
},
"ENTRY": {
"dataset": "entry_train.parquet",
"heads": [
("long_entry_prob", "binary", "long_entry_target", ["long_entry_prob"], ["longEntryProb"]),
("short_entry_prob", "binary", "short_entry_target", ["short_entry_prob"], ["shortEntryProb"]),
("long_expected_net_edge_bps", "regression", "long_expected_net_edge_bps", ["long_expected_net_edge_bps"], [None]),
("short_expected_net_edge_bps", "regression", "short_expected_net_edge_bps", ["short_expected_net_edge_bps"], [None]),
],
},
"CONTINUE": {
"dataset": "continue_train.parquet",
"heads": [
("long_continue_prob", "binary", "long_continue_target", ["long_continue_prob"], ["longContinueProb"]),
("short_continue_prob", "binary", "short_continue_target", ["short_continue_prob"], ["shortContinueProb"]),
("long_expected_continue_edge_bps", "regression", "long_expected_continue_edge_bps", ["long_expected_continue_edge_bps"], [None]),
("short_expected_continue_edge_bps", "regression", "short_expected_continue_edge_bps", ["short_expected_continue_edge_bps"], [None]),
],
},
"EXIT": {
"dataset": "exit_train.parquet",
"heads": [
("long_exit_prob", "binary", "long_exit_target", ["long_exit_prob"], ["longExitProb"]),
("short_exit_prob", "binary", "short_exit_target", ["short_exit_prob"], ["shortExitProb"]),
("long_adverse_move_bps", "regression", "long_adverse_move_bps", ["long_adverse_move_bps"], [None]),
("short_adverse_move_bps", "regression", "short_adverse_move_bps", ["short_adverse_move_bps"], [None]),
("adverse_move_prob", "binary", "adverse_move_prob_label", ["adverse_move_prob"], ["adverse_move_prob"]),
("reversal_prob", "binary", "reversal_prob_label", ["reversal_prob"], ["reversal_prob"]),
("stop_hit_prob", "binary", "stop_hit_prob_label", ["stop_hit_prob"], ["stop_hit_prob"]),
("stagnation_prob", "binary", "stagnation_prob_label", ["stagnation_prob"], ["stagnation_prob"]),
],
},
"RISK": {
"dataset": "risk_train.parquet",
"heads": [
("market_risk_prob", "binary", "market_risk_target", ["market_risk_prob"], ["marketRiskProb"]),
("long_position_risk_prob", "binary", "long_position_risk_target", ["long_position_risk_prob"], ["longPositionRiskProb"]),
("short_position_risk_prob", "binary", "short_position_risk_target", ["short_position_risk_prob"], ["shortPositionRiskProb"]),
("market_path_risk_bps", "regression", "market_path_risk_bps", ["market_path_risk_bps"], [None]),
("long_position_path_risk_bps", "regression", "long_position_path_risk_bps", ["long_position_path_risk_bps"], [None]),
("short_position_path_risk_bps", "regression", "short_position_path_risk_bps", ["short_position_path_risk_bps"], [None]),
("market_drawdown_prob", "binary", "market_drawdown_prob_label", ["market_drawdown_prob"], ["market_drawdown_prob"]),
("volatility_expansion_prob", "binary", "volatility_expansion_prob_label", ["volatility_expansion_prob"], ["volatility_expansion_prob"]),
("spike_prob", "binary", "spike_prob_label", ["spike_prob"], ["spike_prob"]),
("liquidity_deterioration_prob", "binary", "liquidity_deterioration_prob_label", ["liquidity_deterioration_prob"], ["liquidity_deterioration_prob"]),
("position_drawdown_prob", "binary", "position_drawdown_prob_label", ["position_drawdown_prob"], ["position_drawdown_prob"]),
],
},
}
def train_small_models(args: Any) -> None:
root = run_root(args)
model_manifest: dict[str, Any] = {}
for model_name, spec in TARGETS.items():
dataset = read_parquet(root / "dataset" / spec["dataset"])
if args.max_rows and len(dataset) > args.max_rows:
dataset = dataset.sort_values("event_time").tail(args.max_rows).copy()
if dataset.empty:
raise ValueError(f"dataset is empty for {model_name}")
train = dataset[dataset["split_id"] == FIT_SPLIT].copy()
tune = dataset[dataset["split_id"] == TUNE_SPLIT].copy()
validation_locked = dataset[dataset["split_id"] == VALIDATION_LOCKED_SPLIT].copy()
latest_stress = dataset[dataset["split_id"] == LATEST_STRESS_SPLIT].copy()
if train.empty or tune.empty:
raise ValueError(f"{model_name} needs {FIT_SPLIT} and {TUNE_SPLIT} rows")
logging.info(
"trader.training.model_dataset_loaded runId=%s model=%s totalRows=%s trainRows=%s tuneRows=%s validationLockedRows=%s latestStressRows=%s splitCounts=%s",
args.run_id,
model_name,
len(dataset),
len(train),
len(tune),
len(validation_locked),
len(latest_stress),
dataset["split_id"].value_counts().to_dict(),
)
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(train[FEATURE_ORDER].astype("float32"))
x_tune_scaled = scaler.transform(tune[FEATURE_ORDER].astype("float32"))
heads: list[LinearHead] = []
head_results: list[HeadResult] = []
for item in spec["heads"]:
head_results.extend(_fit_head(item, x_train_scaled, x_tune_scaled, train, tune, scaler))
for result in head_results:
logging.info(
"trader.training.model_head_trained runId=%s model=%s head=%s kind=%s metrics=%s",
args.run_id,
model_name,
result.field,
result.kind,
result.metrics,
)
for result in head_results:
heads.append(LinearHead(result.field, _onnx_kind(result.kind), result.weight, result.bias))
model_dir = root / "model" / model_name.lower()
model_path = model_dir / f"{model_name.lower()}.onnx"
export_heads(model_path, heads, feature_count=len(FEATURE_ORDER), opset=17)
predictions = _tune_prediction_frame(tune, head_results)
write_parquet(model_dir / "tune_predictions.parquet", predictions)
if not validation_locked.empty:
write_parquet(model_dir / "validation_locked_predictions.parquet", _predict_frame(validation_locked, head_results, include_labels=True))
if not latest_stress.empty:
write_parquet(model_dir / "latest_stress_predictions.parquet", _predict_frame(latest_stress, head_results, include_labels=False))
metrics = {result.field: result.metrics for result in head_results}
model_hash = sha256_file(model_path)
quality_status, quality_reasons = _model_quality(head_results)
write_json(
model_dir / "model_train_result.json",
{
"model_name": model_name,
"metrics": metrics,
"quality_status": quality_status,
"quality_reasons": quality_reasons,
"artifact_hash_sha256": model_hash,
},
)
write_json(
model_dir / "model_manifest.json",
{
"model_name": model_name,
"model_path": str(model_path),
"model_format": "ONNX",
"input_tensor_name": "features",
"input_feature_count": len(FEATURE_ORDER),
"output_tensor_name": "prediction",
"output_fields": MODEL_OUTPUTS[model_name],
"quality_status": quality_status,
"quality_reasons": quality_reasons,
"artifact_hash_sha256": model_hash,
},
)
_write_model_examples(model_dir, model_name, tune, predictions)
_write_feature_importance(model_dir / "feature_importance.csv", head_results)
_write_version_compare(model_dir, model_name, metrics)
_write_training_report(model_dir / "training_report.md", model_name, metrics, quality_status, quality_reasons)
model_manifest[model_name] = {"path": str(model_path), "hash_sha256": model_hash, "metrics": metrics, "quality_status": quality_status, "quality_reasons": quality_reasons}
logging.info(
"trader.training.model_trained runId=%s model=%s qualityStatus=%s qualityReasons=%s path=%s tunePredictionRows=%s featureImportancePath=%s",
args.run_id,
model_name,
quality_status,
quality_reasons,
model_path,
len(predictions),
model_dir / "feature_importance.csv",
)
write_json(root / "model" / "model_train_manifest.json", model_manifest)
def _fit_head(item, x_train, x_tune, train: pd.DataFrame, tune: pd.DataFrame, scaler: StandardScaler) -> list[HeadResult]:
name, kind, target, fields, target_names = item
if kind == "multiclass":
y_train = train[target].to_numpy().argmax(axis=1)
y_val = tune[target].to_numpy().argmax(axis=1)
model = LogisticRegression(max_iter=500)
model.fit(x_train, y_train)
proba = model.predict_proba(x_tune)
weight, bias = _fold_scaler(model.coef_.T, model.intercept_, scaler)
train_prior = train[target].to_numpy().mean(axis=0)
metrics = _multiclass_metrics(y_train, y_val, proba, train_prior)
return [HeadResult("direction", target_names[0], "softmax", weight, bias, metrics, proba, y_val)]
if kind == "binary":
y_train = pd.to_numeric(train[target], errors="coerce").fillna(0).astype(int).to_numpy()
y_val = pd.to_numeric(tune[target], errors="coerce").fillna(0).astype(int).to_numpy()
if len(np.unique(y_train)) < 2:
prevalence = float(np.clip(y_train.mean(), 1e-6, 1 - 1e-6))
coef = np.zeros((1, len(FEATURE_ORDER)), dtype=np.float32)
intercept = np.array([np.log(prevalence / (1 - prevalence))], dtype=np.float32)
proba = np.full(len(y_val), prevalence, dtype=np.float32)
else:
model = LogisticRegression(max_iter=500)
model.fit(x_train, y_train)
coef = model.coef_
intercept = model.intercept_
proba = model.predict_proba(x_tune)[:, 1]
weight, bias = _fold_scaler(coef.T, intercept, scaler)
metrics = _binary_metrics(y_train, y_val, proba)
if len(np.unique(y_val)) == 2:
metrics["auc"] = float(roc_auc_score(y_val, proba))
return [HeadResult(fields[0], target_names[0], "sigmoid", weight, bias, metrics, proba.reshape(-1, 1), y_val)]
if kind == "regression":
y_train = pd.to_numeric(train[target], errors="coerce").fillna(0.0).to_numpy()
y_val = pd.to_numeric(tune[target], errors="coerce").fillna(0.0).to_numpy()
model = Ridge(alpha=1.0)
model.fit(x_train, y_train)
pred = model.predict(x_tune)
weight, bias = _fold_scaler(model.coef_.reshape(1, -1).T, np.array([model.intercept_]), scaler)
return [HeadResult(fields[0], None, "identity", weight, bias, _regression_metrics(y_train, y_val, pred), pred.reshape(-1, 1), y_val)]
raise ValueError(f"unsupported head kind: {kind}")
def _fold_scaler(weight_scaled: np.ndarray, bias_scaled: np.ndarray, scaler: StandardScaler) -> tuple[np.ndarray, np.ndarray]:
scale = np.where(scaler.scale_ == 0, 1.0, scaler.scale_)
weight = weight_scaled / scale.reshape(-1, 1)
bias = bias_scaled - np.sum((scaler.mean_ / scale).reshape(-1, 1) * weight_scaled, axis=0)
return weight.astype(np.float32), bias.astype(np.float32)
def _onnx_kind(kind: str) -> str:
if kind in ("softmax", "sigmoid", "identity"):
return kind
raise ValueError(f"unsupported result kind: {kind}")
def _multiclass_metrics(y_train: np.ndarray, y_val: np.ndarray, proba: np.ndarray, train_prior: np.ndarray) -> dict[str, Any]:
one_hot = np.eye(proba.shape[1], dtype=float)[y_val]
train_prior = np.asarray(train_prior, dtype=float)
train_prior = train_prior / train_prior.sum() if train_prior.sum() > 0 else np.full(proba.shape[1], 1.0 / proba.shape[1])
constant = np.tile(train_prior.reshape(1, -1), (len(y_val), 1))
proba_for_logloss = _clip_normalize(proba)
constant_for_logloss = _clip_normalize(constant)
metrics: dict[str, Any] = {
"accuracy": float(accuracy_score(y_val, proba.argmax(axis=1))),
"logloss": float(log_loss(y_val, proba_for_logloss, labels=list(range(proba.shape[1])))),
"constant_logloss": float(log_loss(y_val, constant_for_logloss, labels=list(range(proba.shape[1])))),
"brier_multiclass": float(np.mean(np.sum((one_hot - proba) ** 2, axis=1))),
"constant_brier_multiclass": float(np.mean(np.sum((one_hot - constant) ** 2, axis=1))),
}
for idx, name in enumerate(("long", "short", "neutral")):
binary_target = (y_val == idx).astype(int)
positives = int(binary_target.sum())
negatives = int(len(binary_target) - positives)
if positives >= 200 and negatives >= 200:
metrics[f"{name}_auc"] = float(roc_auc_score(binary_target, proba[:, idx]))
else:
metrics[f"{name}_auc_status"] = "INSUFFICIENT_SAMPLE"
metrics[f"{name}_positive_count"] = positives
metrics[f"{name}_negative_count"] = negatives
max_prob = proba.max(axis=1)
predicted_class = proba.argmax(axis=1)
top_count = max(1, int(len(y_val) * 0.10))
top_idx = np.argsort(max_prob)[-top_count:]
metrics["top10_hit_rate"] = float((predicted_class[top_idx] == y_val[top_idx]).mean())
metrics["all_hit_rate"] = float((predicted_class == y_val).mean())
return _with_quality(metrics)
def _clip_normalize(values: np.ndarray) -> np.ndarray:
values = np.clip(np.asarray(values, dtype=float), 1e-6, 1.0)
return values / values.sum(axis=1, keepdims=True)
def _binary_metrics(y_train: np.ndarray, y_val: np.ndarray, proba: np.ndarray) -> dict[str, Any]:
proba = np.asarray(proba, dtype=float)
train_rate = float(np.mean(y_train)) if len(y_train) else 0.0
constant = np.full(len(y_val), train_rate)
metrics: dict[str, Any] = {
"positive_rate": train_rate,
"tune_positive_rate": float(np.mean(y_val)) if len(y_val) else 0.0,
"brier": float(np.mean((y_val - proba) ** 2)) if len(y_val) else 0.0,
"constant_brier": float(np.mean((y_val - constant) ** 2)) if len(y_val) else 0.0,
}
if len(y_val):
top_count = max(1, int(len(y_val) * 0.10))
top_idx = np.argsort(proba)[-top_count:]
metrics["top10_hit_rate"] = float(np.mean(y_val[top_idx]))
metrics["all_hit_rate"] = float(np.mean(y_val))
return _with_quality(metrics)
def _regression_metrics(y_train: np.ndarray, y_val: np.ndarray, pred: np.ndarray) -> dict[str, Any]:
mae = float(mean_absolute_error(y_val, pred))
train_std = float(np.std(y_train))
metrics: dict[str, Any] = {
"mae": mae,
"train_target_std": train_std,
"mae_vs_train_std_ratio": float(mae / train_std) if train_std > 0 else None,
}
return _with_quality(metrics)
def _with_quality(metrics: dict[str, Any]) -> dict[str, Any]:
reasons: list[str] = []
for key, value in metrics.items():
if key.endswith("_auc") and isinstance(value, float) and value < 0.53:
reasons.append(f"{key}_below_0.53")
if "brier" in metrics and metrics.get("constant_brier") is not None and metrics["brier"] >= metrics["constant_brier"]:
reasons.append("brier_not_better_than_constant")
if "brier_multiclass" in metrics and metrics["brier_multiclass"] >= metrics["constant_brier_multiclass"]:
reasons.append("brier_not_better_than_constant")
if "mae" in metrics and metrics.get("train_target_std") is not None and metrics["train_target_std"] > 0 and metrics["mae"] > metrics["train_target_std"]:
reasons.append("mae_above_train_target_std")
if "top10_hit_rate" in metrics and "all_hit_rate" in metrics and metrics["top10_hit_rate"] <= metrics["all_hit_rate"]:
reasons.append("top10_not_better_than_all")
metrics["quality_status"] = "REJECTED" if reasons else "PASS"
metrics["quality_reasons"] = reasons
return metrics
def _model_quality(results: list[HeadResult]) -> tuple[str, list[str]]:
reasons = []
for result in results:
if result.metrics.get("quality_status") == "REJECTED":
for reason in result.metrics.get("quality_reasons", []):
reasons.append(f"{result.field}:{reason}")
return ("REJECTED", reasons) if reasons else ("PASS", [])
def _tune_prediction_frame(tune: pd.DataFrame, results: list[HeadResult]) -> pd.DataFrame:
out = tune[["sample_id", "symbol", "event_time", "split_id"]].copy().reset_index(drop=True)
for result in results:
values = result.tune_prediction
if result.kind == "softmax":
for idx, field in enumerate(MODEL_OUTPUTS["DIRECTION"]):
out[field] = values[:, idx]
if result.tune_target is not None:
out["label__longProb"] = (result.tune_target == 0).astype(int)
out["label__shortProb"] = (result.tune_target == 1).astype(int)
out["label__neutralProb"] = (result.tune_target == 2).astype(int)
else:
out[result.field] = values.reshape(-1)
if result.kind != "softmax" and result.target_name and result.tune_target is not None:
out[f"label__{result.target_name}"] = result.tune_target
return out
def _predict_frame(frame: pd.DataFrame, results: list[HeadResult], include_labels: bool) -> pd.DataFrame:
out = frame[["sample_id", "symbol", "event_time", "split_id"]].copy().reset_index(drop=True)
features = frame[FEATURE_ORDER].astype("float32").to_numpy()
for result in results:
values = features @ result.weight + result.bias.reshape(1, -1)
if result.kind == "softmax":
values = _softmax(values)
for idx, field in enumerate(MODEL_OUTPUTS["DIRECTION"]):
out[field] = values[:, idx]
elif result.kind == "sigmoid":
out[result.field] = (1.0 / (1.0 + np.exp(-values))).reshape(-1)
else:
out[result.field] = values.reshape(-1)
if include_labels and result.kind != "softmax" and result.target_name and result.target_name in frame.columns:
out[f"label__{result.target_name}"] = frame[result.target_name].to_numpy()
return out
def _softmax(values: np.ndarray) -> np.ndarray:
shifted = values - np.max(values, axis=1, keepdims=True)
exp = np.exp(shifted)
return exp / exp.sum(axis=1, keepdims=True)
def _write_training_report(path: Path, model_name: str, metrics: dict[str, Any], quality_status: str, quality_reasons: list[str]) -> None:
lines = [
"# Trader Model Training Report",
"",
f"- model: {model_name}",
f"- quality_status: {quality_status}",
f"- quality_reasons: {quality_reasons}",
"",
"```json",
json.dumps(metrics, indent=2, sort_keys=True),
"```",
"",
]
write_text(path, "\n".join(lines))
def _write_model_examples(model_dir: Path, model_name: str, tune: pd.DataFrame, predictions: pd.DataFrame) -> None:
sample_input = {feature: float(tune.iloc[0][feature]) for feature in FEATURE_ORDER}
sample_output = {field: float(predictions.iloc[0][field]) for field in MODEL_OUTPUTS[model_name]}
write_json(model_dir / "sample_input.json", sample_input)
write_json(model_dir / "sample_output.json", sample_output)
def _write_feature_importance(path: Path, results: list[HeadResult]) -> None:
rows = []
for result in results:
importance = np.mean(np.abs(result.weight), axis=1)
for feature, value in zip(FEATURE_ORDER, importance):
rows.append({"head": result.field, "feature": feature, "abs_weight": float(value)})
frame = pd.DataFrame(rows).sort_values(["head", "abs_weight"], ascending=[True, False])
write_text(path, frame.to_csv(index=False))
def _write_version_compare(model_dir: Path, model_name: str, metrics: dict[str, Any]) -> None:
payload = {
"model_name": model_name,
"status": "NO_BASELINE",
"reason": "first executable V4 training chain has no previous approved artifact bundle under this run root",
"current_metrics": metrics,
}
write_json(model_dir / "version_compare_metrics.json", payload)
write_text(model_dir / "version_compare_by_regime.csv", "regime,status,reason\nNO_BASELINE,NO_BASELINE,no previous approved artifact bundle\n")
write_text(model_dir / "version_compare_top_bucket.csv", "bucket,status,reason\nNO_BASELINE,NO_BASELINE,no previous approved artifact bundle\n")
lines = [
"# Version Compare Report",
"",
f"- model: {model_name}",
"- status: NO_BASELINE",
"- reason: 当前 run 目录没有上一版已验收模型包,所以首版只能记录当前指标,不能做新旧优劣判断。",
"",
]
write_text(model_dir / "version_compare_report.md", "\n".join(lines))
def build_calibrators(args: Any) -> None:
root = run_root(args)
manifest_rows = []
for model_name, target_names in PROBABILITY_TARGET_NAMES.items():
prediction_path = root / "model" / model_name.lower() / "tune_predictions.parquet"
predictions = read_parquet(prediction_path)
targets = {}
reliability_rows = []
quality_reasons = []
for target_name in target_names:
raw_field = _target_to_raw_field(model_name, target_name)
label_field = f"label__{target_name}"
labels = predictions[label_field].to_numpy() if label_field in predictions.columns else None
raw = predictions[raw_field].to_numpy()
bins = _calibration_bins(raw, labels)
metrics, rows = _calibration_metrics(raw, labels, bins, target_name)
targets[target_name] = {"bins": bins, "metrics": metrics}
reliability_rows.extend(rows)
if metrics.get("quality_status") == "REJECTED":
quality_reasons.append(f"{target_name}:{metrics.get('quality_reason')}")
calibrator = {
"calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0",
"method": "BINNING",
"targets": targets,
"clip": {"min": 0.0, "max": 1.0},
"fallback_policy": "FAIL_FAST",
}
path = root / "calibration" / model_name.lower() / "calibrator.json"
cal_hash = write_json(path, calibrator)
quality_status = "REJECTED" if quality_reasons else "PASS"
manifest_rows.append(
{
"model_name": model_name,
"calibrator_path": str(path),
"calibrator_hash_sha256": cal_hash,
"target_count": len(targets),
"quality_status": quality_status,
"quality_reasons": quality_reasons,
}
)
write_text(root / "calibration" / model_name.lower() / "reliability_curve.csv", pd.DataFrame(reliability_rows).to_csv(index=False))
_write_calibration_report(root / "calibration" / model_name.lower() / "calibration_report.md", model_name, targets, quality_status, quality_reasons)
logging.info("trader.training.calibrator_written runId=%s model=%s path=%s", args.run_id, model_name, path)
write_json(root / "calibration" / "calibration_train_manifest.json", {"calibrators": manifest_rows})
def _target_to_raw_field(model_name: str, target_name: str) -> str:
mapping = {
"longProb": "long_prob",
"shortProb": "short_prob",
"neutralProb": "neutral_prob",
"longEntryProb": "long_entry_prob",
"shortEntryProb": "short_entry_prob",
"longContinueProb": "long_continue_prob",
"shortContinueProb": "short_continue_prob",
"longExitProb": "long_exit_prob",
"shortExitProb": "short_exit_prob",
"marketRiskProb": "market_risk_prob",
"longPositionRiskProb": "long_position_risk_prob",
"shortPositionRiskProb": "short_position_risk_prob",
}
return mapping.get(target_name, target_name)
def _calibration_bins(raw: np.ndarray, labels: np.ndarray | None) -> list[dict[str, float]]:
raw = np.asarray(raw, dtype=float)
raw = np.clip(np.nan_to_num(raw, nan=0.5), 0.0, 1.0)
if labels is None or len(labels) != len(raw):
return [{"min": 0.0, "max": 1.0, "calibrated": 0.5}]
labels = np.asarray(labels, dtype=float)
bins = []
edges = np.linspace(0.0, 1.0, 11)
for left, right in zip(edges[:-1], edges[1:]):
if right == 1.0:
mask = (raw >= left) & (raw <= right)
else:
mask = (raw >= left) & (raw < right)
calibrated = float(labels[mask].mean()) if mask.any() else float((left + right) / 2.0)
bins.append({"min": float(left), "max": float(right), "calibrated": float(np.clip(calibrated, 0.0, 1.0))})
return bins
def _apply_calibration(raw: np.ndarray, bins: list[dict[str, float]]) -> np.ndarray:
out = np.zeros_like(raw, dtype=float)
for item in bins:
left = float(item["min"])
right = float(item["max"])
if right >= 1.0:
mask = (raw >= left) & (raw <= right)
else:
mask = (raw >= left) & (raw < right)
out[mask] = float(item["calibrated"])
return np.clip(out, 0.0, 1.0)
def _calibration_metrics(raw: np.ndarray, labels: np.ndarray | None, bins: list[dict[str, float]], target_name: str) -> tuple[dict[str, Any], list[dict[str, Any]]]:
raw = np.clip(np.asarray(raw, dtype=float), 0.0, 1.0)
if labels is None or len(labels) != len(raw):
return {"quality_status": "REJECTED", "quality_reason": "missing_labels"}, []
labels = np.asarray(labels, dtype=float)
calibrated = _apply_calibration(raw, bins)
raw_ece, rows = _ece(raw, labels, target_name, "raw")
calibrated_ece, calibrated_rows = _ece(calibrated, labels, target_name, "calibrated")
rows.extend(calibrated_rows)
quality_status = "PASS" if calibrated_ece <= raw_ece else "REJECTED"
return (
{
"raw_ece": raw_ece,
"calibrated_ece": calibrated_ece,
"quality_status": quality_status,
"quality_reason": None if quality_status == "PASS" else "calibrated_ece_not_improved",
},
rows,
)
def _ece(values: np.ndarray, labels: np.ndarray, target_name: str, series_name: str) -> tuple[float, list[dict[str, Any]]]:
rows = []
total = len(values)
ece = 0.0
edges = np.linspace(0.0, 1.0, 11)
for left, right in zip(edges[:-1], edges[1:]):
mask = (values >= left) & (values <= right) if right >= 1.0 else (values >= left) & (values < right)
count = int(mask.sum())
confidence = float(values[mask].mean()) if count else float((left + right) / 2.0)
accuracy = float(labels[mask].mean()) if count else 0.0
ece += (count / total) * abs(confidence - accuracy) if total else 0.0
rows.append(
{
"target": target_name,
"series": series_name,
"bin_min": left,
"bin_max": right,
"count": count,
"confidence": confidence,
"accuracy": accuracy,
}
)
return float(ece), rows
def _write_calibration_report(path: Path, model_name: str, targets: dict[str, Any], quality_status: str, quality_reasons: list[str]) -> None:
lines = ["# Trader Calibration Report", "", f"- model: {model_name}", f"- target_count: {len(targets)}", f"- quality_status: {quality_status}", f"- quality_reasons: {quality_reasons}", ""]
for target_name, payload in targets.items():
lines.append(f"## {target_name}")
lines.append("")
lines.append("```json")
lines.append(json.dumps(payload.get("metrics", {}), indent=2, sort_keys=True))
lines.append("```")
lines.append("")
write_text(path, "\n".join(lines))
+149
View File
@@ -0,0 +1,149 @@
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import numpy as np
from trader_training.io_utils import read_json, sha256_file, write_json, write_text
from trader_training.schemas import FEATURE_ORDER, MODEL_OUTPUTS, OUTPUT_MAPPING
def validate_artifact_bundle(args: Any) -> None:
root = args.artifact_root
errors: list[str] = []
release_gate = {"release_gate_status": "UNKNOWN", "release_gate_reasons": ["artifact content not checked"]}
required = [
"manifests/model_bundle_manifest.json",
"manifests/model_manifest.json",
"manifests/calibration_manifest.json",
"manifests/position_manager_manifest.json",
"schemas/feature_schema.json",
"schemas/feature_order.json",
"schemas/output_schema.json",
"price_plan_context.json",
"examples/sample_input.json",
"examples/sample_output.json",
]
for relative in required:
if not (root / relative).is_file():
errors.append(f"missing file: {relative}")
if not errors:
_validate_content(root, errors, args.require_active, args.run_onnx)
# Structure validation only proves the bundle is loadable. The release
# gate is the separate business decision for whether it may become ACTIVE.
release_gate = _release_gate(root)
if args.require_active and release_gate["release_gate_status"] != "PASS":
errors.append(f"release gate must be PASS for ACTIVE, actual={release_gate['release_gate_status']}")
status = "PASS" if not errors else "FAIL"
result = {"status": status, "error_count": len(errors), "errors": errors, "artifact_root": str(root), **release_gate}
output_root = root.parent
write_json(output_root / "artifact_validation_result.json", result)
lines = ["# Artifact Validation Report", "", f"- status: {status}", f"- error_count: {len(errors)}", ""]
for error in errors:
lines.append(f"- {error}")
write_text(output_root / "artifact_validation_report.md", "\n".join(lines) + "\n")
logging.info("trader.training.artifact_validated status=%s errorCount=%s path=%s", status, len(errors), root)
if errors:
raise SystemExit("artifact validation failed; see artifact_validation_report.md")
def _validate_content(root: Path, errors: list[str], require_active: bool, run_onnx: bool) -> None:
feature_order = read_json(root / "schemas/feature_order.json")
if feature_order != FEATURE_ORDER:
errors.append("feature_order.json does not match V4 39-feature order")
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
if require_active and model_bundle.get("status") != "ACTIVE":
errors.append("model_bundle_manifest.status must be ACTIVE for Java SHADOW")
if model_bundle.get("status") not in {"CANDIDATE", "ACTIVE"}:
errors.append("model_bundle_manifest.status must be CANDIDATE or ACTIVE")
manifests = read_json(root / "manifests/model_manifest.json")
if len(manifests) != 5:
errors.append("model_manifest.json must contain exactly five logical models")
seen = {item.get("model_type") for item in manifests}
if seen != set(MODEL_OUTPUTS):
errors.append(f"model types mismatch: {seen}")
for item in manifests:
model_type = item.get("model_type")
if item.get("model_format") != "ONNX":
errors.append(f"{model_type} model_format must be ONNX")
if item.get("input_tensor_name") != "features":
errors.append(f"{model_type} input tensor must be features")
if item.get("input_shape_json", {}).get("features") != 39:
errors.append(f"{model_type} input_shape_json.features must be 39")
if item.get("onnx_opset_version") != 17:
errors.append(f"{model_type} opset must be 17")
if item.get("output_mapping_json") != OUTPUT_MAPPING.get(model_type):
errors.append(f"{model_type} output_mapping_json does not match Java contract")
_check_hash(root, item.get("artifact_path"), item.get("artifact_hash_sha256"), errors)
_check_hash(root, item.get("feature_schema_path"), item.get("feature_schema_hash"), errors)
_check_hash(root, item.get("feature_order_path"), item.get("feature_order_hash"), errors)
_check_hash(root, item.get("output_schema_path"), item.get("output_schema_hash"), errors)
if require_active and item.get("status") != "ACTIVE":
errors.append(f"{model_type} status must be ACTIVE for Java SHADOW")
if item.get("status") != model_bundle.get("status"):
errors.append(f"{model_type} status does not match model_bundle_manifest.status")
calibrators = read_json(root / "manifests/calibration_manifest.json")
if len(calibrators) != 5:
errors.append("calibration_manifest.json must contain five calibrators")
for item in calibrators:
_check_hash(root, item.get("calibrator_path"), item.get("calibrator_hash_sha256"), errors)
if require_active and item.get("status") != "ACTIVE":
errors.append(f"{item.get('model_name')} calibrator status must be ACTIVE")
if item.get("status") != model_bundle.get("status"):
errors.append(f"{item.get('model_name')} calibrator status does not match model_bundle_manifest.status")
pm_manifest = read_json(root / "manifests/position_manager_manifest.json")
if require_active and pm_manifest.get("status") != "ACTIVE":
errors.append("position_manager_manifest.status must be ACTIVE for Java SHADOW")
if pm_manifest.get("status") != model_bundle.get("status"):
errors.append("position_manager_manifest.status does not match model_bundle_manifest.status")
if run_onnx and not errors:
_run_sample_inference(root, manifests, errors)
def _release_gate(root: Path) -> dict[str, Any]:
reasons: list[str] = []
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
if model_bundle.get("backtest_status") != "PASS":
reasons.append(f"backtest_status={model_bundle.get('backtest_status')}")
reasons.extend(model_bundle.get("backtest_status_reasons_json", []))
for item in read_json(root / "manifests/model_manifest.json"):
if item.get("quality_status") != "PASS":
reasons.append(f"{item.get('model_type')}.quality_status={item.get('quality_status')}")
reasons.extend([f"{item.get('model_type')}:{reason}" for reason in item.get("quality_reasons_json", [])])
for item in read_json(root / "manifests/calibration_manifest.json"):
if item.get("quality_status") != "PASS":
reasons.append(f"{item.get('model_name')}.calibration_quality_status={item.get('quality_status')}")
reasons.extend([f"{item.get('model_name')}:calibration:{reason}" for reason in item.get("quality_reasons_json", [])])
return {
"release_gate_status": "PASS" if not reasons else "REJECTED",
"release_gate_reasons": reasons,
}
def _check_hash(root: Path, relative: str | None, expected: str | None, errors: list[str]) -> None:
if not relative or not expected:
errors.append(f"missing hash contract for {relative}")
return
path = root / relative
if not path.is_file():
errors.append(f"hash target missing: {relative}")
return
actual = sha256_file(path)
if actual != expected:
errors.append(f"hash mismatch: {relative}")
def _run_sample_inference(root: Path, manifests: list[dict[str, Any]], errors: list[str]) -> None:
try:
import onnxruntime as ort
except ModuleNotFoundError as exc:
raise SystemExit("Python package 'onnxruntime' is required for --run-onnx validation.") from exc
sample = read_json(root / "examples/sample_input.json")
features = np.array([[float(sample[name]) for name in FEATURE_ORDER]], dtype=np.float32)
for item in manifests:
session = ort.InferenceSession(str(root / item["artifact_path"]))
outputs = session.run(None, {"features": features})
if not outputs or outputs[0].shape[-1] != len(MODEL_OUTPUTS[item["model_type"]]):
errors.append(f"{item['model_type']} sample output shape is invalid")