Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.onnx_export import LinearHead, export_heads
|
||||
from trader_training.io_utils import read_json, write_json
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
from trader_training.replay import build_splits
|
||||
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
class TrainingContractTest(unittest.TestCase):
|
||||
def test_feature_order_is_v4_contract_size(self) -> None:
|
||||
self.assertEqual(39, len(FEATURE_ORDER))
|
||||
self.assertEqual(len(FEATURE_ORDER), len(set(FEATURE_ORDER)))
|
||||
self.assertEqual("ret_1m_bps", FEATURE_ORDER[0])
|
||||
self.assertEqual("minutes_to_next_funding", FEATURE_ORDER[-1])
|
||||
|
||||
def test_output_mapping_matches_model_outputs(self) -> None:
|
||||
for model_name, fields in MODEL_OUTPUTS.items():
|
||||
self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name]))
|
||||
self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields])
|
||||
|
||||
def test_split_builder_uses_locked_validation_contract(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
replay_path = data_root / "replay_1m.parquet"
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"event_time": pd.date_range("2025-06-20", "2026-06-19", freq="D", tz="UTC"),
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
}
|
||||
)
|
||||
frame.to_parquet(replay_path, index=False)
|
||||
|
||||
build_splits(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-split",
|
||||
replay_path=replay_path,
|
||||
fit_inner_start="2025-06-20",
|
||||
fit_inner_end="2026-01-15",
|
||||
tune_inner_start="2026-01-16",
|
||||
tune_inner_end="2026-02-28",
|
||||
validation_locked_start="2026-03-01",
|
||||
validation_locked_end="2026-04-30",
|
||||
latest_stress_start="2026-05-01",
|
||||
latest_stress_end="2026-06-19",
|
||||
gap_minutes=0,
|
||||
fold_count=2,
|
||||
)
|
||||
)
|
||||
|
||||
manifest = read_json(data_root / "trader-v4" / "runs" / "unit-split" / "split" / "split_manifest.json")
|
||||
self.assertEqual(set(TRAINING_SPLITS), {item["split_id"] for item in manifest["splits"]})
|
||||
self.assertEqual([VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], manifest["sealed_splits"])
|
||||
self.assertEqual("FINAL_GATE_ONLY", manifest["latest_stress_policy"])
|
||||
|
||||
def test_exported_onnx_accepts_java_feature_shape(self) -> None:
|
||||
import onnxruntime as ort
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "direction.onnx"
|
||||
export_heads(
|
||||
path,
|
||||
[
|
||||
LinearHead(
|
||||
"direction",
|
||||
"softmax",
|
||||
np.zeros((39, 3), dtype=np.float32),
|
||||
np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
)
|
||||
],
|
||||
)
|
||||
session = ort.InferenceSession(str(path))
|
||||
output = session.run(None, {"features": np.zeros((1, 39), dtype=np.float32)})[0]
|
||||
self.assertEqual((1, 3), output.shape)
|
||||
self.assertAlmostEqual(1.0, float(output.sum()), places=6)
|
||||
|
||||
def test_promotion_requires_passed_validation_and_marks_all_manifests_active(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
manifest_dir = root / "manifests"
|
||||
manifest_dir.mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_validation_result.json", {"status": "PASS", "release_gate_status": "PASS", "release_gate_reasons": []})
|
||||
write_json(manifest_dir / "model_bundle_manifest.json", {"status": "CANDIDATE"})
|
||||
write_json(manifest_dir / "model_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
|
||||
write_json(manifest_dir / "calibration_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
|
||||
write_json(manifest_dir / "position_manager_manifest.json", {"status": "CANDIDATE"})
|
||||
write_json(manifest_dir / "training_export_manifest.json", {"status": "CANDIDATE"})
|
||||
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_bundle_manifest.json")["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_manifest.json")[0]["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "calibration_manifest.json")[0]["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "position_manager_manifest.json")["status"])
|
||||
self.assertEqual("ACTIVE", read_json(manifest_dir / "training_export_manifest.json")["status"])
|
||||
|
||||
def test_promotion_refuses_failed_validation(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
(root / "manifests").mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_validation_result.json", {"status": "FAIL"})
|
||||
with self.assertRaises(SystemExit):
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
result = read_json(root.parent / "artifact_promotion_result.json")
|
||||
self.assertEqual("REFUSED", result["status"])
|
||||
self.assertEqual("validation result is not PASS", result["message"])
|
||||
|
||||
def test_promotion_refuses_failed_release_gate_and_overwrites_stale_result(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp) / "artifact_bundle"
|
||||
(root / "manifests").mkdir(parents=True)
|
||||
write_json(root.parent / "artifact_promotion_result.json", {"status": "ACTIVE"})
|
||||
write_json(
|
||||
root.parent / "artifact_validation_result.json",
|
||||
{
|
||||
"status": "PASS",
|
||||
"release_gate_status": "REJECTED",
|
||||
"release_gate_reasons": ["backtest_status=REJECTED"],
|
||||
},
|
||||
)
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
|
||||
|
||||
result = read_json(root.parent / "artifact_promotion_result.json")
|
||||
self.assertEqual("REFUSED", result["status"])
|
||||
self.assertEqual("release gate is not PASS", result["message"])
|
||||
self.assertEqual(["backtest_status=REJECTED"], result["release_gate_reasons"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user