Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
# Trader V4 Training Pipeline
|
||||
|
||||
This directory contains the executable training chain for Trader V4.
|
||||
Large data stays under `/Users/zach/Desktop/quant-strategy-training-data`.
|
||||
|
||||
Run order:
|
||||
|
||||
```bash
|
||||
PY=/Users/zach/IdeaProjects/quant-trading-ai/quant-strategy-server/.venv/bin/python
|
||||
RUN_ID=btc-v4-p0-001
|
||||
ROOT=/Users/zach/Desktop/quant-strategy-training-data
|
||||
|
||||
$PY training/scripts/01_audit_source_data.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19
|
||||
$PY training/scripts/02_build_replay_1m.py --run-id $RUN_ID --data-root $ROOT --symbol BTC-USDT-PERP --start-date 2025-06-20 --end-date 2026-06-19
|
||||
$PY training/scripts/03_build_splits.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/04_build_feature_frame.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/05_build_price_plan_context.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/06_build_direction_labels.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/07_build_entry_labels.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/08_build_position_state_samples.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/09_build_continue_exit_risk_labels.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/10_build_train_datasets.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/11_train_small_models.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/12_calibrate_models.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/13_search_pm_thresholds.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/14_integrated_backtest.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/15_export_artifact_bundle.py --run-id $RUN_ID --data-root $ROOT
|
||||
$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle
|
||||
$PY training/scripts/17_promote_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --reason "validation_locked and latest_stress passed for SHADOW"
|
||||
$PY training/scripts/16_validate_artifact_bundle.py --artifact-root $ROOT/trader-v4/runs/$RUN_ID/export/trader-model-bundle-$RUN_ID/artifact_bundle --require-active --run-onnx
|
||||
```
|
||||
|
||||
Java SHADOW 只加载 `ACTIVE` 包。15 号脚本永远只生成 `CANDIDATE`,16 号校验通过且上线门槛通过后,17 号脚本才允许把包提升为 `ACTIVE`。
|
||||
@@ -0,0 +1,7 @@
|
||||
pandas==2.2.3
|
||||
pyarrow==24.0.0
|
||||
numpy==2.4.6
|
||||
scikit-learn==1.7.0
|
||||
scipy==1.18.0
|
||||
onnx==1.22.0
|
||||
onnxruntime==1.27.0
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.replay import write_audit_outputs
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--symbol", default="BTC-USDT-PERP")
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
parser.add_argument("--min-ready-days", type=int, default=250)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
write_audit_outputs(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.replay import build_replay_1m
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--raw-root", type=Path)
|
||||
parser.add_argument("--symbol", default="BTC-USDT-PERP")
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
parser.add_argument("--min-minutes-per-day", type=int, default=1400)
|
||||
parser.add_argument("--min-replay-ready-days", type=int, default=250)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_replay_1m(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.replay import build_splits
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--fit-inner-start", default="2025-06-20")
|
||||
parser.add_argument("--fit-inner-end", default="2026-01-15")
|
||||
parser.add_argument("--tune-inner-start", default="2026-01-16")
|
||||
parser.add_argument("--tune-inner-end", default="2026-02-28")
|
||||
parser.add_argument("--validation-locked-start", default="2026-03-01")
|
||||
parser.add_argument("--validation-locked-end", default="2026-04-30")
|
||||
parser.add_argument("--latest-stress-start", default="2026-05-01")
|
||||
parser.add_argument("--latest-stress-end", default="2026-06-19")
|
||||
parser.add_argument("--gap-minutes", type=int, default=60)
|
||||
parser.add_argument("--fold-count", type=int, default=3)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_splits(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.features import build_feature_frame
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--split-manifest-path", type=Path)
|
||||
parser.add_argument("--allow-incomplete-days", action="store_true")
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_feature_frame(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import write_price_plan_context
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--price-plan-id", default="btc-p0-plan-45m")
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
write_price_plan_context(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_direction_labels
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_direction_labels(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_entry_labels
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--price-plan-context-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_entry_labels(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_position_state_samples
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--entry-label-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_position_state_samples(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.labels import build_continue_exit_risk_labels
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--price-plan-context-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_continue_exit_risk_labels(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.datasets import build_train_datasets
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--direction-label-path", type=Path)
|
||||
parser.add_argument("--entry-label-path", type=Path)
|
||||
parser.add_argument("--continue-label-path", type=Path)
|
||||
parser.add_argument("--exit-label-path", type=Path)
|
||||
parser.add_argument("--risk-label-path", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_train_datasets(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.training import train_small_models
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--max-rows", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
train_small_models(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.training import build_calibrators
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
build_calibrators(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.pm import search_pm_thresholds
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
search_pm_thresholds(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
from trader_training.pm import integrated_backtest
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
integrated_backtest(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.exporter import export_artifact_bundle
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--export-root", type=Path)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
export_artifact_bundle(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import setup_logging
|
||||
from trader_training.validator import validate_artifact_bundle
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--artifact-root", type=Path, required=True)
|
||||
parser.add_argument("--require-active", action="store_true")
|
||||
parser.add_argument("--run-onnx", action="store_true")
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
validate_artifact_bundle(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.io_utils import setup_logging
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--artifact-root", type=Path, required=True)
|
||||
parser.add_argument("--reason", required=True)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
promote_artifact_bundle(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.onnx_export import LinearHead, export_heads
|
||||
from trader_training.io_utils import read_json, write_json
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
from trader_training.replay import build_splits
|
||||
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
class TrainingContractTest(unittest.TestCase):
|
||||
def test_feature_order_is_v4_contract_size(self) -> None:
|
||||
self.assertEqual(39, len(FEATURE_ORDER))
|
||||
self.assertEqual(len(FEATURE_ORDER), len(set(FEATURE_ORDER)))
|
||||
self.assertEqual("ret_1m_bps", FEATURE_ORDER[0])
|
||||
self.assertEqual("minutes_to_next_funding", FEATURE_ORDER[-1])
|
||||
|
||||
def test_output_mapping_matches_model_outputs(self) -> None:
|
||||
for model_name, fields in MODEL_OUTPUTS.items():
|
||||
self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name]))
|
||||
self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields])
|
||||
|
||||
def test_split_builder_uses_locked_validation_contract(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
replay_path = data_root / "replay_1m.parquet"
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"event_time": pd.date_range("2025-06-20", "2026-06-19", freq="D", tz="UTC"),
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
}
|
||||
)
|
||||
frame.to_parquet(replay_path, index=False)
|
||||
|
||||
build_splits(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-split",
|
||||
replay_path=replay_path,
|
||||
fit_inner_start="2025-06-20",
|
||||
fit_inner_end="2026-01-15",
|
||||
tune_inner_start="2026-01-16",
|
||||
tune_inner_end="2026-02-28",
|
||||
validation_locked_start="2026-03-01",
|
||||
validation_locked_end="2026-04-30",
|
||||
latest_stress_start="2026-05-01",
|
||||
latest_stress_end="2026-06-19",
|
||||
gap_minutes=0,
|
||||
fold_count=2,
|
||||
)
|
||||
)
|
||||
|
||||
manifest = read_json(data_root / "trader-v4" / "runs" / "unit-split" / "split" / "split_manifest.json")
|
||||
self.assertEqual(set(TRAINING_SPLITS), {item["split_id"] for item in manifest["splits"]})
|
||||
self.assertEqual([VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], manifest["sealed_splits"])
|
||||
self.assertEqual("FINAL_GATE_ONLY", manifest["latest_stress_policy"])
|
||||
|
||||
def test_exported_onnx_accepts_java_feature_shape(self) -> None:
|
||||
import onnxruntime as ort
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "direction.onnx"
|
||||
export_heads(
|
||||
path,
|
||||
[
|
||||
LinearHead(
|
||||
"direction",
|
||||
"softmax",
|
||||
np.zeros((39, 3), dtype=np.float32),
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
)
|
||||
],
|
||||
)
|
||||
session = ort.InferenceSession(str(path))
|
||||
output = session.run(None, {"features": np.zeros((1, 39), dtype=np.float32)})[0]
|
||||
self.assertEqual((1, 3), output.shape)
|
||||
self.assertAlmostEqual(1.0, float(output.sum()), places=6)
|
||||
|
||||
def test_promotion_requires_passed_validation_and_marks_all_manifests_active(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
manifest_dir = root / "manifests"
|
||||
manifest_dir.mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_validation_result.json", {"status": "PASS", "release_gate_status": "PASS", "release_gate_reasons": []})
|
||||
write_json(manifest_dir / "model_bundle_manifest.json", {"status": "CANDIDATE"})
|
||||
write_json(manifest_dir / "model_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
|
||||
write_json(manifest_dir / "calibration_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
|
||||
write_json(manifest_dir / "position_manager_manifest.json", {"status": "CANDIDATE"})
|
||||
write_json(manifest_dir / "training_export_manifest.json", {"status": "CANDIDATE"})
|
||||
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_bundle_manifest.json")["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_manifest.json")[0]["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "calibration_manifest.json")[0]["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "position_manager_manifest.json")["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "training_export_manifest.json")["status"])
|
||||
|
||||
def test_promotion_refuses_failed_validation(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
(root / "manifests").mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_validation_result.json", {"status": "FAIL"})
|
||||
with self.assertRaises(SystemExit):
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
result = read_json(root.parent / "artifact_promotion_result.json")
|
||||
self.assertEqual("REFUSED", result["status"])
|
||||
self.assertEqual("validation result is not PASS", result["message"])
|
||||
|
||||
def test_promotion_refuses_failed_release_gate_and_overwrites_stale_result(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
(root / "manifests").mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_promotion_result.json", {"status": "ACTIVE"})
|
||||
write_json(
|
||||
root.parent / "artifact_validation_result.json",
|
||||
{
|
||||
"status": "PASS",
|
||||
"release_gate_status": "REJECTED",
|
||||
"release_gate_reasons": ["backtest_status=REJECTED"],
|
||||
},
|
||||
)
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
|
||||
result = read_json(root.parent / "artifact_promotion_result.json")
|
||||
self.assertEqual("REFUSED", result["status"])
|
||||
self.assertEqual("release gate is not PASS", result["message"])
|
||||
self.assertEqual(["backtest_status=REJECTED"], result["release_gate_reasons"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1 @@
|
||||
"""Trader V4 training pipeline."""
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
|
||||
from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS
|
||||
|
||||
|
||||
def _feature_base(root, args) -> pd.DataFrame:
|
||||
path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
||||
frame = read_parquet(path)
|
||||
require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame")
|
||||
trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy()
|
||||
trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy()
|
||||
if trainable.empty:
|
||||
raise ValueError("no trainable feature rows; check feature quality report")
|
||||
return trainable
|
||||
|
||||
|
||||
def build_train_datasets(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
feature = _feature_base(root, args)
|
||||
dataset_dir = root / "dataset"
|
||||
manifests = {}
|
||||
|
||||
direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet")
|
||||
direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner")
|
||||
manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds)
|
||||
|
||||
entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet")
|
||||
entry_pivot = _entry_pivot(entry)
|
||||
entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner")
|
||||
manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds)
|
||||
|
||||
continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet")
|
||||
continue_ds = feature.merge(
|
||||
continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]],
|
||||
on="sample_id",
|
||||
how="inner",
|
||||
)
|
||||
manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds)
|
||||
|
||||
exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet")
|
||||
exit_cols = [
|
||||
"sample_id",
|
||||
"long_exit_target",
|
||||
"short_exit_target",
|
||||
"long_adverse_move_bps",
|
||||
"short_adverse_move_bps",
|
||||
"adverse_move_prob_label",
|
||||
"reversal_prob_label",
|
||||
"stop_hit_prob_label",
|
||||
"stagnation_prob_label",
|
||||
]
|
||||
exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner")
|
||||
manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds)
|
||||
|
||||
risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet")
|
||||
risk_cols = [
|
||||
"sample_id",
|
||||
"market_risk_target",
|
||||
"market_path_risk_bps",
|
||||
"long_position_path_risk_bps",
|
||||
"short_position_path_risk_bps",
|
||||
"long_position_risk_target",
|
||||
"short_position_risk_target",
|
||||
"market_drawdown_prob_label",
|
||||
"volatility_expansion_prob_label",
|
||||
"spike_prob_label",
|
||||
"liquidity_deterioration_prob_label",
|
||||
"position_drawdown_prob_label",
|
||||
]
|
||||
risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner")
|
||||
manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds)
|
||||
|
||||
write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests})
|
||||
_write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests)
|
||||
logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests))
|
||||
|
||||
|
||||
def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame:
|
||||
require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels")
|
||||
long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
|
||||
columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"}
|
||||
)
|
||||
short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
|
||||
columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"}
|
||||
)
|
||||
return long.merge(short, on="sample_id", how="inner")
|
||||
|
||||
|
||||
def _write_dataset(path, frame: pd.DataFrame) -> dict:
|
||||
data_hash = write_parquet(path, frame)
|
||||
return manifest(
|
||||
path,
|
||||
{
|
||||
"row_count": len(frame),
|
||||
"feature_count": len(FEATURE_ORDER),
|
||||
"data_hash_sha256": data_hash,
|
||||
"split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _write_dataset_report(path, manifests: dict) -> None:
|
||||
lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"]
|
||||
for name, item in manifests.items():
|
||||
lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |")
|
||||
write_text(path, "\n".join(lines) + "\n")
|
||||
@@ -0,0 +1,244 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import read_json, read_parquet, run_root, sha256_file, sha256_json, utc_now_text, write_json
|
||||
from trader_training.pm import default_pm_config
|
||||
from trader_training.schemas import (
|
||||
CALIBRATION_BUNDLE_VERSION,
|
||||
FEATURE_ORDER,
|
||||
FEATURE_VERSION,
|
||||
FIT_SPLIT,
|
||||
LABEL_VERSION,
|
||||
MODEL_BUNDLE_VERSION,
|
||||
MODEL_OUTPUTS,
|
||||
OUTPUT_MAPPING,
|
||||
OUTPUT_SCHEMA,
|
||||
PM_CONFIG_VERSION,
|
||||
SPLIT_VERSION,
|
||||
TUNE_SPLIT,
|
||||
VALIDATION_LOCKED_SPLIT,
|
||||
)
|
||||
|
||||
|
||||
REQUIRED_MODELS = ("DIRECTION", "ENTRY", "CONTINUE", "EXIT", "RISK")
|
||||
EXPORT_STATUS = "CANDIDATE"
|
||||
|
||||
|
||||
def export_artifact_bundle(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
export_root = args.export_root or root / "export" / f"trader-model-bundle-{args.run_id}"
|
||||
artifact_root = export_root / "artifact_bundle"
|
||||
for folder in ("models", "schemas", "calibrators", "manifests", "examples"):
|
||||
(artifact_root / folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
feature_schema_path = root / "feature" / "feature_schema.json"
|
||||
feature_order_path = root / "feature" / "feature_order.json"
|
||||
output_schema_path = artifact_root / "schemas" / "output_schema.json"
|
||||
shutil.copy2(feature_schema_path, artifact_root / "schemas" / "feature_schema.json")
|
||||
shutil.copy2(feature_order_path, artifact_root / "schemas" / "feature_order.json")
|
||||
write_json(output_schema_path, OUTPUT_SCHEMA)
|
||||
|
||||
sample_input = _sample_input(root)
|
||||
write_json(artifact_root / "examples" / "sample_input.json", sample_input)
|
||||
write_json(artifact_root / "examples" / "sample_output.json", _sample_output())
|
||||
|
||||
price_plan = read_json(root / "label" / "price_plan_context.json")
|
||||
write_json(artifact_root / "price_plan_context.json", price_plan)
|
||||
|
||||
model_manifest_rows = []
|
||||
calibration_manifest_rows = []
|
||||
model_hashes = {}
|
||||
train_start = _split_time(root, FIT_SPLIT, "start")
|
||||
train_end = _split_time(root, FIT_SPLIT, "end")
|
||||
validation_start = _split_time(root, VALIDATION_LOCKED_SPLIT, "start")
|
||||
validation_end = _split_time(root, VALIDATION_LOCKED_SPLIT, "end")
|
||||
calibration_train_manifest = read_json(root / "calibration" / "calibration_train_manifest.json")
|
||||
calibration_quality = {row["model_name"]: row for row in calibration_train_manifest.get("calibrators", [])}
|
||||
for model_name in REQUIRED_MODELS:
|
||||
src = root / "model" / model_name.lower() / f"{model_name.lower()}.onnx"
|
||||
dst = artifact_root / "models" / f"{model_name.lower()}.onnx"
|
||||
shutil.copy2(src, dst)
|
||||
model_hashes[model_name] = sha256_file(dst)
|
||||
cal_src = root / "calibration" / model_name.lower() / "calibrator.json"
|
||||
cal_dst = artifact_root / "calibrators" / model_name.lower() / "calibrator.json"
|
||||
cal_dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(cal_src, cal_dst)
|
||||
cal_hash = sha256_file(cal_dst)
|
||||
metrics = read_json(root / "model" / model_name.lower() / "model_train_result.json")
|
||||
model_manifest_rows.append(
|
||||
_model_manifest_row(
|
||||
model_name,
|
||||
f"models/{model_name.lower()}.onnx",
|
||||
model_hashes[model_name],
|
||||
sha256_file(artifact_root / "schemas" / "feature_schema.json"),
|
||||
sha256_file(artifact_root / "schemas" / "feature_order.json"),
|
||||
sha256_file(output_schema_path),
|
||||
metrics.get("metrics", {}),
|
||||
metrics.get("quality_status", "UNKNOWN"),
|
||||
metrics.get("quality_reasons", []),
|
||||
train_start,
|
||||
train_end,
|
||||
validation_start,
|
||||
validation_end,
|
||||
EXPORT_STATUS,
|
||||
)
|
||||
)
|
||||
calibration_manifest_rows.append(
|
||||
{
|
||||
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||||
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||||
"model_name": model_name,
|
||||
"calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0",
|
||||
"calibration_method": "BINNING",
|
||||
"calibrator_path": f"calibrators/{model_name.lower()}/calibrator.json",
|
||||
"calibrator_hash_sha256": cal_hash,
|
||||
"calibration_window_from": _split_time(root, TUNE_SPLIT, "start"),
|
||||
"calibration_window_to": _split_time(root, TUNE_SPLIT, "end"),
|
||||
"calibration_metrics_json": {},
|
||||
"bucket_metrics_json": {},
|
||||
"output_after_calibration_schema_hash": sha256_file(output_schema_path),
|
||||
"quality_status": calibration_quality.get(model_name, {}).get("quality_status", "UNKNOWN"),
|
||||
"quality_reasons_json": calibration_quality.get(model_name, {}).get("quality_reasons", []),
|
||||
"status": EXPORT_STATUS,
|
||||
}
|
||||
)
|
||||
|
||||
pm_payload = read_json(root / "pm-search" / "position_manager_config.json") if (root / "pm-search" / "position_manager_config.json").is_file() else {"config": default_pm_config(), "threshold_stability_json": {}}
|
||||
pm_config = pm_payload["config"]
|
||||
pm_hash = sha256_json(pm_config)
|
||||
backtest_manifest = read_json(root / "backtest" / "backtest_manifest.json")
|
||||
write_json(
|
||||
artifact_root / "manifests" / "position_manager_manifest.json",
|
||||
{
|
||||
"pm_config_version": PM_CONFIG_VERSION,
|
||||
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||||
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||||
"threshold_stability_json": pm_payload.get("threshold_stability_json", {}),
|
||||
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
|
||||
"config_json": pm_config,
|
||||
"config_hash_sha256": pm_hash,
|
||||
"status": EXPORT_STATUS,
|
||||
},
|
||||
)
|
||||
write_json(artifact_root / "manifests" / "model_manifest.json", model_manifest_rows)
|
||||
write_json(artifact_root / "manifests" / "calibration_manifest.json", calibration_manifest_rows)
|
||||
bundle_hash = sha256_json({"models": model_hashes, "feature_order_hash": sha256_file(artifact_root / "schemas" / "feature_order.json"), "pm_hash": pm_hash})
|
||||
write_json(
|
||||
artifact_root / "manifests" / "model_bundle_manifest.json",
|
||||
{
|
||||
"manifest_schema_version": "trader-model-bundle-manifest-v4-p0",
|
||||
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||||
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||||
"feature_version": FEATURE_VERSION,
|
||||
"label_version": LABEL_VERSION,
|
||||
"split_version": SPLIT_VERSION,
|
||||
"training_run_id": args.run_id,
|
||||
"training_export_id": f"export-{args.run_id}",
|
||||
"backtest_manifest_id": f"backtest-{args.run_id}",
|
||||
"required_models_json": list(REQUIRED_MODELS),
|
||||
"provided_models_json": list(REQUIRED_MODELS),
|
||||
"missing_models_json": [],
|
||||
"allowed_run_modes_json": ["REPLAY_SIM", "SHADOW"],
|
||||
"bundle_hash_sha256": bundle_hash,
|
||||
"model_quality_status_json": {row["model_name"]: row["quality_status"] for row in model_manifest_rows},
|
||||
"calibration_quality_status_json": {row["model_name"]: row["quality_status"] for row in calibration_manifest_rows},
|
||||
"backtest_status": backtest_manifest.get("status"),
|
||||
"backtest_status_reasons_json": backtest_manifest.get("status_reasons", []),
|
||||
"backtest_metrics_json": backtest_manifest.get("metrics", {}),
|
||||
"complete": True,
|
||||
"status": EXPORT_STATUS,
|
||||
},
|
||||
)
|
||||
write_json(
|
||||
artifact_root / "manifests" / "training_export_manifest.json",
|
||||
{"created_at": utc_now_text(), "status": EXPORT_STATUS, "artifact_root": str(artifact_root), "bundle_hash_sha256": bundle_hash},
|
||||
)
|
||||
logging.info("trader.training.artifact_exported runId=%s status=%s bundleHash=%s path=%s", args.run_id, EXPORT_STATUS, bundle_hash, artifact_root)
|
||||
|
||||
|
||||
def _model_manifest_row(
|
||||
model_name: str,
|
||||
artifact_path: str,
|
||||
artifact_hash: str,
|
||||
feature_schema_hash: str,
|
||||
feature_order_hash: str,
|
||||
output_schema_hash: str,
|
||||
metrics: dict,
|
||||
quality_status: str,
|
||||
quality_reasons: list[str],
|
||||
train_start: str,
|
||||
train_end: str,
|
||||
validation_start: str,
|
||||
validation_end: str,
|
||||
status: str,
|
||||
) -> dict:
|
||||
return {
|
||||
"model_bundle_version": MODEL_BUNDLE_VERSION,
|
||||
"calibration_bundle_version": CALIBRATION_BUNDLE_VERSION,
|
||||
"model_name": model_name,
|
||||
"model_type": model_name,
|
||||
"side": "BOTH",
|
||||
"symbol_scope_json": ["BTC-USDT-PERP"],
|
||||
"bar_interval": "1m",
|
||||
"horizon_minutes": 45,
|
||||
"model_format": "ONNX",
|
||||
"model_runtime": "ONNX_RUNTIME_JAVA",
|
||||
"model_runtime_version": "1.22.0",
|
||||
"onnx_opset_version": 17,
|
||||
"producer_name": "trader-training",
|
||||
"producer_version": "v4-p0",
|
||||
"feature_version": FEATURE_VERSION,
|
||||
"feature_schema_path": "schemas/feature_schema.json",
|
||||
"feature_schema_hash": feature_schema_hash,
|
||||
"feature_order_path": "schemas/feature_order.json",
|
||||
"feature_order_hash": feature_order_hash,
|
||||
"input_tensor_name": "features",
|
||||
"input_dtype": "FLOAT32",
|
||||
"input_shape_json": {"features": len(FEATURE_ORDER), "batch": 1},
|
||||
"input_example_path": "examples/sample_input.json",
|
||||
"output_schema_path": "schemas/output_schema.json",
|
||||
"output_schema_hash": output_schema_hash,
|
||||
"output_tensor_names_json": ["prediction"],
|
||||
"output_mapping_json": OUTPUT_MAPPING[model_name],
|
||||
"output_value_rules_json": {"clip_by_output_schema": True},
|
||||
"label_version": LABEL_VERSION,
|
||||
"split_version": SPLIT_VERSION,
|
||||
"training_fold": "fold_01",
|
||||
"train_start": train_start,
|
||||
"train_end": train_end,
|
||||
"validation_start": validation_start,
|
||||
"validation_end": validation_end,
|
||||
"test_start": validation_start,
|
||||
"test_end": validation_end,
|
||||
"metrics_json": metrics,
|
||||
"quality_status": quality_status,
|
||||
"quality_reasons_json": quality_reasons,
|
||||
"artifact_path": artifact_path,
|
||||
"artifact_hash_sha256": artifact_hash,
|
||||
"source_hash": artifact_hash,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
|
||||
def _sample_input(root) -> dict:
|
||||
features = read_parquet(root / "feature" / "feature_frame.parquet")
|
||||
row = features[features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].iloc[0]
|
||||
return {feature: float(row[feature]) for feature in FEATURE_ORDER}
|
||||
|
||||
|
||||
def _sample_output() -> dict:
|
||||
return {model: {field: 0.0 for field in fields} for model, fields in MODEL_OUTPUTS.items()}
|
||||
|
||||
|
||||
def _split_time(root, split_id: str, key: str) -> str:
|
||||
manifest = read_json(root / "split" / "split_manifest.json")
|
||||
for item in manifest["splits"]:
|
||||
if item["split_id"] == split_id:
|
||||
return item[key]
|
||||
return "2026-01-01T00:00:00Z"
|
||||
@@ -0,0 +1,342 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import (
|
||||
manifest,
|
||||
read_parquet,
|
||||
require_columns,
|
||||
run_root,
|
||||
sha256_json,
|
||||
to_utc_series,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_text,
|
||||
)
|
||||
from trader_training.replay import assign_split
|
||||
from trader_training.schemas import FEATURE_ORDER, FEATURE_VERSION, FEATURES, FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
META_COLUMNS = [
|
||||
"sample_id",
|
||||
"symbol",
|
||||
"event_time",
|
||||
"open_time_ms",
|
||||
"split_id",
|
||||
"walk_forward_fold",
|
||||
"feature_version",
|
||||
"data_quality_flag",
|
||||
]
|
||||
|
||||
|
||||
def _safe_divide(numerator: pd.Series, denominator: pd.Series, default: float = 0.0) -> pd.Series:
|
||||
result = numerator / denominator.replace(0, np.nan)
|
||||
return result.replace([np.inf, -np.inf], np.nan).fillna(default)
|
||||
|
||||
|
||||
def _rolling_rank_last(values: pd.Series, window: int) -> pd.Series:
|
||||
def calc(raw: np.ndarray) -> float:
|
||||
last = raw[-1]
|
||||
return float(np.sum(raw <= last) / len(raw))
|
||||
|
||||
return values.rolling(window, min_periods=window).apply(calc, raw=True)
|
||||
|
||||
|
||||
def _complete_days(frame: pd.DataFrame) -> pd.DataFrame:
|
||||
frame = frame.copy()
|
||||
frame["event_date"] = frame["event_time"].dt.strftime("%Y-%m-%d")
|
||||
counts = frame.groupby(["symbol", "event_date"])["event_time"].count()
|
||||
complete = counts[counts == 1440].reset_index()[["symbol", "event_date"]]
|
||||
return frame.merge(complete, on=["symbol", "event_date"], how="inner").drop(columns=["event_date"])
|
||||
|
||||
|
||||
def build_feature_frame(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
|
||||
split_manifest_path = args.split_manifest_path or root / "split" / "split_manifest.json"
|
||||
replay = read_parquet(replay_path)
|
||||
required = [
|
||||
"symbol",
|
||||
"event_time",
|
||||
"open_time_ms",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"volume",
|
||||
"taker_buy_volume",
|
||||
"taker_sell_volume",
|
||||
"funding_bps",
|
||||
"mark_price",
|
||||
"index_price",
|
||||
"next_funding_time",
|
||||
"open_interest",
|
||||
"spread_bps",
|
||||
"level1_ofi_1m",
|
||||
"liquidation_buy_notional_1m",
|
||||
"liquidation_sell_notional_1m",
|
||||
"liquidation_available",
|
||||
]
|
||||
require_columns(replay, required, "replay_1m")
|
||||
replay = replay.copy()
|
||||
replay["event_time"] = to_utc_series(replay["event_time"])
|
||||
replay["next_funding_time"] = to_utc_series(replay["next_funding_time"])
|
||||
replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True)
|
||||
if not args.allow_incomplete_days:
|
||||
before = len(replay)
|
||||
replay = _complete_days(replay)
|
||||
logging.info("trader.training.feature_complete_days rowBefore=%s rowAfter=%s", before, len(replay))
|
||||
|
||||
frames: list[pd.DataFrame] = []
|
||||
for symbol, group in replay.groupby("symbol", sort=False):
|
||||
group = group.sort_values("event_time").reset_index(drop=True).copy()
|
||||
close = group["close"].astype(float)
|
||||
high = group["high"].astype(float)
|
||||
low = group["low"].astype(float)
|
||||
volume = group["volume"].astype(float)
|
||||
log_ret = np.log(close / close.shift(1))
|
||||
group["ret_1m_bps"] = (close / close.shift(1) - 1.0) * 10000.0
|
||||
group["ret_5m_bps"] = (close / close.shift(5) - 1.0) * 10000.0
|
||||
group["ret_15m_bps"] = (close / close.shift(15) - 1.0) * 10000.0
|
||||
group["ret_60m_bps"] = (close / close.shift(60) - 1.0) * 10000.0
|
||||
group["ret_240m_bps"] = (close / close.shift(240) - 1.0) * 10000.0
|
||||
group["realized_vol_15m_bps"] = log_ret.rolling(15, min_periods=15).std() * 10000.0
|
||||
group["realized_vol_60m_bps"] = log_ret.rolling(60, min_periods=60).std() * 10000.0
|
||||
group["vol_ratio_15m_60m"] = _safe_divide(group["realized_vol_15m_bps"], group["realized_vol_60m_bps"].clip(lower=1.0))
|
||||
group["range_15m_bps"] = (high.rolling(15, min_periods=15).max() / low.rolling(15, min_periods=15).min() - 1.0) * 10000.0
|
||||
group["range_60m_bps"] = (high.rolling(60, min_periods=60).max() / low.rolling(60, min_periods=60).min() - 1.0) * 10000.0
|
||||
vol_mean = volume.rolling(60, min_periods=60).mean()
|
||||
vol_std = volume.rolling(60, min_periods=60).std().replace(0, np.nan)
|
||||
group["volume_zscore_60m"] = ((volume - vol_mean) / vol_std).fillna(0.0)
|
||||
group["trend_consistency_15m"] = np.sign(group["ret_1m_bps"]).rolling(15, min_periods=15).mean()
|
||||
high60 = high.rolling(60, min_periods=60).max()
|
||||
low60 = low.rolling(60, min_periods=60).min()
|
||||
group["channel_position_60m_pct"] = ((close - low60) / (high60 - low60).clip(lower=1e-12)).clip(0.0, 1.0)
|
||||
prev_high60 = high.shift(1).rolling(60, min_periods=60).max()
|
||||
prev_low60 = low.shift(1).rolling(60, min_periods=60).min()
|
||||
group["upper_breakout_60m_bps"] = ((close / prev_high60 - 1.0).clip(lower=0.0)) * 10000.0
|
||||
group["lower_breakout_60m_bps"] = ((prev_low60 / close - 1.0).clip(lower=0.0)) * 10000.0
|
||||
recent_high15 = high.rolling(15, min_periods=15).max()
|
||||
recent_low15 = low.rolling(15, min_periods=15).min()
|
||||
broke_up = recent_high15 > prev_high60
|
||||
broke_down = recent_low15 < prev_low60
|
||||
group["upper_failed_break_reclaim_15m_bps"] = np.where(broke_up, ((prev_high60 - close).clip(lower=0.0) / close) * 10000.0, 0.0)
|
||||
group["lower_failed_break_reclaim_15m_bps"] = np.where(broke_down, ((close - prev_low60).clip(lower=0.0) / close) * 10000.0, 0.0)
|
||||
group["sweep_up_15m_bps"] = ((recent_high15 / close - 1.0).clip(lower=0.0)) * 10000.0
|
||||
group["sweep_down_15m_bps"] = ((close / recent_low15 - 1.0).clip(lower=0.0)) * 10000.0
|
||||
rank = _rolling_rank_last(group["range_15m_bps"], 240)
|
||||
group["compression_score_4h_pct"] = 1.0 - rank
|
||||
group["compression_release_15m_bps"] = (group["range_15m_bps"] - group["range_15m_bps"].rolling(240, min_periods=240).median()).clip(lower=0.0)
|
||||
buy = group["taker_buy_volume"].astype(float)
|
||||
sell = group["taker_sell_volume"].astype(float)
|
||||
group["taker_imbalance_1m"] = _safe_divide(buy - sell, buy + sell)
|
||||
group["taker_imbalance_5m"] = _safe_divide(buy.rolling(5, min_periods=5).sum() - sell.rolling(5, min_periods=5).sum(), (buy + sell).rolling(5, min_periods=5).sum())
|
||||
group["taker_imbalance_15m"] = _safe_divide(buy.rolling(15, min_periods=15).sum() - sell.rolling(15, min_periods=15).sum(), (buy + sell).rolling(15, min_periods=15).sum())
|
||||
group["spread_rank_24h_pct"] = _rolling_rank_last(group["spread_bps"].astype(float), 1440)
|
||||
group["oi_delta_15m_bps"] = (group["open_interest"].astype(float) / group["open_interest"].astype(float).shift(15) - 1.0) * 10000.0
|
||||
group["oi_delta_60m_bps"] = (group["open_interest"].astype(float) / group["open_interest"].astype(float).shift(60) - 1.0) * 10000.0
|
||||
group["mark_index_basis_bps"] = (group["mark_price"].astype(float) / group["index_price"].astype(float) - 1.0) * 10000.0
|
||||
liq_buy = group["liquidation_buy_notional_1m"].astype(float)
|
||||
liq_sell = group["liquidation_sell_notional_1m"].astype(float)
|
||||
liq_total_15 = (liq_buy + liq_sell).rolling(15, min_periods=1).sum()
|
||||
group["liquidation_imbalance_15m"] = _safe_divide(liq_buy.rolling(15, min_periods=1).sum() - liq_sell.rolling(15, min_periods=1).sum(), liq_total_15)
|
||||
liq_mean = liq_total_15.rolling(1440, min_periods=60).mean()
|
||||
liq_std = liq_total_15.rolling(1440, min_periods=60).std().replace(0, np.nan)
|
||||
group["liquidation_notional_zscore_15m"] = ((liq_total_15 - liq_mean) / liq_std).fillna(0.0)
|
||||
minute_of_day = group["event_time"].dt.hour * 60 + group["event_time"].dt.minute
|
||||
group["minute_of_day_sin"] = np.sin(2 * np.pi * minute_of_day / 1440.0)
|
||||
group["minute_of_day_cos"] = np.cos(2 * np.pi * minute_of_day / 1440.0)
|
||||
group["minutes_to_next_funding"] = ((group["next_funding_time"] - group["event_time"]).dt.total_seconds() / 60.0).clip(0.0, 480.0)
|
||||
group["symbol"] = symbol
|
||||
frames.append(group)
|
||||
|
||||
frame = pd.concat(frames, ignore_index=True)
|
||||
frame["sample_id"] = frame["symbol"].astype(str) + ":" + frame["open_time_ms"].astype(str)
|
||||
frame["split_id"] = assign_split(frame["event_time"], split_manifest_path)
|
||||
frame["walk_forward_fold"] = np.where(frame["split_id"].eq(FIT_SPLIT), "fold_01", "NO_FOLD")
|
||||
frame["feature_version"] = FEATURE_VERSION
|
||||
hard_na = frame[FEATURE_ORDER].isna().any(axis=1)
|
||||
optional_missing = frame["liquidation_available"].fillna(0).eq(0)
|
||||
frame["data_quality_flag"] = np.where(hard_na, "WARMUP", np.where(optional_missing, "PARTIAL_OPTIONAL", "OK"))
|
||||
ordered = frame[META_COLUMNS + FEATURE_ORDER].copy()
|
||||
for feature in FEATURE_ORDER:
|
||||
ordered[feature] = pd.to_numeric(ordered[feature], errors="coerce").astype("float32")
|
||||
|
||||
feature_dir = root / "feature"
|
||||
data_hash = write_parquet(feature_dir / "feature_frame.parquet", ordered)
|
||||
schema = [feature.as_json() for feature in FEATURES]
|
||||
feature_order_hash = write_json(feature_dir / "feature_order.json", FEATURE_ORDER)
|
||||
feature_schema_hash = write_json(feature_dir / "feature_schema.json", schema)
|
||||
write_json(
|
||||
feature_dir / "feature_frame.manifest.json",
|
||||
manifest(
|
||||
feature_dir / "feature_frame.parquet",
|
||||
{
|
||||
"row_count": len(ordered),
|
||||
"ok_row_count": int(ordered["data_quality_flag"].eq("OK").sum()),
|
||||
"partial_optional_row_count": int(ordered["data_quality_flag"].eq("PARTIAL_OPTIONAL").sum()),
|
||||
"warmup_row_count": int(ordered["data_quality_flag"].eq("WARMUP").sum()),
|
||||
"feature_count": len(FEATURE_ORDER),
|
||||
"feature_version": FEATURE_VERSION,
|
||||
"feature_order_hash": feature_order_hash,
|
||||
"feature_schema_hash": feature_schema_hash,
|
||||
"data_hash_sha256": data_hash,
|
||||
},
|
||||
),
|
||||
)
|
||||
write_feature_report(feature_dir / "feature_quality_report.md", ordered, feature_schema_hash, feature_order_hash)
|
||||
logging.info(
|
||||
"trader.training.feature_written runId=%s rowCount=%s splitCounts=%s eventFrom=%s eventTo=%s path=%s",
|
||||
args.run_id,
|
||||
len(ordered),
|
||||
ordered["split_id"].value_counts().to_dict(),
|
||||
ordered["event_time"].min(),
|
||||
ordered["event_time"].max(),
|
||||
feature_dir / "feature_frame.parquet",
|
||||
)
|
||||
|
||||
|
||||
def write_feature_report(path, frame: pd.DataFrame, feature_schema_hash: str, feature_order_hash: str) -> None:
|
||||
split_rows = []
|
||||
for split_id, group in frame.groupby("split_id", sort=True):
|
||||
split_rows.append(
|
||||
{
|
||||
"split_id": split_id,
|
||||
"rows": len(group),
|
||||
"start": str(group["event_time"].min()),
|
||||
"end": str(group["event_time"].max()),
|
||||
"ok": int(group["data_quality_flag"].eq("OK").sum()),
|
||||
"partial_optional": int(group["data_quality_flag"].eq("PARTIAL_OPTIONAL").sum()),
|
||||
"warmup": int(group["data_quality_flag"].eq("WARMUP").sum()),
|
||||
}
|
||||
)
|
||||
finite_rows = []
|
||||
for feature in FEATURE_ORDER:
|
||||
series = pd.to_numeric(frame[feature], errors="coerce")
|
||||
values = series.to_numpy(dtype=float)
|
||||
finite_rows.append(
|
||||
{
|
||||
"feature": feature,
|
||||
"nan_count": int(series.isna().sum()),
|
||||
"inf_count": int(np.isinf(values).sum()),
|
||||
"finite_count": int(np.isfinite(values).sum()),
|
||||
}
|
||||
)
|
||||
correlation_rows = _high_correlation_rows(frame)
|
||||
drift_rows = _drift_rows(frame)
|
||||
lines = [
|
||||
"# Trader Feature Quality Report",
|
||||
"",
|
||||
f"- row_count: {len(frame)}",
|
||||
f"- OK: {int(frame['data_quality_flag'].eq('OK').sum())}",
|
||||
f"- PARTIAL_OPTIONAL: {int(frame['data_quality_flag'].eq('PARTIAL_OPTIONAL').sum())}",
|
||||
f"- WARMUP: {int(frame['data_quality_flag'].eq('WARMUP').sum())}",
|
||||
f"- feature_schema_hash: {feature_schema_hash}",
|
||||
f"- feature_order_hash: {feature_order_hash}",
|
||||
"",
|
||||
"## Split Coverage",
|
||||
"",
|
||||
_markdown_table(split_rows, ["split_id", "rows", "start", "end", "ok", "partial_optional", "warmup"]),
|
||||
"",
|
||||
"## Source Coverage",
|
||||
"",
|
||||
f"- replay_1m_required_columns: present",
|
||||
f"- liquidation_available_share: {float(frame['liquidation_available'].mean()):.6f}",
|
||||
f"- feature_rows_with_optional_liquidation_missing: {int(frame['data_quality_flag'].eq('PARTIAL_OPTIONAL').sum())}",
|
||||
"",
|
||||
"## Leakage Check",
|
||||
"",
|
||||
"- 所有特征只使用当前分钟收盘后已经知道的数据,滚动窗口都只看 `<= t`。",
|
||||
"- 未来价格、未来收益、目标标签不进入 `feature_frame.parquet`。",
|
||||
"",
|
||||
"## Extreme Value Check",
|
||||
"",
|
||||
_markdown_table(finite_rows, ["feature", "nan_count", "inf_count", "finite_count"]),
|
||||
"",
|
||||
"## High Correlation Check",
|
||||
"",
|
||||
_markdown_table(correlation_rows, ["feature_a", "feature_b", "corr_abs"]),
|
||||
"",
|
||||
"## Drift Check",
|
||||
"",
|
||||
_markdown_table(
|
||||
drift_rows,
|
||||
["feature", "train_p50", "tune_p50", "validation_p50", "p50_diff", "train_p99", "tune_p99", "validation_p99", "p99_diff"],
|
||||
),
|
||||
"",
|
||||
"## Distribution",
|
||||
"",
|
||||
"| feature | null_count | min | p01 | p50 | p99 | max |",
|
||||
"| --- | ---: | ---: | ---: | ---: | ---: | ---: |",
|
||||
]
|
||||
for feature in FEATURE_ORDER:
|
||||
series = pd.to_numeric(frame[feature], errors="coerce")
|
||||
quantiles = series.quantile([0.01, 0.5, 0.99])
|
||||
lines.append(
|
||||
f"| {feature} | {int(series.isna().sum())} | {series.min():.6g} | {quantiles.loc[0.01]:.6g} | {quantiles.loc[0.5]:.6g} | {quantiles.loc[0.99]:.6g} | {series.max():.6g} |"
|
||||
)
|
||||
write_text(path, "\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def feature_order_hash() -> str:
|
||||
return sha256_json(FEATURE_ORDER)
|
||||
|
||||
|
||||
def _high_correlation_rows(frame: pd.DataFrame) -> list[dict[str, object]]:
|
||||
sample = frame[FEATURE_ORDER].apply(pd.to_numeric, errors="coerce").dropna()
|
||||
if len(sample) > 5000:
|
||||
sample = sample.sample(5000, random_state=7)
|
||||
if sample.empty:
|
||||
return [{"feature_a": "NONE", "feature_b": "NONE", "corr_abs": 0.0}]
|
||||
corr = sample.corr().abs()
|
||||
rows = []
|
||||
for left_index, left in enumerate(FEATURE_ORDER):
|
||||
for right in FEATURE_ORDER[left_index + 1 :]:
|
||||
value = corr.loc[left, right]
|
||||
if pd.notna(value) and value >= 0.95:
|
||||
rows.append({"feature_a": left, "feature_b": right, "corr_abs": round(float(value), 6)})
|
||||
return rows[:30] or [{"feature_a": "NONE", "feature_b": "NONE", "corr_abs": 0.0}]
|
||||
|
||||
|
||||
def _drift_rows(frame: pd.DataFrame) -> list[dict[str, object]]:
|
||||
train = frame[frame["split_id"].eq(FIT_SPLIT)]
|
||||
validation = frame[frame["split_id"].eq(VALIDATION_LOCKED_SPLIT)]
|
||||
tune = frame[frame["split_id"].eq(TUNE_SPLIT)]
|
||||
rows = []
|
||||
for feature in FEATURE_ORDER:
|
||||
train_series = pd.to_numeric(train[feature], errors="coerce")
|
||||
validation_series = pd.to_numeric(validation[feature], errors="coerce")
|
||||
tune_series = pd.to_numeric(tune[feature], errors="coerce")
|
||||
train_p50 = float(train_series.quantile(0.5)) if not train_series.empty else 0.0
|
||||
tune_p50 = float(tune_series.quantile(0.5)) if not tune_series.empty else 0.0
|
||||
validation_p50 = float(validation_series.quantile(0.5)) if not validation_series.empty else 0.0
|
||||
train_p99 = float(train_series.quantile(0.99)) if not train_series.empty else 0.0
|
||||
tune_p99 = float(tune_series.quantile(0.99)) if not tune_series.empty else 0.0
|
||||
validation_p99 = float(validation_series.quantile(0.99)) if not validation_series.empty else 0.0
|
||||
rows.append(
|
||||
{
|
||||
"feature": feature,
|
||||
"train_p50": round(train_p50, 6),
|
||||
"tune_p50": round(tune_p50, 6),
|
||||
"validation_p50": round(validation_p50, 6),
|
||||
"p50_diff": round(validation_p50 - train_p50, 6),
|
||||
"train_p99": round(train_p99, 6),
|
||||
"tune_p99": round(tune_p99, 6),
|
||||
"validation_p99": round(validation_p99, 6),
|
||||
"p99_diff": round(validation_p99 - train_p99, 6),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def _markdown_table(rows: list[dict[str, object]], columns: list[str]) -> str:
|
||||
if not rows:
|
||||
rows = [{column: "" for column in columns}]
|
||||
lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"]
|
||||
for row in rows:
|
||||
lines.append("| " + " | ".join(str(row.get(column, "")) for column in columns) + " |")
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data")
|
||||
DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw"
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s event=%(message)s",
|
||||
)
|
||||
|
||||
|
||||
def add_common_args(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--run-id", required=True)
|
||||
parser.add_argument("--config", type=Path)
|
||||
parser.add_argument("--workspace", type=Path)
|
||||
parser.add_argument("--fail-fast", action="store_true")
|
||||
|
||||
|
||||
def run_root(args: argparse.Namespace) -> Path:
|
||||
return args.data_root / "trader-v4" / "runs" / args.run_id
|
||||
|
||||
|
||||
def ensure_dir(path: Path) -> Path:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def canonical_json_bytes(value: Any) -> bytes:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8")
|
||||
|
||||
|
||||
def sha256_bytes(data: bytes) -> str:
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def sha256_json(value: Any) -> str:
|
||||
return sha256_bytes(canonical_json_bytes(value))
|
||||
|
||||
|
||||
def sha256_file(path: Path) -> str:
|
||||
digest = hashlib.sha256()
|
||||
with path.open("rb") as fh:
|
||||
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def write_json(path: Path, value: Any) -> str:
|
||||
ensure_dir(path.parent)
|
||||
data = canonical_json_bytes(value)
|
||||
path.write_bytes(data + b"\n")
|
||||
return sha256_bytes(data)
|
||||
|
||||
|
||||
def read_json(path: Path) -> Any:
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
return json.load(fh)
|
||||
|
||||
|
||||
def write_parquet(path: Path, frame: pd.DataFrame) -> str:
|
||||
ensure_dir(path.parent)
|
||||
frame.to_parquet(path, index=False)
|
||||
return sha256_file(path)
|
||||
|
||||
|
||||
def read_parquet(path: Path) -> pd.DataFrame:
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"required parquet is missing: {path}")
|
||||
return pd.read_parquet(path)
|
||||
|
||||
|
||||
def write_text(path: Path, text: str) -> None:
|
||||
ensure_dir(path.parent)
|
||||
path.write_text(text, encoding="utf-8")
|
||||
|
||||
|
||||
def utc_now_text() -> str:
|
||||
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def to_utc_series(values: pd.Series) -> pd.Series:
|
||||
if pd.api.types.is_datetime64_any_dtype(values):
|
||||
return pd.to_datetime(values, utc=True)
|
||||
numeric = pd.to_numeric(values, errors="coerce")
|
||||
if numeric.notna().any():
|
||||
max_value = numeric.dropna().abs().max()
|
||||
unit = "ms" if max_value > 10_000_000_000 else "s"
|
||||
return pd.to_datetime(numeric, unit=unit, utc=True)
|
||||
return pd.to_datetime(values, utc=True, errors="coerce")
|
||||
|
||||
|
||||
def open_time_ms(values: pd.Series) -> pd.Series:
|
||||
dt = to_utc_series(values)
|
||||
return (dt.astype("int64") // 1_000_000).astype("Int64")
|
||||
|
||||
|
||||
def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]:
|
||||
return start_date, end_date
|
||||
|
||||
|
||||
def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]:
|
||||
base = raw_root / f"table={table}"
|
||||
if not base.is_dir():
|
||||
return []
|
||||
files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet"))
|
||||
selected: list[Path] = []
|
||||
for file in files:
|
||||
dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "")
|
||||
if start_date and dt_part < start_date:
|
||||
continue
|
||||
if end_date and dt_part > end_date:
|
||||
continue
|
||||
selected.append(file)
|
||||
return selected
|
||||
|
||||
|
||||
def read_partitioned_table(
|
||||
raw_root: Path,
|
||||
table: str,
|
||||
symbol: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
columns: Iterable[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
files = partition_files(raw_root, table, symbol, start_date, end_date)
|
||||
if not files:
|
||||
return pd.DataFrame()
|
||||
logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files))
|
||||
frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files]
|
||||
frame = pd.concat(frames, ignore_index=True)
|
||||
logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame))
|
||||
return frame
|
||||
|
||||
|
||||
def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None:
|
||||
missing = [column for column in columns if column not in frame.columns]
|
||||
if missing:
|
||||
raise ValueError(f"{name} is missing required columns: {missing}")
|
||||
|
||||
|
||||
def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = {
|
||||
"path": str(path),
|
||||
"hash_sha256": sha256_file(path) if path.is_file() else None,
|
||||
"created_at": utc_now_text(),
|
||||
}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
@@ -0,0 +1,417 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import (
|
||||
manifest,
|
||||
read_json,
|
||||
read_parquet,
|
||||
require_columns,
|
||||
run_root,
|
||||
sha256_json,
|
||||
to_utc_series,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_text,
|
||||
)
|
||||
from trader_training.schemas import LABEL_VERSION
|
||||
|
||||
|
||||
DEFAULT_LABEL_CONFIG = {
|
||||
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
|
||||
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
|
||||
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
|
||||
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
|
||||
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_COST_CONFIG = {
|
||||
"fee_bps": 4.0,
|
||||
"slippage_bps": 2.0,
|
||||
"funding_cost_bps": 0.5,
|
||||
}
|
||||
|
||||
|
||||
def _load_config(path, default):
|
||||
if path is None:
|
||||
return default
|
||||
value = read_json(path)
|
||||
merged = default.copy()
|
||||
for key, item in value.items():
|
||||
if isinstance(item, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = {**merged[key], **item}
|
||||
else:
|
||||
merged[key] = item
|
||||
return merged
|
||||
|
||||
|
||||
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
root = run_root(args)
|
||||
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
||||
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
|
||||
features = read_parquet(feature_path)
|
||||
replay = read_parquet(replay_path)
|
||||
require_columns(features, ("sample_id", "symbol", "event_time", "open_time_ms", "split_id", "walk_forward_fold", "data_quality_flag"), "feature_frame")
|
||||
require_columns(replay, ("symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "spread_bps"), "replay_1m")
|
||||
features = features.copy()
|
||||
replay = replay.copy()
|
||||
features["event_time"] = to_utc_series(features["event_time"])
|
||||
replay["event_time"] = to_utc_series(replay["event_time"])
|
||||
replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True)
|
||||
return features, replay
|
||||
|
||||
|
||||
def _future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
|
||||
start = index + 1
|
||||
end = min(len(group), index + horizon + 1)
|
||||
return group.iloc[start:end]
|
||||
|
||||
|
||||
def _contiguous_future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
|
||||
path = _future_path(group, index, horizon)
|
||||
if len(path) < horizon:
|
||||
return pd.DataFrame()
|
||||
current_ms = int(group.iloc[index]["open_time_ms"])
|
||||
expected = current_ms + np.arange(1, horizon + 1, dtype=np.int64) * 60_000
|
||||
actual = path["open_time_ms"].astype("int64").to_numpy()
|
||||
if len(actual) != len(expected) or not np.array_equal(actual, expected):
|
||||
return pd.DataFrame()
|
||||
return path
|
||||
|
||||
|
||||
def _side_return_bps(side: str, entry_price: float, exit_price: float) -> float:
|
||||
if side == "LONG":
|
||||
return (exit_price / entry_price - 1.0) * 10000.0
|
||||
return (entry_price / exit_price - 1.0) * 10000.0
|
||||
|
||||
|
||||
def _path_stats(group: pd.DataFrame, index: int, side: str, horizon: int, target_bps: float, stop_bps: float) -> dict[str, Any]:
|
||||
current = group.iloc[index]
|
||||
entry = float(current["close"])
|
||||
path = _contiguous_future_path(group, index, horizon)
|
||||
if path.empty:
|
||||
return {"valid": False}
|
||||
target_price = entry * (1.0 + target_bps / 10000.0) if side == "LONG" else entry * (1.0 - target_bps / 10000.0)
|
||||
stop_price = entry * (1.0 - stop_bps / 10000.0) if side == "LONG" else entry * (1.0 + stop_bps / 10000.0)
|
||||
target_hit = False
|
||||
stop_hit = False
|
||||
ambiguous = False
|
||||
time_to_target_ms = -1
|
||||
time_to_stop_ms = -1
|
||||
for _, row in path.iterrows():
|
||||
high = float(row["high"])
|
||||
low = float(row["low"])
|
||||
if side == "LONG":
|
||||
target_now = high >= target_price
|
||||
stop_now = low <= stop_price
|
||||
else:
|
||||
target_now = low <= target_price
|
||||
stop_now = high >= stop_price
|
||||
if target_now and stop_now:
|
||||
ambiguous = True
|
||||
stop_hit = True
|
||||
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
|
||||
break
|
||||
if target_now:
|
||||
target_hit = True
|
||||
time_to_target_ms = int(row["open_time_ms"] - current["open_time_ms"])
|
||||
break
|
||||
if stop_now:
|
||||
stop_hit = True
|
||||
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
|
||||
break
|
||||
exit_price = float(path.iloc[-1]["close"])
|
||||
final_return_bps = _side_return_bps(side, entry, exit_price)
|
||||
if side == "LONG":
|
||||
mfe = (path["high"].max() / entry - 1.0) * 10000.0
|
||||
mae = (entry / path["low"].min() - 1.0) * 10000.0
|
||||
else:
|
||||
mfe = (entry / path["low"].min() - 1.0) * 10000.0
|
||||
mae = (path["high"].max() / entry - 1.0) * 10000.0
|
||||
if target_hit:
|
||||
gross = target_bps
|
||||
elif stop_hit:
|
||||
gross = -stop_bps
|
||||
else:
|
||||
gross = final_return_bps
|
||||
return {
|
||||
"valid": True,
|
||||
"target_hit": int(target_hit),
|
||||
"stop_hit": int(stop_hit),
|
||||
"timeout_hit": int(not target_hit and not stop_hit),
|
||||
"ambiguous_hit": int(ambiguous),
|
||||
"time_to_target_ms": time_to_target_ms,
|
||||
"time_to_stop_ms": time_to_stop_ms,
|
||||
"gross_edge_bps": float(gross),
|
||||
"future_return_bps": float(final_return_bps),
|
||||
"mfe_bps": float(mfe),
|
||||
"mae_bps": float(mae),
|
||||
"future_spread_p80": float(path["spread_bps"].quantile(0.8)),
|
||||
"future_realized_vol_bps": float(np.log(path["close"].astype(float) / path["close"].astype(float).shift(1)).std() * 10000.0),
|
||||
}
|
||||
|
||||
|
||||
def write_price_plan_context(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
entry = labels["entry"]
|
||||
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
|
||||
context = {
|
||||
"pricePlanId": args.price_plan_id,
|
||||
"pricePlanConfigHash": sha256_json({"entry": entry, "cost": cost}),
|
||||
"stopDistanceBps": float(entry["stop_bps"]),
|
||||
"targetDistanceBps": float(entry["target_bps"]),
|
||||
"maxHoldMinutes": int(entry["max_hold_minutes"]),
|
||||
"costBps": cost_bps,
|
||||
}
|
||||
path = root / "label" / "price_plan_context.json"
|
||||
write_json(path, context)
|
||||
frame = pd.DataFrame([{
|
||||
"price_plan_id": context["pricePlanId"],
|
||||
"price_plan_hash": context["pricePlanConfigHash"],
|
||||
"target_bps": context["targetDistanceBps"],
|
||||
"stop_bps": context["stopDistanceBps"],
|
||||
"max_hold_minutes": context["maxHoldMinutes"],
|
||||
"cost_bps": context["costBps"],
|
||||
}])
|
||||
write_parquet(root / "label" / "price_plan_context.parquet", frame)
|
||||
logging.info("trader.training.price_plan_written runId=%s path=%s", args.run_id, path)
|
||||
|
||||
|
||||
def build_direction_labels(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
config = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)["direction"]
|
||||
features, replay = _base_frames(args)
|
||||
horizon = int(config["horizon_minutes"])
|
||||
replay = replay[["symbol", "event_time", "open_time_ms", "close"]].copy()
|
||||
future = replay[["symbol", "open_time_ms", "close"]].copy()
|
||||
future["open_time_ms"] = future["open_time_ms"].astype("int64") - horizon * 60_000
|
||||
future = future.rename(columns={"close": "future_close"})
|
||||
merged = features.merge(replay[["symbol", "open_time_ms", "close"]], on=["symbol", "open_time_ms"], how="left")
|
||||
merged = merged.merge(future, on=["symbol", "open_time_ms"], how="left")
|
||||
merged["future_return_bps"] = (merged["future_close"] / merged["close"] - 1.0) * 10000.0
|
||||
merged["direction_label"] = np.select(
|
||||
[merged["future_return_bps"] >= float(config["long_threshold_bps"]), merged["future_return_bps"] <= float(config["short_threshold_bps"])],
|
||||
["LONG", "SHORT"],
|
||||
default="NEUTRAL",
|
||||
)
|
||||
out = pd.DataFrame(
|
||||
{
|
||||
"sample_id": merged["sample_id"],
|
||||
"symbol": merged["symbol"],
|
||||
"event_time": merged["event_time"],
|
||||
"horizon_minutes": horizon,
|
||||
"future_return_bps": merged["future_return_bps"],
|
||||
"direction_label": merged["direction_label"],
|
||||
"long_target": merged["direction_label"].eq("LONG").astype("int8"),
|
||||
"short_target": merged["direction_label"].eq("SHORT").astype("int8"),
|
||||
"neutral_target": merged["direction_label"].eq("NEUTRAL").astype("int8"),
|
||||
"split_id": merged["split_id"],
|
||||
"walk_forward_fold": merged["walk_forward_fold"],
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
).dropna(subset=["future_return_bps"])
|
||||
path = root / "label" / "direction_labels.parquet"
|
||||
data_hash = write_parquet(path, out)
|
||||
_write_label_manifest(root / "label" / "direction_labels.manifest.json", path, out, data_hash)
|
||||
_write_distribution_report(root / "label" / "direction_label_report.md", out, "direction_label")
|
||||
logging.info("trader.training.direction_labels_written runId=%s rowCount=%s", args.run_id, len(out))
|
||||
|
||||
|
||||
def build_entry_labels(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
plan_path = args.price_plan_context_path or root / "label" / "price_plan_context.json"
|
||||
plan = read_json(plan_path)
|
||||
features, replay = _base_frames(args)
|
||||
entry_conf = labels["entry"]
|
||||
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
|
||||
rows: list[dict[str, Any]] = []
|
||||
groups, index_by_key = _group_replay_with_index(replay)
|
||||
for feature in features.itertuples(index=False):
|
||||
key = (feature.symbol, int(feature.open_time_ms))
|
||||
index = index_by_key.get(key)
|
||||
if index is None:
|
||||
continue
|
||||
group = groups[feature.symbol]
|
||||
for side in ("LONG", "SHORT"):
|
||||
stats = _path_stats(group, index, side, int(entry_conf["max_hold_minutes"]), float(entry_conf["target_bps"]), float(entry_conf["stop_bps"]))
|
||||
if not stats["valid"]:
|
||||
continue
|
||||
expected = stats["gross_edge_bps"] - cost_bps
|
||||
rows.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"side": side,
|
||||
"price_plan_id": plan["pricePlanId"],
|
||||
"price_plan_hash": plan["pricePlanConfigHash"],
|
||||
"target_hit": stats["target_hit"],
|
||||
"stop_hit": stats["stop_hit"],
|
||||
"timeout_hit": stats["timeout_hit"],
|
||||
"ambiguous_hit": stats["ambiguous_hit"],
|
||||
"time_to_target_ms": stats["time_to_target_ms"],
|
||||
"time_to_stop_ms": stats["time_to_stop_ms"],
|
||||
"gross_edge_bps": stats["gross_edge_bps"],
|
||||
"cost_bps": cost_bps,
|
||||
"expected_net_edge_bps": expected,
|
||||
"entry_target": int(stats["target_hit"] == 1 and expected >= float(entry_conf["min_expected_net_edge_bps"])),
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
out = pd.DataFrame(rows)
|
||||
path = root / "label" / "entry_labels.parquet"
|
||||
data_hash = write_parquet(path, out)
|
||||
_write_label_manifest(root / "label" / "entry_labels.manifest.json", path, out, data_hash)
|
||||
_write_distribution_report(root / "label" / "entry_label_report.md", out, "entry_target")
|
||||
logging.info("trader.training.entry_labels_written runId=%s rowCount=%s", args.run_id, len(out))
|
||||
|
||||
|
||||
def build_position_state_samples(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
entry_path = args.entry_label_path or root / "label" / "entry_labels.parquet"
|
||||
entry = read_parquet(entry_path)
|
||||
if entry.empty:
|
||||
raise ValueError("entry labels are required before building position samples")
|
||||
samples = entry[entry["entry_target"] == 1].copy()
|
||||
samples["position_age_minutes"] = 0
|
||||
samples["unrealized_pnl_bps"] = 0.0
|
||||
samples["mfe_bps"] = samples["gross_edge_bps"].clip(lower=0)
|
||||
samples["mae_bps"] = (-samples["gross_edge_bps"]).clip(lower=0)
|
||||
path = root / "label" / "position_state_samples.parquet"
|
||||
data_hash = write_parquet(path, samples)
|
||||
write_json(root / "label" / "position_state_samples.manifest.json", manifest(path, {"row_count": len(samples), "data_hash_sha256": data_hash}))
|
||||
logging.info("trader.training.position_samples_written runId=%s rowCount=%s", args.run_id, len(samples))
|
||||
|
||||
|
||||
def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
plan = read_json(args.price_plan_context_path or root / "label" / "price_plan_context.json")
|
||||
features, replay = _base_frames(args)
|
||||
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
|
||||
horizon = int(labels["continue"]["horizon_minutes"])
|
||||
target_bps = float(plan["targetDistanceBps"])
|
||||
stop_bps = float(plan["stopDistanceBps"])
|
||||
rows_continue: list[dict[str, Any]] = []
|
||||
rows_exit: list[dict[str, Any]] = []
|
||||
rows_risk: list[dict[str, Any]] = []
|
||||
groups, index_by_key = _group_replay_with_index(replay)
|
||||
for feature in features.itertuples(index=False):
|
||||
key = (feature.symbol, int(feature.open_time_ms))
|
||||
index = index_by_key.get(key)
|
||||
if index is None:
|
||||
continue
|
||||
group = groups[feature.symbol]
|
||||
long_stats = _path_stats(group, index, "LONG", horizon, target_bps, stop_bps)
|
||||
short_stats = _path_stats(group, index, "SHORT", horizon, target_bps, stop_bps)
|
||||
if not long_stats["valid"] or not short_stats["valid"]:
|
||||
continue
|
||||
long_edge = long_stats["future_return_bps"] - cost_bps
|
||||
short_edge = short_stats["future_return_bps"] - cost_bps
|
||||
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
|
||||
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
|
||||
rows_continue.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"long_continue_target": int(long_edge >= min_continue and long_stats["mae_bps"] < stop_bps),
|
||||
"short_continue_target": int(short_edge >= min_continue and short_stats["mae_bps"] < stop_bps),
|
||||
"long_expected_continue_edge_bps": long_edge,
|
||||
"short_expected_continue_edge_bps": short_edge,
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
stagnation = int(abs(long_stats["future_return_bps"]) <= float(labels["exit"]["stagnation_abs_return_bps"]))
|
||||
rows_exit.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"long_exit_target": int(long_stats["stop_hit"] == 1 or long_stats["mae_bps"] >= adverse_threshold),
|
||||
"short_exit_target": int(short_stats["stop_hit"] == 1 or short_stats["mae_bps"] >= adverse_threshold),
|
||||
"long_adverse_move_bps": long_stats["mae_bps"],
|
||||
"short_adverse_move_bps": short_stats["mae_bps"],
|
||||
"adverse_move_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= adverse_threshold),
|
||||
"reversal_prob_label": int(np.sign(long_stats["future_return_bps"]) != np.sign(feature.ret_15m_bps) if hasattr(feature, "ret_15m_bps") else 0),
|
||||
"stop_hit_prob_label": int(long_stats["stop_hit"] == 1 or short_stats["stop_hit"] == 1),
|
||||
"stagnation_prob_label": stagnation,
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
path_risk = max(long_stats["mae_bps"], short_stats["mae_bps"])
|
||||
vol_ratio = 0.0 if long_stats["future_realized_vol_bps"] != long_stats["future_realized_vol_bps"] else long_stats["future_realized_vol_bps"]
|
||||
rows_risk.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"market_risk_target": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
|
||||
"market_path_risk_bps": path_risk,
|
||||
"long_position_path_risk_bps": long_stats["mae_bps"],
|
||||
"short_position_path_risk_bps": short_stats["mae_bps"],
|
||||
"long_position_risk_target": int(long_stats["mae_bps"] >= stop_bps),
|
||||
"short_position_risk_target": int(short_stats["mae_bps"] >= stop_bps),
|
||||
"market_drawdown_prob_label": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
|
||||
"volatility_expansion_prob_label": int(vol_ratio >= float(labels["risk"]["spike_bps"])),
|
||||
"spike_prob_label": int(max(long_stats["mfe_bps"], short_stats["mfe_bps"], path_risk) >= float(labels["risk"]["spike_bps"])),
|
||||
"liquidity_deterioration_prob_label": int(long_stats["future_spread_p80"] >= float(replay["spread_bps"].quantile(0.9))),
|
||||
"position_drawdown_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= stop_bps),
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
outputs = [
|
||||
("continue", pd.DataFrame(rows_continue), "long_continue_target"),
|
||||
("exit", pd.DataFrame(rows_exit), "long_exit_target"),
|
||||
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
|
||||
]
|
||||
report_parts = ["# Continue Exit Risk Label Report", ""]
|
||||
for name, frame, target in outputs:
|
||||
path = root / "label" / f"{name}_labels.parquet"
|
||||
data_hash = write_parquet(path, frame)
|
||||
_write_label_manifest(root / "label" / f"{name}_labels.manifest.json", path, frame, data_hash)
|
||||
report_parts.append(f"## {name}")
|
||||
report_parts.append("")
|
||||
report_parts.append(str(frame[target].value_counts(dropna=False).to_dict() if not frame.empty else {}))
|
||||
report_parts.append("")
|
||||
logging.info("trader.training.%s_labels_written runId=%s rowCount=%s", name, args.run_id, len(frame))
|
||||
write_text(root / "label" / "continue_exit_risk_label_report.md", "\n".join(report_parts) + "\n")
|
||||
|
||||
|
||||
def _write_label_manifest(path, parquet_path, frame: pd.DataFrame, data_hash: str) -> None:
|
||||
write_json(path, manifest(parquet_path, {"row_count": len(frame), "label_version": LABEL_VERSION, "data_hash_sha256": data_hash}))
|
||||
|
||||
|
||||
def _write_distribution_report(path, frame: pd.DataFrame, column: str) -> None:
|
||||
counts = frame[column].value_counts(dropna=False).to_dict() if not frame.empty else {}
|
||||
lines = ["# Label Report", "", f"- row_count: {len(frame)}", f"- target_column: {column}", f"- distribution: {counts}", ""]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _group_replay_with_index(replay: pd.DataFrame) -> tuple[dict[str, pd.DataFrame], dict[tuple[str, int], int]]:
|
||||
groups: dict[str, pd.DataFrame] = {}
|
||||
index_by_key: dict[tuple[str, int], int] = {}
|
||||
for symbol, group in replay.groupby("symbol", sort=False):
|
||||
grouped = group.sort_values("event_time").reset_index(drop=True)
|
||||
groups[symbol] = grouped
|
||||
for idx, row in grouped.iterrows():
|
||||
index_by_key[(symbol, int(row["open_time_ms"]))] = idx
|
||||
return groups, index_by_key
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LinearHead:
|
||||
name: str
|
||||
kind: str
|
||||
weight: np.ndarray
|
||||
bias: np.ndarray
|
||||
|
||||
|
||||
def require_onnx():
|
||||
try:
|
||||
import onnx
|
||||
from onnx import TensorProto, helper, numpy_helper
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit("Python package 'onnx' is required. Install training/requirements.txt before export.") from exc
|
||||
return onnx, TensorProto, helper, numpy_helper
|
||||
|
||||
|
||||
def export_heads(path: Path, heads: list[LinearHead], feature_count: int = 39, opset: int = 17) -> None:
|
||||
onnx, TensorProto, helper, numpy_helper = require_onnx()
|
||||
nodes = []
|
||||
initializers = []
|
||||
concat_inputs = []
|
||||
for idx, head in enumerate(heads):
|
||||
weight = np.asarray(head.weight, dtype=np.float32)
|
||||
bias = np.asarray(head.bias, dtype=np.float32).reshape(1, -1)
|
||||
if weight.ndim == 1:
|
||||
weight = weight.reshape(feature_count, 1)
|
||||
weight_name = f"{head.name}_W"
|
||||
bias_name = f"{head.name}_B"
|
||||
linear_name = f"{head.name}_linear"
|
||||
out_name = f"{head.name}_out"
|
||||
initializers.append(numpy_helper.from_array(weight, weight_name))
|
||||
initializers.append(numpy_helper.from_array(bias, bias_name))
|
||||
nodes.append(helper.make_node("MatMul", ["features", weight_name], [f"{linear_name}_mm"], name=f"{head.name}_matmul"))
|
||||
nodes.append(helper.make_node("Add", [f"{linear_name}_mm", bias_name], [linear_name], name=f"{head.name}_add"))
|
||||
if head.kind == "sigmoid":
|
||||
nodes.append(helper.make_node("Sigmoid", [linear_name], [out_name], name=f"{head.name}_sigmoid"))
|
||||
elif head.kind == "softmax":
|
||||
nodes.append(helper.make_node("Softmax", [linear_name], [out_name], name=f"{head.name}_softmax", axis=1))
|
||||
elif head.kind == "identity":
|
||||
out_name = linear_name
|
||||
else:
|
||||
raise ValueError(f"unsupported ONNX head kind: {head.kind}")
|
||||
concat_inputs.append(out_name)
|
||||
if len(concat_inputs) == 1:
|
||||
nodes.append(helper.make_node("Identity", concat_inputs, ["prediction"], name="prediction_identity"))
|
||||
else:
|
||||
nodes.append(helper.make_node("Concat", concat_inputs, ["prediction"], name="prediction_concat", axis=1))
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"trader_v4_linear_heads",
|
||||
[helper.make_tensor_value_info("features", TensorProto.FLOAT, [1, feature_count])],
|
||||
[helper.make_tensor_value_info("prediction", TensorProto.FLOAT, [1, sum(_head_width(head) for head in heads)])],
|
||||
initializer=initializers,
|
||||
)
|
||||
model = helper.make_model(graph, producer_name="trader-training", opset_imports=[helper.make_opsetid("", opset)])
|
||||
model.ir_version = 10
|
||||
onnx.checker.check_model(model)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
onnx.save(model, path)
|
||||
|
||||
|
||||
def _head_width(head: LinearHead) -> int:
|
||||
bias = np.asarray(head.bias)
|
||||
return int(bias.size)
|
||||
@@ -0,0 +1,541 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import read_json, read_parquet, run_root, sha256_json, write_json, write_parquet, write_text
|
||||
from trader_training.schemas import LATEST_STRESS_SPLIT, PM_CONFIG_VERSION, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
def default_pm_config() -> dict:
|
||||
return {
|
||||
"pmConfigVersion": PM_CONFIG_VERSION,
|
||||
"open": {
|
||||
"longOpenProb": 0.58,
|
||||
"shortOpenProb": 0.58,
|
||||
"minLongEntryProb": 0.55,
|
||||
"minShortEntryProb": 0.55,
|
||||
"maxMarketRiskProb": 0.45,
|
||||
"minExpectedEdgeBps": 3.0,
|
||||
"minDirectionMargin": 0.03,
|
||||
"minLiquidityCapacityRatio": 0.10,
|
||||
"maxOodScore": 0.80,
|
||||
},
|
||||
"add": {
|
||||
"minLongProb": 0.60,
|
||||
"minShortProb": 0.60,
|
||||
"minContinueProb": 0.58,
|
||||
"minEntryProb": 0.55,
|
||||
"maxExitProb": 0.45,
|
||||
"maxMarketRiskProb": 0.45,
|
||||
"maxPositionRiskProb": 0.50,
|
||||
"minExpectedEdgeBps": 3.0,
|
||||
"minContinueVsExitEdgeBps": 0.0,
|
||||
"minLiquidityCapacityRatio": 0.10,
|
||||
"minPostTradeLiquidationBufferBps": 500.0,
|
||||
"maxAddCount": 3,
|
||||
"cooldownMinutes": 5,
|
||||
},
|
||||
"exit": {
|
||||
"closeExitProb": 0.70,
|
||||
"closePositionRiskProb": 0.70,
|
||||
"closeMarketRiskProb": 0.70,
|
||||
"closeContinueMax": 0.25,
|
||||
"reduceAdverseMoveProb": 0.62,
|
||||
"reduceContinueMin": 0.35,
|
||||
"reduceContinueMax": 0.70,
|
||||
"minProfitForReduceBps": 5.0,
|
||||
"maxPositionPathRiskBps": 80.0,
|
||||
},
|
||||
"sizing": {
|
||||
"baseRatio": 0.80,
|
||||
"minInitialRatio": 0.05,
|
||||
"maxSingleLegRatio": 1.0,
|
||||
"minAddRatio": 0.02,
|
||||
"maxAddRatio": 0.25,
|
||||
"maxTotalPositionRatio": 1.0,
|
||||
"minEdgeBps": 3.0,
|
||||
"maxLossPerTradeBps": 80.0,
|
||||
"maxLiquidityUsageRatio": 0.20,
|
||||
"uncertaintyPenaltyMultiplier": 0.50,
|
||||
"minPostTradeLiquidationBufferBps": 500.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def search_pm_thresholds(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
frame = _pm_tune_frame(root)
|
||||
candidate_rows: list[dict[str, Any]] = []
|
||||
best_score = -float("inf")
|
||||
best_thresholds: dict[str, float] | None = None
|
||||
best_metrics: dict[str, Any] | None = None
|
||||
best_trades = pd.DataFrame()
|
||||
|
||||
for thresholds in _threshold_candidates():
|
||||
trades = _simulate_open_trades(frame, thresholds)
|
||||
metrics = _trade_metrics(trades)
|
||||
score = _score_thresholds(metrics)
|
||||
candidate_rows.append({**thresholds, **metrics, "score": score})
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_thresholds = thresholds
|
||||
best_metrics = metrics
|
||||
best_trades = trades
|
||||
|
||||
if best_thresholds is None or best_metrics is None:
|
||||
raise ValueError("PM threshold search did not evaluate any candidate")
|
||||
|
||||
config = _pm_config_from_thresholds(best_thresholds)
|
||||
threshold_stability = {
|
||||
"source": "tune_predictions_and_entry_labels",
|
||||
"method": "deterministic_grid_search_v1",
|
||||
"candidate_count": len(candidate_rows),
|
||||
"best_score": best_score,
|
||||
"best_metrics": best_metrics,
|
||||
}
|
||||
payload = {
|
||||
"pm_config_version": PM_CONFIG_VERSION,
|
||||
"config": config,
|
||||
"config_hash_sha256": sha256_json(config),
|
||||
"threshold_stability_json": threshold_stability,
|
||||
}
|
||||
candidate_frame = pd.DataFrame(candidate_rows).sort_values("score", ascending=False).reset_index(drop=True)
|
||||
equity_curve = _equity_curve(best_trades)
|
||||
regime_metrics = _regime_metrics(best_trades)
|
||||
write_json(root / "pm-search" / "position_manager_config.json", payload)
|
||||
write_json(root / "pm-search" / "pm_threshold_config.json", payload)
|
||||
write_text(root / "pm-search" / "pm_search_candidates.csv", candidate_frame.to_csv(index=False))
|
||||
write_parquet(root / "pm-search" / "pm_backtest_trades.parquet", best_trades)
|
||||
write_text(root / "pm-search" / "pm_equity_curve.csv", equity_curve.to_csv(index=False))
|
||||
write_text(root / "pm-search" / "pm_regime_metrics.csv", regime_metrics.to_csv(index=False))
|
||||
_write_pm_report(root / "pm-search" / "pm_threshold_report.md", candidate_frame, best_thresholds, best_metrics)
|
||||
_write_pm_report(root / "pm-search" / "pm_search_report.md", candidate_frame, best_thresholds, best_metrics)
|
||||
logging.info(
|
||||
"trader.training.pm_thresholds_searched runId=%s candidateCount=%s bestScore=%.6f tradeCount=%s totalWeightedEdgeBps=%.6f",
|
||||
args.run_id,
|
||||
len(candidate_rows),
|
||||
best_score,
|
||||
best_metrics["trade_count"],
|
||||
best_metrics["total_weighted_edge_bps"],
|
||||
)
|
||||
|
||||
|
||||
def integrated_backtest(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
config_path = root / "pm-search" / "position_manager_config.json"
|
||||
if not config_path.is_file():
|
||||
raise FileNotFoundError(f"PM config is required before backtest: {config_path}")
|
||||
pm_payload = read_json(config_path)
|
||||
trades_path = root / "pm-search" / "pm_backtest_trades.parquet"
|
||||
# PM search is allowed to use tune_inner, but final acceptance must be
|
||||
# measured on the sealed validation_locked and latest_stress splits.
|
||||
tune_trades = read_parquet(trades_path) if trades_path.is_file() else _simulate_open_trades(_pm_tune_frame(root), _thresholds_from_config(pm_payload["config"]))
|
||||
tune_trades["eval_split"] = TUNE_SPLIT
|
||||
validation_locked_trades = _simulate_open_trades(_pm_frame(root, VALIDATION_LOCKED_SPLIT), _thresholds_from_config(pm_payload["config"]))
|
||||
validation_locked_trades["eval_split"] = VALIDATION_LOCKED_SPLIT
|
||||
stress_trades = _simulate_open_trades(_pm_frame(root, LATEST_STRESS_SPLIT), _thresholds_from_config(pm_payload["config"]))
|
||||
stress_trades["eval_split"] = LATEST_STRESS_SPLIT
|
||||
trades = pd.concat([tune_trades, validation_locked_trades, stress_trades], ignore_index=True)
|
||||
metrics = {
|
||||
TUNE_SPLIT: _trade_metrics(tune_trades),
|
||||
VALIDATION_LOCKED_SPLIT: _trade_metrics(validation_locked_trades),
|
||||
LATEST_STRESS_SPLIT: _trade_metrics(stress_trades),
|
||||
"combined": _trade_metrics(trades),
|
||||
}
|
||||
status, status_reasons = _backtest_status(metrics)
|
||||
equity_curve = _equity_curve(trades)
|
||||
regime_metrics = _regime_metrics(trades)
|
||||
result = {
|
||||
"backtest_manifest_id": f"backtest-{args.run_id}",
|
||||
"mode": "VALIDATION_PM_BACKTEST",
|
||||
"pm_config_hash_sha256": pm_payload["config_hash_sha256"],
|
||||
"metrics": metrics,
|
||||
"status_reasons": status_reasons,
|
||||
"status": status,
|
||||
}
|
||||
write_json(root / "backtest" / "backtest_manifest.json", result)
|
||||
write_parquet(root / "backtest" / "backtest_trades.parquet", trades)
|
||||
write_text(root / "backtest" / "equity_curve.csv", equity_curve.to_csv(index=False))
|
||||
write_text(root / "backtest" / "regime_metrics.csv", regime_metrics.to_csv(index=False))
|
||||
_write_backtest_report(root / "backtest" / "backtest_report.md", result)
|
||||
_write_failure_cases(root / "backtest" / "failure_cases.md", trades)
|
||||
_write_no_baseline_ablation(root / "backtest" / "direction_ablation_backtest_report.md")
|
||||
logging.info(
|
||||
"trader.training.backtest_written runId=%s status=%s tradeCount=%s totalWeightedEdgeBps=%.6f maxDrawdownBps=%.6f",
|
||||
args.run_id,
|
||||
status,
|
||||
metrics[VALIDATION_LOCKED_SPLIT]["trade_count"],
|
||||
metrics[VALIDATION_LOCKED_SPLIT]["total_weighted_edge_bps"],
|
||||
metrics[VALIDATION_LOCKED_SPLIT]["max_drawdown_bps"],
|
||||
)
|
||||
|
||||
|
||||
def _pm_tune_frame(root) -> pd.DataFrame:
|
||||
return _pm_frame(root, TUNE_SPLIT)
|
||||
|
||||
|
||||
def _pm_frame(root, split_id: str) -> pd.DataFrame:
|
||||
prediction_files = {
|
||||
TUNE_SPLIT: "tune_predictions.parquet",
|
||||
VALIDATION_LOCKED_SPLIT: "validation_locked_predictions.parquet",
|
||||
LATEST_STRESS_SPLIT: "latest_stress_predictions.parquet",
|
||||
}
|
||||
prediction_file = prediction_files[split_id]
|
||||
direction = read_parquet(root / "model" / "direction" / prediction_file)
|
||||
entry = read_parquet(root / "model" / "entry" / prediction_file).rename(
|
||||
columns={
|
||||
"long_expected_net_edge_bps": "pred_long_expected_net_edge_bps",
|
||||
"short_expected_net_edge_bps": "pred_short_expected_net_edge_bps",
|
||||
}
|
||||
)
|
||||
risk = read_parquet(root / "model" / "risk" / prediction_file)
|
||||
entry_dataset = read_parquet(root / "dataset" / "entry_train.parquet").rename(
|
||||
columns={
|
||||
"long_expected_net_edge_bps": "actual_long_expected_net_edge_bps",
|
||||
"short_expected_net_edge_bps": "actual_short_expected_net_edge_bps",
|
||||
}
|
||||
)
|
||||
entry_cols = [
|
||||
"sample_id",
|
||||
"long_entry_prob",
|
||||
"short_entry_prob",
|
||||
"pred_long_expected_net_edge_bps",
|
||||
"pred_short_expected_net_edge_bps",
|
||||
]
|
||||
risk_cols = ["sample_id", "market_risk_prob", "long_position_risk_prob", "short_position_risk_prob"]
|
||||
actual_cols = ["sample_id", "actual_long_expected_net_edge_bps", "actual_short_expected_net_edge_bps", "long_entry_target", "short_entry_target"]
|
||||
frame = (
|
||||
direction[["sample_id", "symbol", "event_time", "split_id", "long_prob", "short_prob", "neutral_prob"]]
|
||||
.merge(entry[entry_cols], on="sample_id", how="inner")
|
||||
.merge(risk[risk_cols], on="sample_id", how="inner")
|
||||
.merge(entry_dataset[actual_cols], on="sample_id", how="inner")
|
||||
)
|
||||
if frame.empty:
|
||||
raise ValueError(f"PM frame is empty for {split_id}; check model predictions and entry dataset")
|
||||
logging.info(
|
||||
"trader.training.pm_frame_loaded splitId=%s rowCount=%s splitCounts=%s",
|
||||
split_id,
|
||||
len(frame),
|
||||
frame["split_id"].value_counts().to_dict(),
|
||||
)
|
||||
return frame
|
||||
|
||||
|
||||
def _threshold_candidates() -> list[dict[str, float]]:
|
||||
values = itertools.product(
|
||||
[0.54, 0.56, 0.58, 0.60],
|
||||
[0.54, 0.56, 0.58, 0.60],
|
||||
[0.50, 0.52, 0.55, 0.58],
|
||||
[0.35, 0.45, 0.55],
|
||||
[1.0, 2.0, 3.0, 5.0],
|
||||
[0.02, 0.03, 0.05],
|
||||
)
|
||||
return [
|
||||
{
|
||||
"long_open_prob": long_prob,
|
||||
"short_open_prob": short_prob,
|
||||
"min_entry_prob": entry_prob,
|
||||
"max_market_risk_prob": risk_prob,
|
||||
"min_expected_edge_bps": edge_bps,
|
||||
"min_direction_margin": margin,
|
||||
}
|
||||
for long_prob, short_prob, entry_prob, risk_prob, edge_bps, margin in values
|
||||
]
|
||||
|
||||
|
||||
def _simulate_open_trades(frame: pd.DataFrame, thresholds: dict[str, float]) -> pd.DataFrame:
|
||||
long_mask = (
|
||||
(frame["long_prob"] >= thresholds["long_open_prob"])
|
||||
& ((frame["long_prob"] - frame["short_prob"]) >= thresholds["min_direction_margin"])
|
||||
& (frame["long_entry_prob"] >= thresholds["min_entry_prob"])
|
||||
& (frame["market_risk_prob"] <= thresholds["max_market_risk_prob"])
|
||||
& (frame["pred_long_expected_net_edge_bps"] >= thresholds["min_expected_edge_bps"])
|
||||
)
|
||||
short_mask = (
|
||||
(frame["short_prob"] >= thresholds["short_open_prob"])
|
||||
& ((frame["short_prob"] - frame["long_prob"]) >= thresholds["min_direction_margin"])
|
||||
& (frame["short_entry_prob"] >= thresholds["min_entry_prob"])
|
||||
& (frame["market_risk_prob"] <= thresholds["max_market_risk_prob"])
|
||||
& (frame["pred_short_expected_net_edge_bps"] >= thresholds["min_expected_edge_bps"])
|
||||
)
|
||||
long_score = frame["pred_long_expected_net_edge_bps"] + (frame["long_prob"] - frame["short_prob"]) * 10.0
|
||||
short_score = frame["pred_short_expected_net_edge_bps"] + (frame["short_prob"] - frame["long_prob"]) * 10.0
|
||||
side = np.where(long_mask & (~short_mask | (long_score >= short_score)), "LONG", np.where(short_mask, "SHORT", ""))
|
||||
trades = frame.loc[side != ""].copy().reset_index(drop=True)
|
||||
if trades.empty:
|
||||
return _empty_trade_frame()
|
||||
trades["side"] = side[side != ""]
|
||||
is_long = trades["side"].eq("LONG")
|
||||
trades["direction_prob"] = np.where(is_long, trades["long_prob"], trades["short_prob"])
|
||||
trades["entry_prob"] = np.where(is_long, trades["long_entry_prob"], trades["short_entry_prob"])
|
||||
trades["predicted_edge_bps"] = np.where(is_long, trades["pred_long_expected_net_edge_bps"], trades["pred_short_expected_net_edge_bps"])
|
||||
trades["actual_edge_bps"] = np.where(is_long, trades["actual_long_expected_net_edge_bps"], trades["actual_short_expected_net_edge_bps"])
|
||||
trades["entry_target"] = np.where(is_long, trades["long_entry_target"], trades["short_entry_target"])
|
||||
trades["planned_ratio"] = _planned_ratio(trades["predicted_edge_bps"], trades["market_risk_prob"], thresholds["min_expected_edge_bps"])
|
||||
trades["weighted_edge_bps"] = trades["actual_edge_bps"] * trades["planned_ratio"]
|
||||
trades["threshold_hash"] = sha256_json(thresholds)[:16]
|
||||
return trades[
|
||||
[
|
||||
"sample_id",
|
||||
"symbol",
|
||||
"event_time",
|
||||
"split_id",
|
||||
"side",
|
||||
"direction_prob",
|
||||
"entry_prob",
|
||||
"market_risk_prob",
|
||||
"predicted_edge_bps",
|
||||
"actual_edge_bps",
|
||||
"entry_target",
|
||||
"planned_ratio",
|
||||
"weighted_edge_bps",
|
||||
"threshold_hash",
|
||||
]
|
||||
].sort_values("event_time")
|
||||
|
||||
|
||||
def _empty_trade_frame() -> pd.DataFrame:
|
||||
return pd.DataFrame(
|
||||
columns=[
|
||||
"sample_id",
|
||||
"symbol",
|
||||
"event_time",
|
||||
"split_id",
|
||||
"side",
|
||||
"direction_prob",
|
||||
"entry_prob",
|
||||
"market_risk_prob",
|
||||
"predicted_edge_bps",
|
||||
"actual_edge_bps",
|
||||
"entry_target",
|
||||
"planned_ratio",
|
||||
"weighted_edge_bps",
|
||||
"threshold_hash",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _planned_ratio(predicted_edge: pd.Series, market_risk: pd.Series, min_edge: float) -> np.ndarray:
|
||||
edge_strength = ((predicted_edge.astype(float) - min_edge) / 20.0).clip(lower=0.0, upper=1.5)
|
||||
risk_discount = (1.0 - market_risk.astype(float)).clip(lower=0.0, upper=1.0)
|
||||
return (edge_strength * risk_discount).clip(lower=0.05, upper=1.0).to_numpy()
|
||||
|
||||
|
||||
def _trade_metrics(trades: pd.DataFrame) -> dict[str, Any]:
|
||||
if trades.empty:
|
||||
return {
|
||||
"trade_count": 0,
|
||||
"win_rate": 0.0,
|
||||
"avg_actual_edge_bps": 0.0,
|
||||
"avg_weighted_edge_bps": 0.0,
|
||||
"total_weighted_edge_bps": 0.0,
|
||||
"max_drawdown_bps": 0.0,
|
||||
"avg_planned_ratio": 0.0,
|
||||
"profit_factor": 0.0,
|
||||
"max_consecutive_losses": 0,
|
||||
}
|
||||
equity = trades["weighted_edge_bps"].astype(float).cumsum()
|
||||
drawdown = equity.cummax() - equity
|
||||
gains = trades.loc[trades["weighted_edge_bps"] > 0, "weighted_edge_bps"].astype(float).sum()
|
||||
losses = -trades.loc[trades["weighted_edge_bps"] < 0, "weighted_edge_bps"].astype(float).sum()
|
||||
return {
|
||||
"trade_count": int(len(trades)),
|
||||
"win_rate": float((trades["actual_edge_bps"].astype(float) > 0).mean()),
|
||||
"avg_actual_edge_bps": float(trades["actual_edge_bps"].astype(float).mean()),
|
||||
"avg_weighted_edge_bps": float(trades["weighted_edge_bps"].astype(float).mean()),
|
||||
"total_weighted_edge_bps": float(equity.iloc[-1]),
|
||||
"max_drawdown_bps": float(drawdown.max()),
|
||||
"avg_planned_ratio": float(trades["planned_ratio"].astype(float).mean()),
|
||||
"profit_factor": float(gains / losses) if losses > 0 else float("inf"),
|
||||
"max_consecutive_losses": _max_consecutive_losses(trades["weighted_edge_bps"].astype(float).to_numpy()),
|
||||
}
|
||||
|
||||
|
||||
def _max_consecutive_losses(values: np.ndarray) -> int:
|
||||
max_count = 0
|
||||
current = 0
|
||||
for value in values:
|
||||
if value < 0:
|
||||
current += 1
|
||||
max_count = max(max_count, current)
|
||||
else:
|
||||
current = 0
|
||||
return max_count
|
||||
|
||||
|
||||
def _backtest_status(metrics: dict[str, dict[str, Any]]) -> tuple[str, list[str]]:
|
||||
reasons: list[str] = []
|
||||
validation_locked = metrics[VALIDATION_LOCKED_SPLIT]
|
||||
stress = metrics[LATEST_STRESS_SPLIT]
|
||||
if validation_locked["total_weighted_edge_bps"] <= 0:
|
||||
reasons.append("validation_locked_net_edge_not_positive")
|
||||
if validation_locked["trade_count"] < 80:
|
||||
reasons.append("validation_locked_trade_count_below_80")
|
||||
if validation_locked["profit_factor"] < 1.15:
|
||||
reasons.append("validation_locked_profit_factor_below_1.15")
|
||||
if validation_locked["avg_weighted_edge_bps"] <= 0:
|
||||
reasons.append("validation_locked_avg_trade_edge_not_positive")
|
||||
if validation_locked["max_consecutive_losses"] > 8:
|
||||
reasons.append("validation_locked_max_consecutive_losses_above_8")
|
||||
if stress["trade_count"] < 20:
|
||||
reasons.append("latest_stress_trade_count_below_20")
|
||||
if stress["profit_factor"] < 1.0:
|
||||
reasons.append("latest_stress_profit_factor_below_1.0")
|
||||
if stress["avg_weighted_edge_bps"] < -3.0:
|
||||
reasons.append("latest_stress_avg_trade_edge_below_minus_3")
|
||||
if stress["max_consecutive_losses"] > 10:
|
||||
reasons.append("latest_stress_max_consecutive_losses_above_10")
|
||||
if validation_locked["total_weighted_edge_bps"] > 0 and stress["total_weighted_edge_bps"] < -0.5 * validation_locked["total_weighted_edge_bps"]:
|
||||
reasons.append("latest_stress_loss_too_large_vs_validation")
|
||||
return ("REJECTED", reasons) if reasons else ("PASS", [])
|
||||
|
||||
|
||||
def _score_thresholds(metrics: dict[str, Any]) -> float:
|
||||
if metrics["trade_count"] == 0:
|
||||
return -1_000_000.0
|
||||
low_sample_penalty = max(0, 20 - int(metrics["trade_count"])) * 5.0
|
||||
return (
|
||||
metrics["avg_weighted_edge_bps"] * np.sqrt(metrics["trade_count"])
|
||||
+ metrics["total_weighted_edge_bps"] * 0.05
|
||||
- metrics["max_drawdown_bps"] * 0.25
|
||||
- low_sample_penalty
|
||||
)
|
||||
|
||||
|
||||
def _pm_config_from_thresholds(thresholds: dict[str, float]) -> dict:
|
||||
config = default_pm_config()
|
||||
config["open"].update(
|
||||
{
|
||||
"longOpenProb": thresholds["long_open_prob"],
|
||||
"shortOpenProb": thresholds["short_open_prob"],
|
||||
"minLongEntryProb": thresholds["min_entry_prob"],
|
||||
"minShortEntryProb": thresholds["min_entry_prob"],
|
||||
"maxMarketRiskProb": thresholds["max_market_risk_prob"],
|
||||
"minExpectedEdgeBps": thresholds["min_expected_edge_bps"],
|
||||
"minDirectionMargin": thresholds["min_direction_margin"],
|
||||
}
|
||||
)
|
||||
config["add"]["maxMarketRiskProb"] = thresholds["max_market_risk_prob"]
|
||||
config["add"]["minExpectedEdgeBps"] = thresholds["min_expected_edge_bps"]
|
||||
config["sizing"]["minEdgeBps"] = thresholds["min_expected_edge_bps"]
|
||||
config["sizing"]["maxSingleLegRatio"] = 1.0
|
||||
return config
|
||||
|
||||
|
||||
def _thresholds_from_config(config: dict) -> dict[str, float]:
|
||||
open_config = config["open"]
|
||||
return {
|
||||
"long_open_prob": float(open_config["longOpenProb"]),
|
||||
"short_open_prob": float(open_config["shortOpenProb"]),
|
||||
"min_entry_prob": float(min(open_config["minLongEntryProb"], open_config["minShortEntryProb"])),
|
||||
"max_market_risk_prob": float(open_config["maxMarketRiskProb"]),
|
||||
"min_expected_edge_bps": float(open_config["minExpectedEdgeBps"]),
|
||||
"min_direction_margin": float(open_config["minDirectionMargin"]),
|
||||
}
|
||||
|
||||
|
||||
def _equity_curve(trades: pd.DataFrame) -> pd.DataFrame:
|
||||
if trades.empty:
|
||||
return pd.DataFrame(columns=["event_time", "trade_index", "weighted_edge_bps", "equity_bps", "drawdown_bps"])
|
||||
curve = trades[["event_time", "weighted_edge_bps"]].copy().reset_index(drop=True)
|
||||
curve["trade_index"] = np.arange(1, len(curve) + 1)
|
||||
curve["equity_bps"] = curve["weighted_edge_bps"].astype(float).cumsum()
|
||||
curve["drawdown_bps"] = curve["equity_bps"].cummax() - curve["equity_bps"]
|
||||
return curve[["event_time", "trade_index", "weighted_edge_bps", "equity_bps", "drawdown_bps"]]
|
||||
|
||||
|
||||
def _regime_metrics(trades: pd.DataFrame) -> pd.DataFrame:
|
||||
if trades.empty:
|
||||
return pd.DataFrame(columns=["split_id", "side", "trade_count", "win_rate", "avg_actual_edge_bps", "total_weighted_edge_bps"])
|
||||
rows = []
|
||||
for (split_id, side), group in trades.groupby(["split_id", "side"], sort=True):
|
||||
metrics = _trade_metrics(group)
|
||||
rows.append(
|
||||
{
|
||||
"split_id": split_id,
|
||||
"side": side,
|
||||
"trade_count": metrics["trade_count"],
|
||||
"win_rate": metrics["win_rate"],
|
||||
"avg_actual_edge_bps": metrics["avg_actual_edge_bps"],
|
||||
"total_weighted_edge_bps": metrics["total_weighted_edge_bps"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def _write_pm_report(path, candidates: pd.DataFrame, best_thresholds: dict[str, float], best_metrics: dict[str, Any]) -> None:
|
||||
top = candidates.head(10)
|
||||
lines = [
|
||||
"# PM Threshold Report",
|
||||
"",
|
||||
"本次不是固定写死阈值,而是在验证集上试一组可复现的阈值,选择净收益、回撤、交易数量综合更好的那组。",
|
||||
"",
|
||||
"## Best Thresholds",
|
||||
"",
|
||||
"```json",
|
||||
str(best_thresholds).replace("'", '"'),
|
||||
"```",
|
||||
"",
|
||||
"## Best Metrics",
|
||||
"",
|
||||
"```json",
|
||||
str(best_metrics).replace("'", '"'),
|
||||
"```",
|
||||
"",
|
||||
"## Top Candidates",
|
||||
"",
|
||||
_markdown_table(top.to_dict("records"), list(top.columns)),
|
||||
"",
|
||||
]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _write_backtest_report(path, result: dict[str, Any]) -> None:
|
||||
lines = [
|
||||
"# Integrated Backtest Report",
|
||||
"",
|
||||
"这里用验证集模型输出和 PM 阈值生成交易明细,统计净收益、胜率、回撤和分段表现。",
|
||||
"",
|
||||
"```json",
|
||||
str(result).replace("'", '"'),
|
||||
"```",
|
||||
"",
|
||||
]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _write_failure_cases(path, trades: pd.DataFrame) -> None:
|
||||
worst = trades.sort_values("weighted_edge_bps").head(20) if not trades.empty else trades
|
||||
lines = [
|
||||
"# Backtest Failure Cases",
|
||||
"",
|
||||
"按加权净收益从差到好列出最差样本,方便回看特征、标签和阈值。",
|
||||
"",
|
||||
_markdown_table(worst.to_dict("records"), list(worst.columns)) if not worst.empty else "无交易样本。",
|
||||
"",
|
||||
]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _write_no_baseline_ablation(path) -> None:
|
||||
lines = [
|
||||
"# Direction Ablation Backtest Report",
|
||||
"",
|
||||
"- status: NO_BASELINE",
|
||||
"- reason: 当前 run 目录没有旧 Direction 基准模型包,所以首版不能做只替换 Direction 的消融回测。",
|
||||
"- action: 后续版本必须拿上一版 ACTIVE 包做 baseline,再比较新 Direction 是否真的提升。",
|
||||
"",
|
||||
]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _markdown_table(rows: list[dict[str, Any]], columns: list[str]) -> str:
|
||||
lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"]
|
||||
for row in rows:
|
||||
lines.append("| " + " | ".join(str(row.get(column, "")) for column in columns) + " |")
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from trader_training.io_utils import read_json, utc_now_text, write_json, write_text
|
||||
|
||||
|
||||
def promote_artifact_bundle(args: Any) -> None:
|
||||
artifact_root = args.artifact_root
|
||||
validation_path = artifact_root.parent / "artifact_validation_result.json"
|
||||
validation = read_json(validation_path)
|
||||
if validation.get("status") != "PASS":
|
||||
_refuse(artifact_root, args.reason, "validation result is not PASS", validation)
|
||||
if validation.get("release_gate_status") != "PASS":
|
||||
_refuse(artifact_root, args.reason, "release gate is not PASS", validation)
|
||||
|
||||
promotion = {
|
||||
"promoted_at": utc_now_text(),
|
||||
"reason": args.reason,
|
||||
"validation_result_path": str(validation_path),
|
||||
}
|
||||
|
||||
bundle_path = artifact_root / "manifests" / "model_bundle_manifest.json"
|
||||
bundle = read_json(bundle_path)
|
||||
|
||||
model_manifest_path = artifact_root / "manifests" / "model_manifest.json"
|
||||
model_rows = read_json(model_manifest_path)
|
||||
|
||||
calibration_manifest_path = artifact_root / "manifests" / "calibration_manifest.json"
|
||||
calibration_rows = read_json(calibration_manifest_path)
|
||||
|
||||
pm_manifest_path = artifact_root / "manifests" / "position_manager_manifest.json"
|
||||
pm_manifest = read_json(pm_manifest_path)
|
||||
|
||||
export_manifest_path = artifact_root / "manifests" / "training_export_manifest.json"
|
||||
export_manifest = read_json(export_manifest_path)
|
||||
|
||||
# Check every manifest before writing any ACTIVE status, so a bad bundle
|
||||
# cannot be left half-promoted if one file fails late.
|
||||
_require_candidate("model_bundle_manifest", bundle.get("status"))
|
||||
for row in model_rows:
|
||||
_require_candidate(f"model_manifest.{row.get('model_name')}", row.get("status"))
|
||||
for row in calibration_rows:
|
||||
_require_candidate(f"calibration_manifest.{row.get('model_name')}", row.get("status"))
|
||||
_require_candidate("position_manager_manifest", pm_manifest.get("status"))
|
||||
_require_candidate("training_export_manifest", export_manifest.get("status"))
|
||||
|
||||
bundle["status"] = "ACTIVE"
|
||||
bundle["promotion_json"] = promotion
|
||||
for row in model_rows:
|
||||
row["status"] = "ACTIVE"
|
||||
row["promotion_json"] = promotion
|
||||
for row in calibration_rows:
|
||||
row["status"] = "ACTIVE"
|
||||
row["promotion_json"] = promotion
|
||||
pm_manifest["status"] = "ACTIVE"
|
||||
pm_manifest["promotion_json"] = promotion
|
||||
export_manifest["status"] = "ACTIVE"
|
||||
export_manifest["promotion_json"] = promotion
|
||||
|
||||
write_json(bundle_path, bundle)
|
||||
write_json(model_manifest_path, model_rows)
|
||||
write_json(calibration_manifest_path, calibration_rows)
|
||||
write_json(pm_manifest_path, pm_manifest)
|
||||
write_json(export_manifest_path, export_manifest)
|
||||
|
||||
result = {"status": "ACTIVE", "artifact_root": str(artifact_root), "promotion": promotion}
|
||||
write_json(artifact_root.parent / "artifact_promotion_result.json", result)
|
||||
write_text(
|
||||
artifact_root.parent / "artifact_promotion_report.md",
|
||||
"\n".join(
|
||||
[
|
||||
"# Artifact Promotion Report",
|
||||
"",
|
||||
"- status: ACTIVE",
|
||||
f"- artifact_root: {artifact_root}",
|
||||
f"- reason: {args.reason}",
|
||||
f"- promoted_at: {promotion['promoted_at']}",
|
||||
"",
|
||||
]
|
||||
),
|
||||
)
|
||||
logging.info("trader.training.artifact_promoted status=ACTIVE path=%s reason=%s", artifact_root, args.reason)
|
||||
|
||||
|
||||
def _require_candidate(name: str, status: str | None) -> None:
|
||||
if status != "CANDIDATE":
|
||||
raise SystemExit(f"artifact promotion refused: {name} status must be CANDIDATE, actual={status}")
|
||||
|
||||
|
||||
def _refuse(artifact_root: Any, reason: str, message: str, validation: dict[str, Any]) -> None:
|
||||
result = {
|
||||
"status": "REFUSED",
|
||||
"artifact_root": str(artifact_root),
|
||||
"reason": reason,
|
||||
"message": message,
|
||||
"validation_status": validation.get("status"),
|
||||
"release_gate_status": validation.get("release_gate_status"),
|
||||
"release_gate_reasons": validation.get("release_gate_reasons", []),
|
||||
"refused_at": utc_now_text(),
|
||||
}
|
||||
write_json(artifact_root.parent / "artifact_promotion_result.json", result)
|
||||
write_text(
|
||||
artifact_root.parent / "artifact_promotion_report.md",
|
||||
"\n".join(
|
||||
[
|
||||
"# Artifact Promotion Report",
|
||||
"",
|
||||
"- status: REFUSED",
|
||||
f"- artifact_root: {artifact_root}",
|
||||
f"- reason: {reason}",
|
||||
f"- message: {message}",
|
||||
f"- validation_status: {result['validation_status']}",
|
||||
f"- release_gate_status: {result['release_gate_status']}",
|
||||
f"- release_gate_reasons: {result['release_gate_reasons']}",
|
||||
f"- refused_at: {result['refused_at']}",
|
||||
"",
|
||||
]
|
||||
),
|
||||
)
|
||||
logging.warning(
|
||||
"trader.training.artifact_promotion_refused status=REFUSED path=%s reason=%s message=%s releaseGate=%s",
|
||||
artifact_root,
|
||||
reason,
|
||||
message,
|
||||
result["release_gate_status"],
|
||||
)
|
||||
raise SystemExit(f"artifact promotion refused: {message}, reasons={result['release_gate_reasons']}")
|
||||
@@ -0,0 +1,496 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import (
|
||||
DEFAULT_RAW_ROOT,
|
||||
ensure_dir,
|
||||
manifest,
|
||||
open_time_ms,
|
||||
partition_files,
|
||||
read_json,
|
||||
read_partitioned_table,
|
||||
read_parquet,
|
||||
require_columns,
|
||||
run_root,
|
||||
to_utc_series,
|
||||
utc_now_text,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_text,
|
||||
)
|
||||
from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, SPLIT_VERSION, TRAINING_SPLITS, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
def audit_source_data(data_root: Path, symbol: str, start_date: str | None, end_date: str | None, min_ready_days: int = 250) -> dict[str, Any]:
|
||||
raw_root = data_root / "crypto-lake" / "raw"
|
||||
required_tables = ("candles", "trades", "level_1", "funding", "open_interest")
|
||||
optional_tables = ("liquidations",)
|
||||
rows: list[dict[str, Any]] = []
|
||||
table_dates: dict[str, set[str]] = {}
|
||||
for table in required_tables + optional_tables:
|
||||
files = partition_files(raw_root, table, symbol, start_date, end_date)
|
||||
dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files})
|
||||
table_dates[table] = set(dates)
|
||||
rows.append(
|
||||
{
|
||||
"table": table,
|
||||
"required": table in required_tables,
|
||||
"file_count": len(files),
|
||||
"first_date": dates[0] if dates else None,
|
||||
"last_date": dates[-1] if dates else None,
|
||||
"status": "OK" if files or table in optional_tables else "MISSING",
|
||||
}
|
||||
)
|
||||
all_dates = _audit_date_range(table_dates, required_tables, start_date, end_date)
|
||||
replay_ready_days = []
|
||||
excluded_days = []
|
||||
for day in all_dates:
|
||||
missing_required = [table for table in required_tables if day not in table_dates[table]]
|
||||
missing_optional = [table for table in optional_tables if day not in table_dates[table]]
|
||||
if missing_required:
|
||||
excluded_days.append({"date": day, "reason": "MISSING_REQUIRED_TABLE", "missing_required_tables": missing_required, "missing_optional_tables": missing_optional})
|
||||
else:
|
||||
replay_ready_days.append(day)
|
||||
result = {
|
||||
"symbol": symbol,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"raw_root": str(raw_root),
|
||||
"tables": rows,
|
||||
"replay_ready_day_count": len(replay_ready_days),
|
||||
"excluded_day_count": len(excluded_days),
|
||||
"replay_ready_days": replay_ready_days,
|
||||
"excluded_days": excluded_days,
|
||||
"created_at": utc_now_text(),
|
||||
"ready": all(row["status"] == "OK" for row in rows if row["required"]) and len(replay_ready_days) >= min_ready_days,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def write_audit_outputs(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
result = audit_source_data(args.data_root, args.symbol, args.start_date, args.end_date, int(args.min_ready_days))
|
||||
path = root / "raw-manifest" / "source_data_audit.json"
|
||||
write_json(path, result)
|
||||
write_json(root / "raw-manifest" / "source_data_manifest.json", result)
|
||||
write_json(root / "raw-manifest" / "excluded_days.json", result["excluded_days"])
|
||||
write_text(root / "raw-manifest" / "replay_ready_days.txt", "\n".join(result["replay_ready_days"]) + ("\n" if result["replay_ready_days"] else ""))
|
||||
report_lines = [
|
||||
"# Trader Source Data Audit",
|
||||
"",
|
||||
f"- symbol: {result['symbol']}",
|
||||
f"- raw_root: {result['raw_root']}",
|
||||
f"- ready: {result['ready']}",
|
||||
f"- replay_ready_day_count: {result['replay_ready_day_count']}",
|
||||
f"- excluded_day_count: {result['excluded_day_count']}",
|
||||
"",
|
||||
"| table | required | file_count | first_date | last_date | status |",
|
||||
"| --- | --- | ---: | --- | --- | --- |",
|
||||
]
|
||||
for row in result["tables"]:
|
||||
report_lines.append(
|
||||
f"| {row['table']} | {row['required']} | {row['file_count']} | {row['first_date']} | {row['last_date']} | {row['status']} |"
|
||||
)
|
||||
write_text(root / "raw-manifest" / "source_data_audit.md", "\n".join(report_lines) + "\n")
|
||||
logging.info(
|
||||
"trader.training.audit_written runId=%s ready=%s readyDays=%s excludedDays=%s path=%s",
|
||||
args.run_id,
|
||||
result["ready"],
|
||||
result["replay_ready_day_count"],
|
||||
result["excluded_day_count"],
|
||||
path,
|
||||
)
|
||||
if not result["ready"]:
|
||||
raise SystemExit("required raw tables are missing; see source_data_audit.md")
|
||||
|
||||
|
||||
def _audit_date_range(table_dates: dict[str, set[str]], required_tables: tuple[str, ...], start_date: str | None, end_date: str | None) -> list[str]:
|
||||
if start_date and end_date:
|
||||
start = pd.Timestamp(start_date)
|
||||
end = pd.Timestamp(end_date)
|
||||
else:
|
||||
dates = sorted(set().union(*(table_dates[table] for table in required_tables)))
|
||||
if not dates:
|
||||
return []
|
||||
start = pd.Timestamp(start_date or dates[0])
|
||||
end = pd.Timestamp(end_date or dates[-1])
|
||||
return [day.strftime("%Y-%m-%d") for day in pd.date_range(start, end, freq="D")]
|
||||
|
||||
|
||||
def _minute_frame(frame: pd.DataFrame, time_column: str = "origin_time") -> pd.DataFrame:
|
||||
frame = frame.copy()
|
||||
frame["event_time"] = to_utc_series(frame[time_column]).dt.floor("min")
|
||||
frame["open_time_ms"] = open_time_ms(frame["event_time"])
|
||||
return frame
|
||||
|
||||
|
||||
def _read_candles(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
||||
candles = read_partitioned_table(
|
||||
raw_root,
|
||||
"candles",
|
||||
symbol,
|
||||
start_date,
|
||||
end_date,
|
||||
columns=("origin_time", "start", "open", "high", "low", "close", "volume", "symbol"),
|
||||
)
|
||||
if candles.empty:
|
||||
raise ValueError("candles raw data is required to build replay_1m")
|
||||
time_col = "start" if "start" in candles.columns else "origin_time"
|
||||
candles = _minute_frame(candles, time_col)
|
||||
keep = ["symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "volume"]
|
||||
candles = candles[keep].sort_values(["symbol", "event_time"]).drop_duplicates(["symbol", "event_time"], keep="last")
|
||||
for column in ("open", "high", "low", "close", "volume"):
|
||||
candles[column] = pd.to_numeric(candles[column], errors="coerce")
|
||||
return candles
|
||||
|
||||
|
||||
def _read_trades(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
||||
trades = read_partitioned_table(
|
||||
raw_root,
|
||||
"trades",
|
||||
symbol,
|
||||
start_date,
|
||||
end_date,
|
||||
columns=("origin_time", "side", "quantity", "symbol"),
|
||||
)
|
||||
if trades.empty:
|
||||
raise ValueError("trades raw data is required for taker imbalance")
|
||||
trades = _minute_frame(trades)
|
||||
trades["quantity"] = pd.to_numeric(trades["quantity"], errors="coerce").fillna(0.0)
|
||||
side = trades["side"].astype(str).str.upper()
|
||||
trades["taker_buy_volume"] = np.where(side.eq("BUY"), trades["quantity"], 0.0)
|
||||
trades["taker_sell_volume"] = np.where(side.eq("SELL"), trades["quantity"], 0.0)
|
||||
return (
|
||||
trades.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[["taker_buy_volume", "taker_sell_volume"]]
|
||||
.sum()
|
||||
)
|
||||
|
||||
|
||||
def _read_level1(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
||||
level1 = read_partitioned_table(
|
||||
raw_root,
|
||||
"level_1",
|
||||
symbol,
|
||||
start_date,
|
||||
end_date,
|
||||
columns=("origin_time", "bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size", "symbol"),
|
||||
)
|
||||
if level1.empty:
|
||||
raise ValueError("level_1 raw data is required for spread and OFI")
|
||||
level1 = _minute_frame(level1)
|
||||
for column in ("bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"):
|
||||
level1[column] = pd.to_numeric(level1[column], errors="coerce")
|
||||
level1 = level1.dropna(subset=["bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"])
|
||||
level1 = level1.sort_values(["symbol", "event_time", "origin_time"])
|
||||
group = level1.groupby("symbol", sort=False, observed=True)
|
||||
prev_bid_price = group["bid_0_price"].shift(1)
|
||||
prev_bid_size = group["bid_0_size"].shift(1)
|
||||
prev_ask_price = group["ask_0_price"].shift(1)
|
||||
prev_ask_size = group["ask_0_size"].shift(1)
|
||||
bid_ofi = np.select(
|
||||
[level1["bid_0_price"] > prev_bid_price, level1["bid_0_price"].eq(prev_bid_price)],
|
||||
[level1["bid_0_size"], level1["bid_0_size"] - prev_bid_size],
|
||||
default=-prev_bid_size,
|
||||
)
|
||||
ask_ofi = np.select(
|
||||
[level1["ask_0_price"] < prev_ask_price, level1["ask_0_price"].eq(prev_ask_price)],
|
||||
[level1["ask_0_size"], prev_ask_size - level1["ask_0_size"]],
|
||||
default=-prev_ask_size,
|
||||
)
|
||||
level1["ofi_raw"] = np.nan_to_num(bid_ofi + ask_ofi, nan=0.0)
|
||||
level1["depth"] = (level1["bid_0_size"] + level1["ask_0_size"]).clip(lower=1e-12)
|
||||
level1["mid"] = (level1["bid_0_price"] + level1["ask_0_price"]) / 2.0
|
||||
level1["spread_bps"] = (level1["ask_0_price"] - level1["bid_0_price"]) / level1["mid"] * 10000.0
|
||||
agg = level1.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True).agg(
|
||||
best_bid_price=("bid_0_price", "last"),
|
||||
best_ask_price=("ask_0_price", "last"),
|
||||
spread_bps=("spread_bps", "last"),
|
||||
ofi_sum=("ofi_raw", "sum"),
|
||||
depth_mean=("depth", "mean"),
|
||||
)
|
||||
agg["level1_ofi_1m"] = agg["ofi_sum"] / agg["depth_mean"].clip(lower=1e-12)
|
||||
return agg.drop(columns=["ofi_sum", "depth_mean"])
|
||||
|
||||
|
||||
def _read_liquidations(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
||||
files = partition_files(raw_root, "liquidations", symbol, start_date, end_date)
|
||||
if not files:
|
||||
return pd.DataFrame(columns=["symbol", "event_time", "open_time_ms", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"])
|
||||
liquidations = read_partitioned_table(
|
||||
raw_root,
|
||||
"liquidations",
|
||||
symbol,
|
||||
start_date,
|
||||
end_date,
|
||||
columns=("origin_time", "side", "quantity", "price", "symbol"),
|
||||
)
|
||||
liquidations = _minute_frame(liquidations)
|
||||
liquidations["quantity"] = pd.to_numeric(liquidations["quantity"], errors="coerce").fillna(0.0)
|
||||
liquidations["price"] = pd.to_numeric(liquidations["price"], errors="coerce").fillna(0.0)
|
||||
liquidations["notional"] = liquidations["quantity"] * liquidations["price"]
|
||||
side = liquidations["side"].astype(str).str.upper()
|
||||
liquidations["liquidation_buy_notional_1m"] = np.where(side.eq("BUY"), liquidations["notional"], 0.0)
|
||||
liquidations["liquidation_sell_notional_1m"] = np.where(side.eq("SELL"), liquidations["notional"], 0.0)
|
||||
agg = liquidations.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[
|
||||
["liquidation_buy_notional_1m", "liquidation_sell_notional_1m"]
|
||||
].sum()
|
||||
agg["liquidation_available"] = 1.0
|
||||
return agg
|
||||
|
||||
|
||||
def _asof_column(
|
||||
replay: pd.DataFrame,
|
||||
raw_root: Path,
|
||||
table: str,
|
||||
symbol: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
columns: tuple[str, ...],
|
||||
) -> pd.DataFrame:
|
||||
frame = read_partitioned_table(raw_root, table, symbol, start_date, end_date, columns=("origin_time", "symbol", *columns))
|
||||
if frame.empty:
|
||||
raise ValueError(f"{table} raw data is required")
|
||||
frame = _minute_frame(frame)
|
||||
for column in columns:
|
||||
if column.endswith("time"):
|
||||
continue
|
||||
frame[column] = pd.to_numeric(frame[column], errors="coerce")
|
||||
frame = frame.sort_values(["symbol", "event_time"])
|
||||
left = replay[["symbol", "event_time"]].sort_values(["symbol", "event_time"])
|
||||
merged = pd.merge_asof(
|
||||
left,
|
||||
frame[["symbol", "event_time", *columns]].sort_values(["symbol", "event_time"]),
|
||||
by="symbol",
|
||||
on="event_time",
|
||||
direction="backward",
|
||||
tolerance=pd.Timedelta(hours=12),
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def build_replay_1m(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
raw_root = args.raw_root or DEFAULT_RAW_ROOT
|
||||
logging.info("trader.training.replay_started runId=%s symbol=%s rawRoot=%s", args.run_id, args.symbol, raw_root)
|
||||
replay = _read_candles(raw_root, args.symbol, args.start_date, args.end_date)
|
||||
trades = _read_trades(raw_root, args.symbol, args.start_date, args.end_date)
|
||||
level1 = _read_level1(raw_root, args.symbol, args.start_date, args.end_date)
|
||||
liquidations = _read_liquidations(raw_root, args.symbol, args.start_date, args.end_date)
|
||||
replay = replay.merge(trades, on=["symbol", "event_time", "open_time_ms"], how="left")
|
||||
replay = replay.merge(level1, on=["symbol", "event_time", "open_time_ms"], how="left")
|
||||
replay = replay.merge(liquidations, on=["symbol", "event_time", "open_time_ms"], how="left")
|
||||
replay[["taker_buy_volume", "taker_sell_volume"]] = replay[["taker_buy_volume", "taker_sell_volume"]].fillna(0.0)
|
||||
for column in ("liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"):
|
||||
replay[column] = replay[column].fillna(0.0)
|
||||
|
||||
funding = _asof_column(replay, raw_root, "funding", args.symbol, args.start_date, args.end_date, ("rate", "mark_price", "index_price", "next_funding_time"))
|
||||
funding = funding.rename(columns={"rate": "funding_rate"})
|
||||
funding["funding_bps"] = pd.to_numeric(funding["funding_rate"], errors="coerce") * 10000.0
|
||||
replay = replay.merge(funding.drop(columns=["funding_rate"]), on=["symbol", "event_time"], how="left")
|
||||
replay["next_funding_time"] = to_utc_series(replay["next_funding_time"])
|
||||
|
||||
oi = _asof_column(replay, raw_root, "open_interest", args.symbol, args.start_date, args.end_date, ("open_interest",))
|
||||
replay = replay.merge(oi, on=["symbol", "event_time"], how="left")
|
||||
replay["timeframe"] = "1m"
|
||||
replay["source_coverage"] = "crypto_lake_raw"
|
||||
|
||||
required = [
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"volume",
|
||||
"best_bid_price",
|
||||
"best_ask_price",
|
||||
"spread_bps",
|
||||
"level1_ofi_1m",
|
||||
"funding_bps",
|
||||
"mark_price",
|
||||
"index_price",
|
||||
"open_interest",
|
||||
]
|
||||
replay["event_date"] = replay["event_time"].dt.strftime("%Y-%m-%d")
|
||||
missing_required = replay[required].isna().any(axis=1)
|
||||
day_quality = (
|
||||
replay.assign(missing_required=missing_required.astype(int))
|
||||
.groupby("event_date", as_index=False, observed=True)
|
||||
.agg(row_count=("event_time", "count"), missing_required_rows=("missing_required", "sum"))
|
||||
)
|
||||
day_quality["ready"] = (day_quality["row_count"] >= int(args.min_minutes_per_day)) & day_quality["missing_required_rows"].eq(0)
|
||||
ready_days = sorted(day_quality.loc[day_quality["ready"], "event_date"].astype(str).tolist())
|
||||
excluded_days = [
|
||||
{
|
||||
"date": row.event_date,
|
||||
"row_count": int(row.row_count),
|
||||
"missing_required_rows": int(row.missing_required_rows),
|
||||
"reason": "MISSING_REQUIRED_MARKET_FIELDS" if int(row.missing_required_rows) else "INCOMPLETE_MINUTE_COUNT",
|
||||
}
|
||||
for row in day_quality.loc[~day_quality["ready"]].itertuples(index=False)
|
||||
]
|
||||
if len(ready_days) < int(args.min_replay_ready_days):
|
||||
write_json(root / "replay" / "excluded_days.json", excluded_days)
|
||||
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + ("\n" if ready_days else ""))
|
||||
raise ValueError(f"replay_1m has only {len(ready_days)} replay-ready days, required {args.min_replay_ready_days}")
|
||||
before_filter = len(replay)
|
||||
replay = replay[replay["event_date"].isin(ready_days)].copy()
|
||||
logging.info(
|
||||
"trader.training.replay_ready_days_selected runId=%s readyDays=%s excludedDays=%s rowBefore=%s rowAfter=%s",
|
||||
args.run_id,
|
||||
len(ready_days),
|
||||
len(excluded_days),
|
||||
before_filter,
|
||||
len(replay),
|
||||
)
|
||||
|
||||
columns = [
|
||||
"symbol",
|
||||
"timeframe",
|
||||
"event_time",
|
||||
"open_time_ms",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"volume",
|
||||
"taker_buy_volume",
|
||||
"taker_sell_volume",
|
||||
"funding_bps",
|
||||
"mark_price",
|
||||
"index_price",
|
||||
"next_funding_time",
|
||||
"open_interest",
|
||||
"best_bid_price",
|
||||
"best_ask_price",
|
||||
"spread_bps",
|
||||
"level1_ofi_1m",
|
||||
"liquidation_buy_notional_1m",
|
||||
"liquidation_sell_notional_1m",
|
||||
"liquidation_available",
|
||||
"source_coverage",
|
||||
]
|
||||
replay = replay[columns].sort_values(["symbol", "event_time"]).reset_index(drop=True)
|
||||
path = root / "replay" / "replay_1m.parquet"
|
||||
data_hash = write_parquet(path, replay)
|
||||
write_json(
|
||||
root / "replay" / "replay_1m.manifest.json",
|
||||
manifest(
|
||||
path,
|
||||
{
|
||||
"row_count": len(replay),
|
||||
"hash_sha256": data_hash,
|
||||
"replay_ready_day_count": len(ready_days),
|
||||
"excluded_day_count": len(excluded_days),
|
||||
"min_minutes_per_day": int(args.min_minutes_per_day),
|
||||
},
|
||||
),
|
||||
)
|
||||
write_json(root / "replay" / "excluded_days.json", excluded_days)
|
||||
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + "\n")
|
||||
logging.info("trader.training.replay_written runId=%s rowCount=%s readyDays=%s path=%s", args.run_id, len(replay), len(ready_days), path)
|
||||
|
||||
|
||||
def build_splits(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
|
||||
replay = read_parquet(replay_path)
|
||||
require_columns(replay, ("event_time", "symbol"), "replay_1m")
|
||||
replay["event_time"] = to_utc_series(replay["event_time"])
|
||||
replay = replay.sort_values(["event_time", "symbol"]).reset_index(drop=True)
|
||||
if len(replay) < 10:
|
||||
raise ValueError("not enough replay rows to build time splits")
|
||||
gap = int(args.gap_minutes)
|
||||
intervals = _fixed_split_intervals(args, gap)
|
||||
replay_start = replay["event_time"].min()
|
||||
replay_end = replay["event_time"].max()
|
||||
intervals = [
|
||||
(split_id, max(start, replay_start), min(end, replay_end))
|
||||
for split_id, start, end in intervals
|
||||
if max(start, replay_start) <= min(end, replay_end)
|
||||
]
|
||||
if {item[0] for item in intervals} != set(TRAINING_SPLITS):
|
||||
raise ValueError(f"fixed split dates do not fit replay coverage: replay_start={replay_start} replay_end={replay_end}")
|
||||
split_manifest = {
|
||||
"split_version": SPLIT_VERSION,
|
||||
"created_at": utc_now_text(),
|
||||
"source_replay_path": str(replay_path),
|
||||
"gap_minutes": gap,
|
||||
# Sealed splits are withheld from broad parameter search. They only answer
|
||||
# whether a finished candidate survives final validation and recent stress.
|
||||
"sealed_splits": [VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT],
|
||||
"latest_stress_policy": "FINAL_GATE_ONLY",
|
||||
"requested_splits": {
|
||||
FIT_SPLIT: [args.fit_inner_start, args.fit_inner_end],
|
||||
TUNE_SPLIT: [args.tune_inner_start, args.tune_inner_end],
|
||||
VALIDATION_LOCKED_SPLIT: [args.validation_locked_start, args.validation_locked_end],
|
||||
LATEST_STRESS_SPLIT: [args.latest_stress_start, args.latest_stress_end],
|
||||
},
|
||||
"splits": [
|
||||
{"split_id": split_id, "start": start.isoformat().replace("+00:00", "Z"), "end": end.isoformat().replace("+00:00", "Z")}
|
||||
for split_id, start, end in intervals
|
||||
if start <= end
|
||||
],
|
||||
}
|
||||
fold_count = max(1, int(args.fold_count))
|
||||
fit_interval = next(item for item in intervals if item[0] == FIT_SPLIT)
|
||||
tune_interval = next(item for item in intervals if item[0] == TUNE_SPLIT)
|
||||
train_times = pd.Series(pd.date_range(fit_interval[1], fit_interval[2], periods=fold_count + 1))
|
||||
folds = []
|
||||
for idx in range(fold_count):
|
||||
folds.append(
|
||||
{
|
||||
"walk_forward_fold": f"fold_{idx + 1:02d}",
|
||||
"train_start": fit_interval[1].isoformat().replace("+00:00", "Z"),
|
||||
"train_end": train_times.iloc[idx + 1].isoformat().replace("+00:00", "Z"),
|
||||
"validation_start": tune_interval[1].isoformat().replace("+00:00", "Z"),
|
||||
"validation_end": tune_interval[2].isoformat().replace("+00:00", "Z"),
|
||||
}
|
||||
)
|
||||
ensure_dir(root / "split")
|
||||
write_json(root / "split" / "split_manifest.json", split_manifest)
|
||||
write_json(root / "split" / "walk_forward_folds.json", {"split_version": SPLIT_VERSION, "folds": folds})
|
||||
_write_purge_embargo_report(root / "split" / "purge_embargo_report.md", intervals, gap)
|
||||
logging.info("trader.training.splits_written runId=%s splitCount=%s foldCount=%s", args.run_id, len(split_manifest["splits"]), len(folds))
|
||||
|
||||
|
||||
def assign_split(event_times: pd.Series, split_manifest_path: Path) -> pd.Series:
|
||||
manifest_data = read_json(split_manifest_path)
|
||||
result = pd.Series("NO_SPLIT", index=event_times.index, dtype="object")
|
||||
values = to_utc_series(event_times)
|
||||
for item in manifest_data["splits"]:
|
||||
start = pd.Timestamp(item["start"])
|
||||
end = pd.Timestamp(item["end"])
|
||||
mask = values.between(start, end, inclusive="both")
|
||||
result.loc[mask] = item["split_id"]
|
||||
return result
|
||||
|
||||
|
||||
def _fixed_split_intervals(args: Any, gap_minutes: int) -> list[tuple[str, pd.Timestamp, pd.Timestamp]]:
|
||||
gap = pd.Timedelta(minutes=gap_minutes)
|
||||
return [
|
||||
(FIT_SPLIT, _start_of_day(args.fit_inner_start), _end_of_day(args.fit_inner_end) - gap),
|
||||
(TUNE_SPLIT, _start_of_day(args.tune_inner_start) + gap, _end_of_day(args.tune_inner_end) - gap),
|
||||
(VALIDATION_LOCKED_SPLIT, _start_of_day(args.validation_locked_start) + gap, _end_of_day(args.validation_locked_end) - gap),
|
||||
(LATEST_STRESS_SPLIT, _start_of_day(args.latest_stress_start) + gap, _end_of_day(args.latest_stress_end)),
|
||||
]
|
||||
|
||||
|
||||
def _start_of_day(value: str) -> pd.Timestamp:
|
||||
return pd.Timestamp(value, tz="UTC")
|
||||
|
||||
|
||||
def _end_of_day(value: str) -> pd.Timestamp:
|
||||
return pd.Timestamp(value, tz="UTC") + pd.Timedelta(days=1) - pd.Timedelta(minutes=1)
|
||||
|
||||
|
||||
def _write_purge_embargo_report(path: Path, intervals: list[tuple[str, pd.Timestamp, pd.Timestamp]], gap_minutes: int) -> None:
|
||||
lines = ["# Purge Embargo Report", "", f"- gap_minutes: {gap_minutes}", "", "| split_id | start | end |", "| --- | --- | --- |"]
|
||||
for split_id, start, end in intervals:
|
||||
lines.append(f"| {split_id} | {start.isoformat()} | {end.isoformat()} |")
|
||||
write_text(path, "\n".join(lines) + "\n")
|
||||
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
FEATURE_VERSION = "feature-v4-p0"
|
||||
LABEL_VERSION = "label-v4-p0"
|
||||
SPLIT_VERSION = "split-v4-p0"
|
||||
MODEL_BUNDLE_VERSION = "trader-v4-btc-p0"
|
||||
CALIBRATION_BUNDLE_VERSION = "cal-v4-btc-p0"
|
||||
PM_CONFIG_VERSION = "pm-v4-btc-p0"
|
||||
OUTPUT_SCHEMA_VERSION = "output-schema-v4-btc-p0"
|
||||
|
||||
FIT_SPLIT = "fit_inner"
|
||||
TUNE_SPLIT = "tune_inner"
|
||||
VALIDATION_LOCKED_SPLIT = "validation_locked"
|
||||
LATEST_STRESS_SPLIT = "latest_stress"
|
||||
TRAINING_SPLITS = (FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FeatureDef:
|
||||
order: int
|
||||
name: str
|
||||
cn_name: str
|
||||
meaning: str
|
||||
source_tables: tuple[str, ...]
|
||||
formula: str
|
||||
lookback_window: str
|
||||
unit: str
|
||||
dtype: str
|
||||
null_rule: str
|
||||
live_available: bool
|
||||
leakage_check: str
|
||||
owner_models: tuple[str, ...]
|
||||
|
||||
def as_json(self) -> dict[str, Any]:
|
||||
return {
|
||||
"order": self.order,
|
||||
"name": self.name,
|
||||
"cn_name": self.cn_name,
|
||||
"meaning": self.meaning,
|
||||
"source_tables": list(self.source_tables),
|
||||
"formula": self.formula,
|
||||
"lookback_window": self.lookback_window,
|
||||
"unit": self.unit,
|
||||
"dtype": self.dtype,
|
||||
"null_rule": self.null_rule,
|
||||
"live_available": self.live_available,
|
||||
"leakage_check": self.leakage_check,
|
||||
"owner_models": list(self.owner_models),
|
||||
}
|
||||
|
||||
|
||||
ALL_MODELS = ("Direction", "Entry", "Continue", "Exit", "Risk")
|
||||
|
||||
|
||||
FEATURES: tuple[FeatureDef, ...] = (
|
||||
FeatureDef(1, "ret_1m_bps", "最近1分钟收益", "Latest short return.", ("replay_1m",), "close_t / close_t-1m - 1", "1m", "bps", "float32", "WARMUP", True, "uses <= t close only", ALL_MODELS),
|
||||
FeatureDef(2, "ret_5m_bps", "最近5分钟收益", "Short trend.", ("replay_1m",), "close_t / close_t-5m - 1", "5m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Entry", "Continue", "Exit")),
|
||||
FeatureDef(3, "ret_15m_bps", "最近15分钟收益", "Near trend.", ("replay_1m",), "close_t / close_t-15m - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Entry", "Continue", "Exit")),
|
||||
FeatureDef(4, "ret_60m_bps", "最近60分钟收益", "Baseline trend.", ("replay_1m",), "close_t / close_t-60m - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Continue", "Exit", "Risk")),
|
||||
FeatureDef(5, "ret_240m_bps", "最近240分钟收益", "Four-hour trend.", ("replay_1m",), "close_t / close_t-240m - 1", "240m", "bps", "float32", "WARMUP", True, "uses <= t close only", ("Direction", "Continue", "Risk")),
|
||||
FeatureDef(6, "realized_vol_15m_bps", "15分钟波动", "Near realized volatility.", ("replay_1m",), "std(log_return_1m, 15) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Entry", "Exit", "Risk")),
|
||||
FeatureDef(7, "realized_vol_60m_bps", "60分钟波动", "Baseline realized volatility.", ("replay_1m",), "std(log_return_1m, 60) * 10000", "60m", "bps", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Entry", "Exit", "Risk")),
|
||||
FeatureDef(8, "vol_ratio_15m_60m", "近端波动放大", "Near volatility versus baseline.", ("feature",), "realized_vol_15m_bps / max(realized_vol_60m_bps, 1)", "15m/60m", "ratio", "float32", "WARMUP", True, "derived from <= t features", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(9, "range_15m_bps", "15分钟振幅", "Near high-low range.", ("replay_1m",), "max(high_15m) / min(low_15m) - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t high/low only", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(10, "range_60m_bps", "60分钟振幅", "Baseline high-low range.", ("replay_1m",), "max(high_60m) / min(low_60m) - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t high/low only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(11, "volume_zscore_60m", "60分钟成交量异常", "Current volume abnormality.", ("replay_1m",), "(volume_t - mean(volume_60m)) / std(volume_60m)", "60m", "zscore", "float32", "std=0 -> 0", True, "uses <= t volume only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(12, "trend_consistency_15m", "15分钟方向连续性", "Signed return consistency.", ("replay_1m",), "mean(sign(ret_1m), 15)", "15m", "ratio", "float32", "WARMUP", True, "uses <= t returns only", ("Direction", "Continue", "Exit")),
|
||||
FeatureDef(13, "channel_position_60m_pct", "60分钟通道位置", "Close position in recent channel.", ("replay_1m",), "(close_t - low_60m) / max(high_60m - low_60m, tick)", "60m", "pct", "float32", "WARMUP", True, "uses <= t high/low/close only", ("Direction", "Entry", "Continue")),
|
||||
FeatureDef(14, "upper_breakout_60m_bps", "向上突破距离", "Upper breakout distance.", ("replay_1m",), "max(0, close_t / prev_high_60m_excl_t - 1) * 10000", "60m", "bps", "float32", "WARMUP", True, "current close versus prior window only", ("Direction", "Entry", "Continue")),
|
||||
FeatureDef(15, "lower_breakout_60m_bps", "向下跌破距离", "Lower breakdown distance.", ("replay_1m",), "max(0, prev_low_60m_excl_t / close_t - 1) * 10000", "60m", "bps", "float32", "WARMUP", True, "current close versus prior window only", ("Direction", "Entry", "Continue")),
|
||||
FeatureDef(16, "upper_failed_break_reclaim_15m_bps", "上破失败回落", "Failed upper breakout reclaim.", ("replay_1m",), "if high_15m broke prior high then max(0, prev_high_60m - close_t) / close_t * 10000", "15m/60m", "bps", "float32", "no event -> 0", True, "prior high excludes t", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(17, "lower_failed_break_reclaim_15m_bps", "下破失败收回", "Failed lower breakdown reclaim.", ("replay_1m",), "if low_15m broke prior low then max(0, close_t - prev_low_60m) / close_t * 10000", "15m/60m", "bps", "float32", "no event -> 0", True, "prior low excludes t", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(18, "sweep_up_15m_bps", "上影扫高", "Upper sweep size.", ("replay_1m",), "max(0, max(high_15m) / close_t - 1) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t high/close only", ("Exit", "Risk")),
|
||||
FeatureDef(19, "sweep_down_15m_bps", "下影扫低", "Lower sweep size.", ("replay_1m",), "max(0, close_t / min(low_15m) - 1) * 10000", "15m", "bps", "float32", "WARMUP", True, "uses <= t low/close only", ("Exit", "Risk")),
|
||||
FeatureDef(20, "compression_score_4h_pct", "4小时压缩分位", "Higher means recent range is compressed.", ("feature",), "1 - percentile_rank(range_15m_bps over 240m)", "240m", "pct", "float32", "WARMUP", True, "rolling rank uses <= t", ("Direction", "Entry")),
|
||||
FeatureDef(21, "compression_release_15m_bps", "压缩释放幅度", "Range release versus 4h median.", ("feature",), "max(0, range_15m_bps - median(range_15m_bps over 240m))", "15m/240m", "bps", "float32", "WARMUP", True, "rolling median uses <= t", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(22, "taker_imbalance_1m", "1分钟主动买卖差", "Taker buy/sell imbalance.", ("trades", "replay_1m"), "(buy_1m - sell_1m) / max(total_1m, eps)", "1m", "ratio", "float32", "volume=0 -> 0", True, "uses current closed minute trades only", ("Direction", "Entry", "Continue")),
|
||||
FeatureDef(23, "taker_imbalance_5m", "5分钟主动买卖差", "Short taker imbalance.", ("trades", "replay_1m"), "(buy_5m - sell_5m) / max(total_5m, eps)", "5m", "ratio", "float32", "WARMUP", True, "uses <= t trades only", ("Direction", "Entry", "Continue")),
|
||||
FeatureDef(24, "taker_imbalance_15m", "15分钟主动买卖差", "Near taker imbalance.", ("trades", "replay_1m"), "(buy_15m - sell_15m) / max(total_15m, eps)", "15m", "ratio", "float32", "WARMUP", True, "uses <= t trades only", ("Direction", "Continue", "Exit")),
|
||||
FeatureDef(25, "level1_ofi_1m", "1分钟盘口订单流", "Best bid/ask order-flow imbalance.", ("level_1", "replay_1m"), "sum(OFI changes in minute) / mean(level1 depth)", "1m", "ratio", "float32", "missing -> fail", True, "uses current closed minute L1 only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(26, "spread_bps", "买卖价差", "Best bid/ask spread.", ("level_1", "replay_1m"), "(best_ask - best_bid) / mid * 10000", "1m", "bps", "float32", "missing -> fail", True, "uses current closed minute L1 only", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(27, "spread_rank_24h_pct", "24小时价差分位", "Spread congestion rank.", ("feature",), "percentile_rank(spread_bps over 24h)", "24h", "pct", "float32", "WARMUP", True, "rolling rank uses <= t", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(28, "oi_delta_15m_bps", "15分钟持仓变化", "Open-interest short change.", ("open_interest", "replay_1m"), "open_interest_t / open_interest_t-15m - 1", "15m", "bps", "float32", "WARMUP", True, "uses <= t OI only", ("Direction", "Continue", "Risk")),
|
||||
FeatureDef(29, "oi_delta_60m_bps", "60分钟持仓变化", "Open-interest baseline change.", ("open_interest", "replay_1m"), "open_interest_t / open_interest_t-60m - 1", "60m", "bps", "float32", "WARMUP", True, "uses <= t OI only", ("Direction", "Continue", "Risk")),
|
||||
FeatureDef(30, "funding_bps", "资金费率", "Current funding rate.", ("funding", "replay_1m"), "rate * 10000", "as-of", "bps", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(31, "mark_index_basis_bps", "标记价指数价偏离", "Mark-index basis.", ("funding", "replay_1m"), "mark_price / index_price - 1", "as-of", "bps", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(32, "liquidation_buy_notional_1m", "1分钟买向爆仓金额", "Buy-side liquidation notional.", ("liquidations", "replay_1m"), "sum(quantity * price for BUY)", "1m", "quote", "float32", "missing partition -> 0 with flag", True, "uses current closed minute liquidations only", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(33, "liquidation_sell_notional_1m", "1分钟卖向爆仓金额", "Sell-side liquidation notional.", ("liquidations", "replay_1m"), "sum(quantity * price for SELL)", "1m", "quote", "float32", "missing partition -> 0 with flag", True, "uses current closed minute liquidations only", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(34, "liquidation_imbalance_15m", "15分钟爆仓方向差", "Liquidation imbalance.", ("liquidations", "replay_1m"), "(buy_15m - sell_15m) / max(total_15m, eps)", "15m", "ratio", "float32", "missing partition -> 0 with flag", True, "uses <= t liquidations only", ("Direction", "Entry", "Exit", "Risk")),
|
||||
FeatureDef(35, "liquidation_notional_zscore_15m", "爆仓金额异常", "Liquidation notional zscore.", ("liquidations", "replay_1m"), "(liq_15m - mean_24h) / std_24h", "15m/24h", "zscore", "float32", "missing partition -> 0 with flag", True, "rolling window uses <= t", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(36, "liquidation_available", "爆仓数据可用", "Whether liquidation data exists.", ("liquidations", "replay_1m"), "day partition exists", "day", "0/1", "float32", "never null", True, "partition availability known by event day", ("Entry", "Exit", "Risk")),
|
||||
FeatureDef(37, "minute_of_day_sin", "日内时间正弦", "Time of day cyclic feature.", ("event_time",), "sin(2*pi*minute_of_day/1440)", "event_time", "ratio", "float32", "never null", True, "event timestamp only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(38, "minute_of_day_cos", "日内时间余弦", "Time of day cyclic feature.", ("event_time",), "cos(2*pi*minute_of_day/1440)", "event_time", "ratio", "float32", "never null", True, "event timestamp only", ("Direction", "Entry", "Risk")),
|
||||
FeatureDef(39, "minutes_to_next_funding", "距离下次资金费分钟", "Minutes to next funding settlement.", ("funding", "replay_1m"), "clip((next_funding_time - event_time) / 60000, 0, 480)", "as-of", "minute", "float32", "as-of > 12h -> fail", True, "backward as-of only", ("Entry", "Continue", "Risk")),
|
||||
)
|
||||
|
||||
|
||||
FEATURE_ORDER = [feature.name for feature in FEATURES]
|
||||
|
||||
|
||||
OUTPUT_SCHEMA: dict[str, Any] = {
|
||||
"output_schema_version": OUTPUT_SCHEMA_VERSION,
|
||||
"direction": {
|
||||
"longProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"neutralProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"sum_rule": "longProb + shortProb + neutralProb must equal 1.0 within 0.000001",
|
||||
},
|
||||
"entry": {
|
||||
"longEntryProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortEntryProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longExpectedNetEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
|
||||
"shortExpectedNetEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
|
||||
},
|
||||
"continuation": {
|
||||
"longContinueProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortContinueProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longExpectedContinueEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
|
||||
"shortExpectedContinueEdgeBps": {"type": "decimal", "range": [-500.0, 500.0]},
|
||||
},
|
||||
"exit": {
|
||||
"longExitProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortExitProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longAdverseMoveBps": {"type": "decimal", "range": [0.0, 500.0]},
|
||||
"shortAdverseMoveBps": {"type": "decimal", "range": [0.0, 500.0]},
|
||||
"exitReasonScores": {
|
||||
"adverse_move_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"reversal_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"stop_hit_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"stagnation_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
},
|
||||
},
|
||||
"risk": {
|
||||
"marketRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"longPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"shortPositionRiskProb": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"marketPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"longPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"shortPositionPathRiskBps": {"type": "decimal", "range": [0.0, 1000.0]},
|
||||
"riskReasonScores": {
|
||||
"market_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"volatility_expansion_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"spike_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"liquidity_deterioration_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
"position_drawdown_prob": {"type": "decimal", "range": [0.0, 1.0]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
MODEL_OUTPUTS: dict[str, list[str]] = {
|
||||
"DIRECTION": ["long_prob", "short_prob", "neutral_prob"],
|
||||
"ENTRY": ["long_entry_prob", "short_entry_prob", "long_expected_net_edge_bps", "short_expected_net_edge_bps"],
|
||||
"CONTINUE": ["long_continue_prob", "short_continue_prob", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"],
|
||||
"EXIT": [
|
||||
"long_exit_prob",
|
||||
"short_exit_prob",
|
||||
"long_adverse_move_bps",
|
||||
"short_adverse_move_bps",
|
||||
"adverse_move_prob",
|
||||
"reversal_prob",
|
||||
"stop_hit_prob",
|
||||
"stagnation_prob",
|
||||
],
|
||||
"RISK": [
|
||||
"market_risk_prob",
|
||||
"long_position_risk_prob",
|
||||
"short_position_risk_prob",
|
||||
"market_path_risk_bps",
|
||||
"long_position_path_risk_bps",
|
||||
"short_position_path_risk_bps",
|
||||
"market_drawdown_prob",
|
||||
"volatility_expansion_prob",
|
||||
"spike_prob",
|
||||
"liquidity_deterioration_prob",
|
||||
"position_drawdown_prob",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
PROBABILITY_TARGET_NAMES: dict[str, list[str]] = {
|
||||
"DIRECTION": ["longProb", "shortProb", "neutralProb"],
|
||||
"ENTRY": ["longEntryProb", "shortEntryProb"],
|
||||
"CONTINUE": ["longContinueProb", "shortContinueProb"],
|
||||
"EXIT": ["longExitProb", "shortExitProb", "adverse_move_prob", "reversal_prob", "stop_hit_prob", "stagnation_prob"],
|
||||
"RISK": [
|
||||
"marketRiskProb",
|
||||
"longPositionRiskProb",
|
||||
"shortPositionRiskProb",
|
||||
"market_drawdown_prob",
|
||||
"volatility_expansion_prob",
|
||||
"spike_prob",
|
||||
"liquidity_deterioration_prob",
|
||||
"position_drawdown_prob",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
OUTPUT_MAPPING: dict[str, dict[str, str]] = {
|
||||
model: {field: f"prediction[{index}]" for index, field in enumerate(fields)}
|
||||
for model, fields in MODEL_OUTPUTS.items()
|
||||
}
|
||||
@@ -0,0 +1,581 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LogisticRegression, Ridge
|
||||
from sklearn.metrics import accuracy_score, log_loss, mean_absolute_error, roc_auc_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from trader_training.io_utils import read_parquet, run_root, sha256_file, write_json, write_parquet, write_text
|
||||
from trader_training.onnx_export import LinearHead, export_heads
|
||||
from trader_training.schemas import FEATURE_ORDER, FIT_SPLIT, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, PROBABILITY_TARGET_NAMES, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeadResult:
|
||||
field: str
|
||||
target_name: str | None
|
||||
kind: str
|
||||
weight: np.ndarray
|
||||
bias: np.ndarray
|
||||
metrics: dict[str, Any]
|
||||
tune_prediction: np.ndarray
|
||||
tune_target: np.ndarray | None
|
||||
|
||||
|
||||
TARGETS = {
|
||||
"DIRECTION": {
|
||||
"dataset": "direction_train.parquet",
|
||||
"heads": [("direction", "multiclass", ["long_target", "short_target", "neutral_target"], ["long_prob", "short_prob", "neutral_prob"], ["longProb", "shortProb", "neutralProb"])],
|
||||
},
|
||||
"ENTRY": {
|
||||
"dataset": "entry_train.parquet",
|
||||
"heads": [
|
||||
("long_entry_prob", "binary", "long_entry_target", ["long_entry_prob"], ["longEntryProb"]),
|
||||
("short_entry_prob", "binary", "short_entry_target", ["short_entry_prob"], ["shortEntryProb"]),
|
||||
("long_expected_net_edge_bps", "regression", "long_expected_net_edge_bps", ["long_expected_net_edge_bps"], [None]),
|
||||
("short_expected_net_edge_bps", "regression", "short_expected_net_edge_bps", ["short_expected_net_edge_bps"], [None]),
|
||||
],
|
||||
},
|
||||
"CONTINUE": {
|
||||
"dataset": "continue_train.parquet",
|
||||
"heads": [
|
||||
("long_continue_prob", "binary", "long_continue_target", ["long_continue_prob"], ["longContinueProb"]),
|
||||
("short_continue_prob", "binary", "short_continue_target", ["short_continue_prob"], ["shortContinueProb"]),
|
||||
("long_expected_continue_edge_bps", "regression", "long_expected_continue_edge_bps", ["long_expected_continue_edge_bps"], [None]),
|
||||
("short_expected_continue_edge_bps", "regression", "short_expected_continue_edge_bps", ["short_expected_continue_edge_bps"], [None]),
|
||||
],
|
||||
},
|
||||
"EXIT": {
|
||||
"dataset": "exit_train.parquet",
|
||||
"heads": [
|
||||
("long_exit_prob", "binary", "long_exit_target", ["long_exit_prob"], ["longExitProb"]),
|
||||
("short_exit_prob", "binary", "short_exit_target", ["short_exit_prob"], ["shortExitProb"]),
|
||||
("long_adverse_move_bps", "regression", "long_adverse_move_bps", ["long_adverse_move_bps"], [None]),
|
||||
("short_adverse_move_bps", "regression", "short_adverse_move_bps", ["short_adverse_move_bps"], [None]),
|
||||
("adverse_move_prob", "binary", "adverse_move_prob_label", ["adverse_move_prob"], ["adverse_move_prob"]),
|
||||
("reversal_prob", "binary", "reversal_prob_label", ["reversal_prob"], ["reversal_prob"]),
|
||||
("stop_hit_prob", "binary", "stop_hit_prob_label", ["stop_hit_prob"], ["stop_hit_prob"]),
|
||||
("stagnation_prob", "binary", "stagnation_prob_label", ["stagnation_prob"], ["stagnation_prob"]),
|
||||
],
|
||||
},
|
||||
"RISK": {
|
||||
"dataset": "risk_train.parquet",
|
||||
"heads": [
|
||||
("market_risk_prob", "binary", "market_risk_target", ["market_risk_prob"], ["marketRiskProb"]),
|
||||
("long_position_risk_prob", "binary", "long_position_risk_target", ["long_position_risk_prob"], ["longPositionRiskProb"]),
|
||||
("short_position_risk_prob", "binary", "short_position_risk_target", ["short_position_risk_prob"], ["shortPositionRiskProb"]),
|
||||
("market_path_risk_bps", "regression", "market_path_risk_bps", ["market_path_risk_bps"], [None]),
|
||||
("long_position_path_risk_bps", "regression", "long_position_path_risk_bps", ["long_position_path_risk_bps"], [None]),
|
||||
("short_position_path_risk_bps", "regression", "short_position_path_risk_bps", ["short_position_path_risk_bps"], [None]),
|
||||
("market_drawdown_prob", "binary", "market_drawdown_prob_label", ["market_drawdown_prob"], ["market_drawdown_prob"]),
|
||||
("volatility_expansion_prob", "binary", "volatility_expansion_prob_label", ["volatility_expansion_prob"], ["volatility_expansion_prob"]),
|
||||
("spike_prob", "binary", "spike_prob_label", ["spike_prob"], ["spike_prob"]),
|
||||
("liquidity_deterioration_prob", "binary", "liquidity_deterioration_prob_label", ["liquidity_deterioration_prob"], ["liquidity_deterioration_prob"]),
|
||||
("position_drawdown_prob", "binary", "position_drawdown_prob_label", ["position_drawdown_prob"], ["position_drawdown_prob"]),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def train_small_models(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
model_manifest: dict[str, Any] = {}
|
||||
for model_name, spec in TARGETS.items():
|
||||
dataset = read_parquet(root / "dataset" / spec["dataset"])
|
||||
if args.max_rows and len(dataset) > args.max_rows:
|
||||
dataset = dataset.sort_values("event_time").tail(args.max_rows).copy()
|
||||
if dataset.empty:
|
||||
raise ValueError(f"dataset is empty for {model_name}")
|
||||
train = dataset[dataset["split_id"] == FIT_SPLIT].copy()
|
||||
tune = dataset[dataset["split_id"] == TUNE_SPLIT].copy()
|
||||
validation_locked = dataset[dataset["split_id"] == VALIDATION_LOCKED_SPLIT].copy()
|
||||
latest_stress = dataset[dataset["split_id"] == LATEST_STRESS_SPLIT].copy()
|
||||
if train.empty or tune.empty:
|
||||
raise ValueError(f"{model_name} needs {FIT_SPLIT} and {TUNE_SPLIT} rows")
|
||||
logging.info(
|
||||
"trader.training.model_dataset_loaded runId=%s model=%s totalRows=%s trainRows=%s tuneRows=%s validationLockedRows=%s latestStressRows=%s splitCounts=%s",
|
||||
args.run_id,
|
||||
model_name,
|
||||
len(dataset),
|
||||
len(train),
|
||||
len(tune),
|
||||
len(validation_locked),
|
||||
len(latest_stress),
|
||||
dataset["split_id"].value_counts().to_dict(),
|
||||
)
|
||||
scaler = StandardScaler()
|
||||
x_train_scaled = scaler.fit_transform(train[FEATURE_ORDER].astype("float32"))
|
||||
x_tune_scaled = scaler.transform(tune[FEATURE_ORDER].astype("float32"))
|
||||
heads: list[LinearHead] = []
|
||||
head_results: list[HeadResult] = []
|
||||
for item in spec["heads"]:
|
||||
head_results.extend(_fit_head(item, x_train_scaled, x_tune_scaled, train, tune, scaler))
|
||||
for result in head_results:
|
||||
logging.info(
|
||||
"trader.training.model_head_trained runId=%s model=%s head=%s kind=%s metrics=%s",
|
||||
args.run_id,
|
||||
model_name,
|
||||
result.field,
|
||||
result.kind,
|
||||
result.metrics,
|
||||
)
|
||||
for result in head_results:
|
||||
heads.append(LinearHead(result.field, _onnx_kind(result.kind), result.weight, result.bias))
|
||||
model_dir = root / "model" / model_name.lower()
|
||||
model_path = model_dir / f"{model_name.lower()}.onnx"
|
||||
export_heads(model_path, heads, feature_count=len(FEATURE_ORDER), opset=17)
|
||||
predictions = _tune_prediction_frame(tune, head_results)
|
||||
write_parquet(model_dir / "tune_predictions.parquet", predictions)
|
||||
if not validation_locked.empty:
|
||||
write_parquet(model_dir / "validation_locked_predictions.parquet", _predict_frame(validation_locked, head_results, include_labels=True))
|
||||
if not latest_stress.empty:
|
||||
write_parquet(model_dir / "latest_stress_predictions.parquet", _predict_frame(latest_stress, head_results, include_labels=False))
|
||||
metrics = {result.field: result.metrics for result in head_results}
|
||||
model_hash = sha256_file(model_path)
|
||||
quality_status, quality_reasons = _model_quality(head_results)
|
||||
write_json(
|
||||
model_dir / "model_train_result.json",
|
||||
{
|
||||
"model_name": model_name,
|
||||
"metrics": metrics,
|
||||
"quality_status": quality_status,
|
||||
"quality_reasons": quality_reasons,
|
||||
"artifact_hash_sha256": model_hash,
|
||||
},
|
||||
)
|
||||
write_json(
|
||||
model_dir / "model_manifest.json",
|
||||
{
|
||||
"model_name": model_name,
|
||||
"model_path": str(model_path),
|
||||
"model_format": "ONNX",
|
||||
"input_tensor_name": "features",
|
||||
"input_feature_count": len(FEATURE_ORDER),
|
||||
"output_tensor_name": "prediction",
|
||||
"output_fields": MODEL_OUTPUTS[model_name],
|
||||
"quality_status": quality_status,
|
||||
"quality_reasons": quality_reasons,
|
||||
"artifact_hash_sha256": model_hash,
|
||||
},
|
||||
)
|
||||
_write_model_examples(model_dir, model_name, tune, predictions)
|
||||
_write_feature_importance(model_dir / "feature_importance.csv", head_results)
|
||||
_write_version_compare(model_dir, model_name, metrics)
|
||||
_write_training_report(model_dir / "training_report.md", model_name, metrics, quality_status, quality_reasons)
|
||||
model_manifest[model_name] = {"path": str(model_path), "hash_sha256": model_hash, "metrics": metrics, "quality_status": quality_status, "quality_reasons": quality_reasons}
|
||||
logging.info(
|
||||
"trader.training.model_trained runId=%s model=%s qualityStatus=%s qualityReasons=%s path=%s tunePredictionRows=%s featureImportancePath=%s",
|
||||
args.run_id,
|
||||
model_name,
|
||||
quality_status,
|
||||
quality_reasons,
|
||||
model_path,
|
||||
len(predictions),
|
||||
model_dir / "feature_importance.csv",
|
||||
)
|
||||
write_json(root / "model" / "model_train_manifest.json", model_manifest)
|
||||
|
||||
|
||||
def _fit_head(item, x_train, x_tune, train: pd.DataFrame, tune: pd.DataFrame, scaler: StandardScaler) -> list[HeadResult]:
|
||||
name, kind, target, fields, target_names = item
|
||||
if kind == "multiclass":
|
||||
y_train = train[target].to_numpy().argmax(axis=1)
|
||||
y_val = tune[target].to_numpy().argmax(axis=1)
|
||||
model = LogisticRegression(max_iter=500)
|
||||
model.fit(x_train, y_train)
|
||||
proba = model.predict_proba(x_tune)
|
||||
weight, bias = _fold_scaler(model.coef_.T, model.intercept_, scaler)
|
||||
train_prior = train[target].to_numpy().mean(axis=0)
|
||||
metrics = _multiclass_metrics(y_train, y_val, proba, train_prior)
|
||||
return [HeadResult("direction", target_names[0], "softmax", weight, bias, metrics, proba, y_val)]
|
||||
if kind == "binary":
|
||||
y_train = pd.to_numeric(train[target], errors="coerce").fillna(0).astype(int).to_numpy()
|
||||
y_val = pd.to_numeric(tune[target], errors="coerce").fillna(0).astype(int).to_numpy()
|
||||
if len(np.unique(y_train)) < 2:
|
||||
prevalence = float(np.clip(y_train.mean(), 1e-6, 1 - 1e-6))
|
||||
coef = np.zeros((1, len(FEATURE_ORDER)), dtype=np.float32)
|
||||
intercept = np.array([np.log(prevalence / (1 - prevalence))], dtype=np.float32)
|
||||
proba = np.full(len(y_val), prevalence, dtype=np.float32)
|
||||
else:
|
||||
model = LogisticRegression(max_iter=500)
|
||||
model.fit(x_train, y_train)
|
||||
coef = model.coef_
|
||||
intercept = model.intercept_
|
||||
proba = model.predict_proba(x_tune)[:, 1]
|
||||
weight, bias = _fold_scaler(coef.T, intercept, scaler)
|
||||
metrics = _binary_metrics(y_train, y_val, proba)
|
||||
if len(np.unique(y_val)) == 2:
|
||||
metrics["auc"] = float(roc_auc_score(y_val, proba))
|
||||
return [HeadResult(fields[0], target_names[0], "sigmoid", weight, bias, metrics, proba.reshape(-1, 1), y_val)]
|
||||
if kind == "regression":
|
||||
y_train = pd.to_numeric(train[target], errors="coerce").fillna(0.0).to_numpy()
|
||||
y_val = pd.to_numeric(tune[target], errors="coerce").fillna(0.0).to_numpy()
|
||||
model = Ridge(alpha=1.0)
|
||||
model.fit(x_train, y_train)
|
||||
pred = model.predict(x_tune)
|
||||
weight, bias = _fold_scaler(model.coef_.reshape(1, -1).T, np.array([model.intercept_]), scaler)
|
||||
return [HeadResult(fields[0], None, "identity", weight, bias, _regression_metrics(y_train, y_val, pred), pred.reshape(-1, 1), y_val)]
|
||||
raise ValueError(f"unsupported head kind: {kind}")
|
||||
|
||||
|
||||
def _fold_scaler(weight_scaled: np.ndarray, bias_scaled: np.ndarray, scaler: StandardScaler) -> tuple[np.ndarray, np.ndarray]:
|
||||
scale = np.where(scaler.scale_ == 0, 1.0, scaler.scale_)
|
||||
weight = weight_scaled / scale.reshape(-1, 1)
|
||||
bias = bias_scaled - np.sum((scaler.mean_ / scale).reshape(-1, 1) * weight_scaled, axis=0)
|
||||
return weight.astype(np.float32), bias.astype(np.float32)
|
||||
|
||||
|
||||
def _onnx_kind(kind: str) -> str:
|
||||
if kind in ("softmax", "sigmoid", "identity"):
|
||||
return kind
|
||||
raise ValueError(f"unsupported result kind: {kind}")
|
||||
|
||||
|
||||
def _multiclass_metrics(y_train: np.ndarray, y_val: np.ndarray, proba: np.ndarray, train_prior: np.ndarray) -> dict[str, Any]:
|
||||
one_hot = np.eye(proba.shape[1], dtype=float)[y_val]
|
||||
train_prior = np.asarray(train_prior, dtype=float)
|
||||
train_prior = train_prior / train_prior.sum() if train_prior.sum() > 0 else np.full(proba.shape[1], 1.0 / proba.shape[1])
|
||||
constant = np.tile(train_prior.reshape(1, -1), (len(y_val), 1))
|
||||
proba_for_logloss = _clip_normalize(proba)
|
||||
constant_for_logloss = _clip_normalize(constant)
|
||||
metrics: dict[str, Any] = {
|
||||
"accuracy": float(accuracy_score(y_val, proba.argmax(axis=1))),
|
||||
"logloss": float(log_loss(y_val, proba_for_logloss, labels=list(range(proba.shape[1])))),
|
||||
"constant_logloss": float(log_loss(y_val, constant_for_logloss, labels=list(range(proba.shape[1])))),
|
||||
"brier_multiclass": float(np.mean(np.sum((one_hot - proba) ** 2, axis=1))),
|
||||
"constant_brier_multiclass": float(np.mean(np.sum((one_hot - constant) ** 2, axis=1))),
|
||||
}
|
||||
for idx, name in enumerate(("long", "short", "neutral")):
|
||||
binary_target = (y_val == idx).astype(int)
|
||||
positives = int(binary_target.sum())
|
||||
negatives = int(len(binary_target) - positives)
|
||||
if positives >= 200 and negatives >= 200:
|
||||
metrics[f"{name}_auc"] = float(roc_auc_score(binary_target, proba[:, idx]))
|
||||
else:
|
||||
metrics[f"{name}_auc_status"] = "INSUFFICIENT_SAMPLE"
|
||||
metrics[f"{name}_positive_count"] = positives
|
||||
metrics[f"{name}_negative_count"] = negatives
|
||||
max_prob = proba.max(axis=1)
|
||||
predicted_class = proba.argmax(axis=1)
|
||||
top_count = max(1, int(len(y_val) * 0.10))
|
||||
top_idx = np.argsort(max_prob)[-top_count:]
|
||||
metrics["top10_hit_rate"] = float((predicted_class[top_idx] == y_val[top_idx]).mean())
|
||||
metrics["all_hit_rate"] = float((predicted_class == y_val).mean())
|
||||
return _with_quality(metrics)
|
||||
|
||||
|
||||
def _clip_normalize(values: np.ndarray) -> np.ndarray:
|
||||
values = np.clip(np.asarray(values, dtype=float), 1e-6, 1.0)
|
||||
return values / values.sum(axis=1, keepdims=True)
|
||||
|
||||
|
||||
def _binary_metrics(y_train: np.ndarray, y_val: np.ndarray, proba: np.ndarray) -> dict[str, Any]:
|
||||
proba = np.asarray(proba, dtype=float)
|
||||
train_rate = float(np.mean(y_train)) if len(y_train) else 0.0
|
||||
constant = np.full(len(y_val), train_rate)
|
||||
metrics: dict[str, Any] = {
|
||||
"positive_rate": train_rate,
|
||||
"tune_positive_rate": float(np.mean(y_val)) if len(y_val) else 0.0,
|
||||
"brier": float(np.mean((y_val - proba) ** 2)) if len(y_val) else 0.0,
|
||||
"constant_brier": float(np.mean((y_val - constant) ** 2)) if len(y_val) else 0.0,
|
||||
}
|
||||
if len(y_val):
|
||||
top_count = max(1, int(len(y_val) * 0.10))
|
||||
top_idx = np.argsort(proba)[-top_count:]
|
||||
metrics["top10_hit_rate"] = float(np.mean(y_val[top_idx]))
|
||||
metrics["all_hit_rate"] = float(np.mean(y_val))
|
||||
return _with_quality(metrics)
|
||||
|
||||
|
||||
def _regression_metrics(y_train: np.ndarray, y_val: np.ndarray, pred: np.ndarray) -> dict[str, Any]:
|
||||
mae = float(mean_absolute_error(y_val, pred))
|
||||
train_std = float(np.std(y_train))
|
||||
metrics: dict[str, Any] = {
|
||||
"mae": mae,
|
||||
"train_target_std": train_std,
|
||||
"mae_vs_train_std_ratio": float(mae / train_std) if train_std > 0 else None,
|
||||
}
|
||||
return _with_quality(metrics)
|
||||
|
||||
|
||||
def _with_quality(metrics: dict[str, Any]) -> dict[str, Any]:
|
||||
reasons: list[str] = []
|
||||
for key, value in metrics.items():
|
||||
if key.endswith("_auc") and isinstance(value, float) and value < 0.53:
|
||||
reasons.append(f"{key}_below_0.53")
|
||||
if "brier" in metrics and metrics.get("constant_brier") is not None and metrics["brier"] >= metrics["constant_brier"]:
|
||||
reasons.append("brier_not_better_than_constant")
|
||||
if "brier_multiclass" in metrics and metrics["brier_multiclass"] >= metrics["constant_brier_multiclass"]:
|
||||
reasons.append("brier_not_better_than_constant")
|
||||
if "mae" in metrics and metrics.get("train_target_std") is not None and metrics["train_target_std"] > 0 and metrics["mae"] > metrics["train_target_std"]:
|
||||
reasons.append("mae_above_train_target_std")
|
||||
if "top10_hit_rate" in metrics and "all_hit_rate" in metrics and metrics["top10_hit_rate"] <= metrics["all_hit_rate"]:
|
||||
reasons.append("top10_not_better_than_all")
|
||||
metrics["quality_status"] = "REJECTED" if reasons else "PASS"
|
||||
metrics["quality_reasons"] = reasons
|
||||
return metrics
|
||||
|
||||
|
||||
def _model_quality(results: list[HeadResult]) -> tuple[str, list[str]]:
|
||||
reasons = []
|
||||
for result in results:
|
||||
if result.metrics.get("quality_status") == "REJECTED":
|
||||
for reason in result.metrics.get("quality_reasons", []):
|
||||
reasons.append(f"{result.field}:{reason}")
|
||||
return ("REJECTED", reasons) if reasons else ("PASS", [])
|
||||
|
||||
|
||||
def _tune_prediction_frame(tune: pd.DataFrame, results: list[HeadResult]) -> pd.DataFrame:
|
||||
out = tune[["sample_id", "symbol", "event_time", "split_id"]].copy().reset_index(drop=True)
|
||||
for result in results:
|
||||
values = result.tune_prediction
|
||||
if result.kind == "softmax":
|
||||
for idx, field in enumerate(MODEL_OUTPUTS["DIRECTION"]):
|
||||
out[field] = values[:, idx]
|
||||
if result.tune_target is not None:
|
||||
out["label__longProb"] = (result.tune_target == 0).astype(int)
|
||||
out["label__shortProb"] = (result.tune_target == 1).astype(int)
|
||||
out["label__neutralProb"] = (result.tune_target == 2).astype(int)
|
||||
else:
|
||||
out[result.field] = values.reshape(-1)
|
||||
if result.kind != "softmax" and result.target_name and result.tune_target is not None:
|
||||
out[f"label__{result.target_name}"] = result.tune_target
|
||||
return out
|
||||
|
||||
|
||||
def _predict_frame(frame: pd.DataFrame, results: list[HeadResult], include_labels: bool) -> pd.DataFrame:
|
||||
out = frame[["sample_id", "symbol", "event_time", "split_id"]].copy().reset_index(drop=True)
|
||||
features = frame[FEATURE_ORDER].astype("float32").to_numpy()
|
||||
for result in results:
|
||||
values = features @ result.weight + result.bias.reshape(1, -1)
|
||||
if result.kind == "softmax":
|
||||
values = _softmax(values)
|
||||
for idx, field in enumerate(MODEL_OUTPUTS["DIRECTION"]):
|
||||
out[field] = values[:, idx]
|
||||
elif result.kind == "sigmoid":
|
||||
out[result.field] = (1.0 / (1.0 + np.exp(-values))).reshape(-1)
|
||||
else:
|
||||
out[result.field] = values.reshape(-1)
|
||||
if include_labels and result.kind != "softmax" and result.target_name and result.target_name in frame.columns:
|
||||
out[f"label__{result.target_name}"] = frame[result.target_name].to_numpy()
|
||||
return out
|
||||
|
||||
|
||||
def _softmax(values: np.ndarray) -> np.ndarray:
|
||||
shifted = values - np.max(values, axis=1, keepdims=True)
|
||||
exp = np.exp(shifted)
|
||||
return exp / exp.sum(axis=1, keepdims=True)
|
||||
|
||||
|
||||
def _write_training_report(path: Path, model_name: str, metrics: dict[str, Any], quality_status: str, quality_reasons: list[str]) -> None:
|
||||
lines = [
|
||||
"# Trader Model Training Report",
|
||||
"",
|
||||
f"- model: {model_name}",
|
||||
f"- quality_status: {quality_status}",
|
||||
f"- quality_reasons: {quality_reasons}",
|
||||
"",
|
||||
"```json",
|
||||
json.dumps(metrics, indent=2, sort_keys=True),
|
||||
"```",
|
||||
"",
|
||||
]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _write_model_examples(model_dir: Path, model_name: str, tune: pd.DataFrame, predictions: pd.DataFrame) -> None:
|
||||
sample_input = {feature: float(tune.iloc[0][feature]) for feature in FEATURE_ORDER}
|
||||
sample_output = {field: float(predictions.iloc[0][field]) for field in MODEL_OUTPUTS[model_name]}
|
||||
write_json(model_dir / "sample_input.json", sample_input)
|
||||
write_json(model_dir / "sample_output.json", sample_output)
|
||||
|
||||
|
||||
def _write_feature_importance(path: Path, results: list[HeadResult]) -> None:
|
||||
rows = []
|
||||
for result in results:
|
||||
importance = np.mean(np.abs(result.weight), axis=1)
|
||||
for feature, value in zip(FEATURE_ORDER, importance):
|
||||
rows.append({"head": result.field, "feature": feature, "abs_weight": float(value)})
|
||||
frame = pd.DataFrame(rows).sort_values(["head", "abs_weight"], ascending=[True, False])
|
||||
write_text(path, frame.to_csv(index=False))
|
||||
|
||||
|
||||
def _write_version_compare(model_dir: Path, model_name: str, metrics: dict[str, Any]) -> None:
|
||||
payload = {
|
||||
"model_name": model_name,
|
||||
"status": "NO_BASELINE",
|
||||
"reason": "first executable V4 training chain has no previous approved artifact bundle under this run root",
|
||||
"current_metrics": metrics,
|
||||
}
|
||||
write_json(model_dir / "version_compare_metrics.json", payload)
|
||||
write_text(model_dir / "version_compare_by_regime.csv", "regime,status,reason\nNO_BASELINE,NO_BASELINE,no previous approved artifact bundle\n")
|
||||
write_text(model_dir / "version_compare_top_bucket.csv", "bucket,status,reason\nNO_BASELINE,NO_BASELINE,no previous approved artifact bundle\n")
|
||||
lines = [
|
||||
"# Version Compare Report",
|
||||
"",
|
||||
f"- model: {model_name}",
|
||||
"- status: NO_BASELINE",
|
||||
"- reason: 当前 run 目录没有上一版已验收模型包,所以首版只能记录当前指标,不能做新旧优劣判断。",
|
||||
"",
|
||||
]
|
||||
write_text(model_dir / "version_compare_report.md", "\n".join(lines))
|
||||
|
||||
|
||||
def build_calibrators(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
manifest_rows = []
|
||||
for model_name, target_names in PROBABILITY_TARGET_NAMES.items():
|
||||
prediction_path = root / "model" / model_name.lower() / "tune_predictions.parquet"
|
||||
predictions = read_parquet(prediction_path)
|
||||
targets = {}
|
||||
reliability_rows = []
|
||||
quality_reasons = []
|
||||
for target_name in target_names:
|
||||
raw_field = _target_to_raw_field(model_name, target_name)
|
||||
label_field = f"label__{target_name}"
|
||||
labels = predictions[label_field].to_numpy() if label_field in predictions.columns else None
|
||||
raw = predictions[raw_field].to_numpy()
|
||||
bins = _calibration_bins(raw, labels)
|
||||
metrics, rows = _calibration_metrics(raw, labels, bins, target_name)
|
||||
targets[target_name] = {"bins": bins, "metrics": metrics}
|
||||
reliability_rows.extend(rows)
|
||||
if metrics.get("quality_status") == "REJECTED":
|
||||
quality_reasons.append(f"{target_name}:{metrics.get('quality_reason')}")
|
||||
calibrator = {
|
||||
"calibrator_version": f"{model_name.lower()}-cal-v4-btc-p0",
|
||||
"method": "BINNING",
|
||||
"targets": targets,
|
||||
"clip": {"min": 0.0, "max": 1.0},
|
||||
"fallback_policy": "FAIL_FAST",
|
||||
}
|
||||
path = root / "calibration" / model_name.lower() / "calibrator.json"
|
||||
cal_hash = write_json(path, calibrator)
|
||||
quality_status = "REJECTED" if quality_reasons else "PASS"
|
||||
manifest_rows.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"calibrator_path": str(path),
|
||||
"calibrator_hash_sha256": cal_hash,
|
||||
"target_count": len(targets),
|
||||
"quality_status": quality_status,
|
||||
"quality_reasons": quality_reasons,
|
||||
}
|
||||
)
|
||||
write_text(root / "calibration" / model_name.lower() / "reliability_curve.csv", pd.DataFrame(reliability_rows).to_csv(index=False))
|
||||
_write_calibration_report(root / "calibration" / model_name.lower() / "calibration_report.md", model_name, targets, quality_status, quality_reasons)
|
||||
logging.info("trader.training.calibrator_written runId=%s model=%s path=%s", args.run_id, model_name, path)
|
||||
write_json(root / "calibration" / "calibration_train_manifest.json", {"calibrators": manifest_rows})
|
||||
|
||||
|
||||
def _target_to_raw_field(model_name: str, target_name: str) -> str:
|
||||
mapping = {
|
||||
"longProb": "long_prob",
|
||||
"shortProb": "short_prob",
|
||||
"neutralProb": "neutral_prob",
|
||||
"longEntryProb": "long_entry_prob",
|
||||
"shortEntryProb": "short_entry_prob",
|
||||
"longContinueProb": "long_continue_prob",
|
||||
"shortContinueProb": "short_continue_prob",
|
||||
"longExitProb": "long_exit_prob",
|
||||
"shortExitProb": "short_exit_prob",
|
||||
"marketRiskProb": "market_risk_prob",
|
||||
"longPositionRiskProb": "long_position_risk_prob",
|
||||
"shortPositionRiskProb": "short_position_risk_prob",
|
||||
}
|
||||
return mapping.get(target_name, target_name)
|
||||
|
||||
|
||||
def _calibration_bins(raw: np.ndarray, labels: np.ndarray | None) -> list[dict[str, float]]:
|
||||
raw = np.asarray(raw, dtype=float)
|
||||
raw = np.clip(np.nan_to_num(raw, nan=0.5), 0.0, 1.0)
|
||||
if labels is None or len(labels) != len(raw):
|
||||
return [{"min": 0.0, "max": 1.0, "calibrated": 0.5}]
|
||||
labels = np.asarray(labels, dtype=float)
|
||||
bins = []
|
||||
edges = np.linspace(0.0, 1.0, 11)
|
||||
for left, right in zip(edges[:-1], edges[1:]):
|
||||
if right == 1.0:
|
||||
mask = (raw >= left) & (raw <= right)
|
||||
else:
|
||||
mask = (raw >= left) & (raw < right)
|
||||
calibrated = float(labels[mask].mean()) if mask.any() else float((left + right) / 2.0)
|
||||
bins.append({"min": float(left), "max": float(right), "calibrated": float(np.clip(calibrated, 0.0, 1.0))})
|
||||
return bins
|
||||
|
||||
|
||||
def _apply_calibration(raw: np.ndarray, bins: list[dict[str, float]]) -> np.ndarray:
|
||||
out = np.zeros_like(raw, dtype=float)
|
||||
for item in bins:
|
||||
left = float(item["min"])
|
||||
right = float(item["max"])
|
||||
if right >= 1.0:
|
||||
mask = (raw >= left) & (raw <= right)
|
||||
else:
|
||||
mask = (raw >= left) & (raw < right)
|
||||
out[mask] = float(item["calibrated"])
|
||||
return np.clip(out, 0.0, 1.0)
|
||||
|
||||
|
||||
def _calibration_metrics(raw: np.ndarray, labels: np.ndarray | None, bins: list[dict[str, float]], target_name: str) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
raw = np.clip(np.asarray(raw, dtype=float), 0.0, 1.0)
|
||||
if labels is None or len(labels) != len(raw):
|
||||
return {"quality_status": "REJECTED", "quality_reason": "missing_labels"}, []
|
||||
labels = np.asarray(labels, dtype=float)
|
||||
calibrated = _apply_calibration(raw, bins)
|
||||
raw_ece, rows = _ece(raw, labels, target_name, "raw")
|
||||
calibrated_ece, calibrated_rows = _ece(calibrated, labels, target_name, "calibrated")
|
||||
rows.extend(calibrated_rows)
|
||||
quality_status = "PASS" if calibrated_ece <= raw_ece else "REJECTED"
|
||||
return (
|
||||
{
|
||||
"raw_ece": raw_ece,
|
||||
"calibrated_ece": calibrated_ece,
|
||||
"quality_status": quality_status,
|
||||
"quality_reason": None if quality_status == "PASS" else "calibrated_ece_not_improved",
|
||||
},
|
||||
rows,
|
||||
)
|
||||
|
||||
|
||||
def _ece(values: np.ndarray, labels: np.ndarray, target_name: str, series_name: str) -> tuple[float, list[dict[str, Any]]]:
|
||||
rows = []
|
||||
total = len(values)
|
||||
ece = 0.0
|
||||
edges = np.linspace(0.0, 1.0, 11)
|
||||
for left, right in zip(edges[:-1], edges[1:]):
|
||||
mask = (values >= left) & (values <= right) if right >= 1.0 else (values >= left) & (values < right)
|
||||
count = int(mask.sum())
|
||||
confidence = float(values[mask].mean()) if count else float((left + right) / 2.0)
|
||||
accuracy = float(labels[mask].mean()) if count else 0.0
|
||||
ece += (count / total) * abs(confidence - accuracy) if total else 0.0
|
||||
rows.append(
|
||||
{
|
||||
"target": target_name,
|
||||
"series": series_name,
|
||||
"bin_min": left,
|
||||
"bin_max": right,
|
||||
"count": count,
|
||||
"confidence": confidence,
|
||||
"accuracy": accuracy,
|
||||
}
|
||||
)
|
||||
return float(ece), rows
|
||||
|
||||
|
||||
def _write_calibration_report(path: Path, model_name: str, targets: dict[str, Any], quality_status: str, quality_reasons: list[str]) -> None:
|
||||
lines = ["# Trader Calibration Report", "", f"- model: {model_name}", f"- target_count: {len(targets)}", f"- quality_status: {quality_status}", f"- quality_reasons: {quality_reasons}", ""]
|
||||
for target_name, payload in targets.items():
|
||||
lines.append(f"## {target_name}")
|
||||
lines.append("")
|
||||
lines.append("```json")
|
||||
lines.append(json.dumps(payload.get("metrics", {}), indent=2, sort_keys=True))
|
||||
lines.append("```")
|
||||
lines.append("")
|
||||
write_text(path, "\n".join(lines))
|
||||
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from trader_training.io_utils import read_json, sha256_file, write_json, write_text
|
||||
from trader_training.schemas import FEATURE_ORDER, MODEL_OUTPUTS, OUTPUT_MAPPING
|
||||
|
||||
|
||||
def validate_artifact_bundle(args: Any) -> None:
|
||||
root = args.artifact_root
|
||||
errors: list[str] = []
|
||||
release_gate = {"release_gate_status": "UNKNOWN", "release_gate_reasons": ["artifact content not checked"]}
|
||||
required = [
|
||||
"manifests/model_bundle_manifest.json",
|
||||
"manifests/model_manifest.json",
|
||||
"manifests/calibration_manifest.json",
|
||||
"manifests/position_manager_manifest.json",
|
||||
"schemas/feature_schema.json",
|
||||
"schemas/feature_order.json",
|
||||
"schemas/output_schema.json",
|
||||
"price_plan_context.json",
|
||||
"examples/sample_input.json",
|
||||
"examples/sample_output.json",
|
||||
]
|
||||
for relative in required:
|
||||
if not (root / relative).is_file():
|
||||
errors.append(f"missing file: {relative}")
|
||||
if not errors:
|
||||
_validate_content(root, errors, args.require_active, args.run_onnx)
|
||||
# Structure validation only proves the bundle is loadable. The release
|
||||
# gate is the separate business decision for whether it may become ACTIVE.
|
||||
release_gate = _release_gate(root)
|
||||
if args.require_active and release_gate["release_gate_status"] != "PASS":
|
||||
errors.append(f"release gate must be PASS for ACTIVE, actual={release_gate['release_gate_status']}")
|
||||
status = "PASS" if not errors else "FAIL"
|
||||
result = {"status": status, "error_count": len(errors), "errors": errors, "artifact_root": str(root), **release_gate}
|
||||
output_root = root.parent
|
||||
write_json(output_root / "artifact_validation_result.json", result)
|
||||
lines = ["# Artifact Validation Report", "", f"- status: {status}", f"- error_count: {len(errors)}", ""]
|
||||
for error in errors:
|
||||
lines.append(f"- {error}")
|
||||
write_text(output_root / "artifact_validation_report.md", "\n".join(lines) + "\n")
|
||||
logging.info("trader.training.artifact_validated status=%s errorCount=%s path=%s", status, len(errors), root)
|
||||
if errors:
|
||||
raise SystemExit("artifact validation failed; see artifact_validation_report.md")
|
||||
|
||||
|
||||
def _validate_content(root: Path, errors: list[str], require_active: bool, run_onnx: bool) -> None:
|
||||
feature_order = read_json(root / "schemas/feature_order.json")
|
||||
if feature_order != FEATURE_ORDER:
|
||||
errors.append("feature_order.json does not match V4 39-feature order")
|
||||
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
|
||||
if require_active and model_bundle.get("status") != "ACTIVE":
|
||||
errors.append("model_bundle_manifest.status must be ACTIVE for Java SHADOW")
|
||||
if model_bundle.get("status") not in {"CANDIDATE", "ACTIVE"}:
|
||||
errors.append("model_bundle_manifest.status must be CANDIDATE or ACTIVE")
|
||||
manifests = read_json(root / "manifests/model_manifest.json")
|
||||
if len(manifests) != 5:
|
||||
errors.append("model_manifest.json must contain exactly five logical models")
|
||||
seen = {item.get("model_type") for item in manifests}
|
||||
if seen != set(MODEL_OUTPUTS):
|
||||
errors.append(f"model types mismatch: {seen}")
|
||||
for item in manifests:
|
||||
model_type = item.get("model_type")
|
||||
if item.get("model_format") != "ONNX":
|
||||
errors.append(f"{model_type} model_format must be ONNX")
|
||||
if item.get("input_tensor_name") != "features":
|
||||
errors.append(f"{model_type} input tensor must be features")
|
||||
if item.get("input_shape_json", {}).get("features") != 39:
|
||||
errors.append(f"{model_type} input_shape_json.features must be 39")
|
||||
if item.get("onnx_opset_version") != 17:
|
||||
errors.append(f"{model_type} opset must be 17")
|
||||
if item.get("output_mapping_json") != OUTPUT_MAPPING.get(model_type):
|
||||
errors.append(f"{model_type} output_mapping_json does not match Java contract")
|
||||
_check_hash(root, item.get("artifact_path"), item.get("artifact_hash_sha256"), errors)
|
||||
_check_hash(root, item.get("feature_schema_path"), item.get("feature_schema_hash"), errors)
|
||||
_check_hash(root, item.get("feature_order_path"), item.get("feature_order_hash"), errors)
|
||||
_check_hash(root, item.get("output_schema_path"), item.get("output_schema_hash"), errors)
|
||||
if require_active and item.get("status") != "ACTIVE":
|
||||
errors.append(f"{model_type} status must be ACTIVE for Java SHADOW")
|
||||
if item.get("status") != model_bundle.get("status"):
|
||||
errors.append(f"{model_type} status does not match model_bundle_manifest.status")
|
||||
calibrators = read_json(root / "manifests/calibration_manifest.json")
|
||||
if len(calibrators) != 5:
|
||||
errors.append("calibration_manifest.json must contain five calibrators")
|
||||
for item in calibrators:
|
||||
_check_hash(root, item.get("calibrator_path"), item.get("calibrator_hash_sha256"), errors)
|
||||
if require_active and item.get("status") != "ACTIVE":
|
||||
errors.append(f"{item.get('model_name')} calibrator status must be ACTIVE")
|
||||
if item.get("status") != model_bundle.get("status"):
|
||||
errors.append(f"{item.get('model_name')} calibrator status does not match model_bundle_manifest.status")
|
||||
pm_manifest = read_json(root / "manifests/position_manager_manifest.json")
|
||||
if require_active and pm_manifest.get("status") != "ACTIVE":
|
||||
errors.append("position_manager_manifest.status must be ACTIVE for Java SHADOW")
|
||||
if pm_manifest.get("status") != model_bundle.get("status"):
|
||||
errors.append("position_manager_manifest.status does not match model_bundle_manifest.status")
|
||||
if run_onnx and not errors:
|
||||
_run_sample_inference(root, manifests, errors)
|
||||
|
||||
|
||||
def _release_gate(root: Path) -> dict[str, Any]:
|
||||
reasons: list[str] = []
|
||||
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
|
||||
if model_bundle.get("backtest_status") != "PASS":
|
||||
reasons.append(f"backtest_status={model_bundle.get('backtest_status')}")
|
||||
reasons.extend(model_bundle.get("backtest_status_reasons_json", []))
|
||||
for item in read_json(root / "manifests/model_manifest.json"):
|
||||
if item.get("quality_status") != "PASS":
|
||||
reasons.append(f"{item.get('model_type')}.quality_status={item.get('quality_status')}")
|
||||
reasons.extend([f"{item.get('model_type')}:{reason}" for reason in item.get("quality_reasons_json", [])])
|
||||
for item in read_json(root / "manifests/calibration_manifest.json"):
|
||||
if item.get("quality_status") != "PASS":
|
||||
reasons.append(f"{item.get('model_name')}.calibration_quality_status={item.get('quality_status')}")
|
||||
reasons.extend([f"{item.get('model_name')}:calibration:{reason}" for reason in item.get("quality_reasons_json", [])])
|
||||
return {
|
||||
"release_gate_status": "PASS" if not reasons else "REJECTED",
|
||||
"release_gate_reasons": reasons,
|
||||
}
|
||||
|
||||
|
||||
def _check_hash(root: Path, relative: str | None, expected: str | None, errors: list[str]) -> None:
|
||||
if not relative or not expected:
|
||||
errors.append(f"missing hash contract for {relative}")
|
||||
return
|
||||
path = root / relative
|
||||
if not path.is_file():
|
||||
errors.append(f"hash target missing: {relative}")
|
||||
return
|
||||
actual = sha256_file(path)
|
||||
if actual != expected:
|
||||
errors.append(f"hash mismatch: {relative}")
|
||||
|
||||
|
||||
def _run_sample_inference(root: Path, manifests: list[dict[str, Any]], errors: list[str]) -> None:
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit("Python package 'onnxruntime' is required for --run-onnx validation.") from exc
|
||||
sample = read_json(root / "examples/sample_input.json")
|
||||
features = np.array([[float(sample[name]) for name in FEATURE_ORDER]], dtype=np.float32)
|
||||
for item in manifests:
|
||||
session = ort.InferenceSession(str(root / item["artifact_path"]))
|
||||
outputs = session.run(None, {"features": features})
|
||||
if not outputs or outputs[0].shape[-1] != len(MODEL_OUTPUTS[item["model_type"]]):
|
||||
errors.append(f"{item['model_type']} sample output shape is invalid")
|
||||
Reference in New Issue
Block a user