Files
quant-trader-service/training/tests/test_training_contract.py
T
Codex 9acb3460a1 Improve Trader V4 training pipeline
Align entry labels with max future edge, tune direction labeling, and harden regression evaluation.

Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
2026-06-27 19:57:29 +08:00

253 lines
12 KiB
Python

from __future__ import annotations
import sys
import tempfile
import unittest
from argparse import Namespace
from pathlib import Path
import numpy as np
import pandas as pd
TRAINING_ROOT = Path(__file__).resolve().parents[1]
if str(TRAINING_ROOT) not in sys.path:
sys.path.insert(0, str(TRAINING_ROOT))
from trader_training.onnx_export import LinearHead, export_heads
from trader_training.io_utils import read_json, write_json
from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels
from trader_training.promote import promote_artifact_bundle
from trader_training.replay import build_splits
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
class TrainingContractTest(unittest.TestCase):
def test_feature_order_is_v4_contract_size(self) -> None:
self.assertEqual(54, len(FEATURE_ORDER))
self.assertEqual(len(FEATURE_ORDER), len(set(FEATURE_ORDER)))
self.assertEqual("ret_1m_bps", FEATURE_ORDER[0])
self.assertEqual("book_pressure_reversal_15m", FEATURE_ORDER[-1])
def test_output_mapping_matches_model_outputs(self) -> None:
for model_name, fields in MODEL_OUTPUTS.items():
self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name]))
self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields])
def test_split_builder_uses_locked_validation_contract(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)
replay_path = data_root / "replay_1m.parquet"
frame = pd.DataFrame(
{
"event_time": pd.date_range("2025-06-20", "2026-06-19", freq="D", tz="UTC"),
"symbol": "BTC-USDT-PERP",
}
)
frame.to_parquet(replay_path, index=False)
build_splits(
Namespace(
data_root=data_root,
run_id="unit-split",
replay_path=replay_path,
fit_inner_start="2025-06-20",
fit_inner_end="2026-01-15",
tune_inner_start="2026-01-16",
tune_inner_end="2026-02-28",
validation_locked_start="2026-03-01",
validation_locked_end="2026-04-30",
latest_stress_start="2026-05-01",
latest_stress_end="2026-06-19",
gap_minutes=0,
fold_count=2,
)
)
manifest = read_json(data_root / "trader-v4" / "runs" / "unit-split" / "split" / "split_manifest.json")
self.assertEqual(set(TRAINING_SPLITS), {item["split_id"] for item in manifest["splits"]})
self.assertEqual([VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], manifest["sealed_splits"])
self.assertEqual("FINAL_GATE_ONLY", manifest["latest_stress_policy"])
def test_path_stats_keeps_same_bar_target_stop_as_stop_first(self) -> None:
frame = pd.DataFrame(
{
"event_time": pd.date_range("2026-01-01", periods=6, freq="min", tz="UTC"),
"open_time_ms": np.arange(6, dtype=np.int64) * 60_000,
"symbol": "BTC-USDT-PERP",
"close": [100.0, 100.0, 100.0, 100.0, 100.0, 100.0],
"high": [100.0, 100.05, 100.20, 100.0, 100.0, 100.0],
"low": [100.0, 99.95, 99.70, 100.0, 100.0, 100.0],
"spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
}
)
stats = _path_stats_for_group(frame, "LONG", horizon=3, target_bps=10.0, stop_bps=8.0)
first = stats.loc[stats["open_time_ms"].eq(0)].iloc[0]
self.assertEqual(0, first["target_hit"])
self.assertEqual(1, first["stop_hit"])
self.assertEqual(1, first["ambiguous_hit"])
self.assertEqual(120_000, first["time_to_stop_ms"])
self.assertAlmostEqual(-8.0, first["gross_edge_bps"])
def test_entry_label_uses_max_future_edge_not_fixed_target_hit(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)
run_root = data_root / "trader-v4" / "runs" / "unit-entry"
feature_path = run_root / "feature" / "feature_frame.parquet"
replay_path = run_root / "replay" / "replay_1m.parquet"
plan_path = run_root / "label" / "price_plan_context.json"
config_path = data_root / "label_config.json"
feature_path.parent.mkdir(parents=True)
replay_path.parent.mkdir(parents=True)
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
pd.DataFrame(
{
"sample_id": ["s0", "s1"],
"symbol": "BTC-USDT-PERP",
"event_time": times[:2],
"open_time_ms": [0, 60_000],
"split_id": "fit_inner",
"walk_forward_fold": 0,
"data_quality_flag": "OK",
"spread_bps": 1.0,
"spread_rank_24h_pct": 0.1,
"realized_vol_15m_bps": 2.0,
}
).to_parquet(feature_path, index=False)
pd.DataFrame(
{
"event_time": times,
"open_time_ms": np.arange(5, dtype=np.int64) * 60_000,
"symbol": "BTC-USDT-PERP",
"open": [100.0, 100.0, 100.0, 100.0, 100.0],
"high": [100.0, 100.05, 100.19, 100.20, 100.0],
"low": [100.0, 99.99, 99.98, 99.97, 100.0],
"close": [100.0, 100.0, 100.0, 100.0, 100.0],
"spread_bps": 1.0,
}
).to_parquet(replay_path, index=False)
write_json(
config_path,
{
"entry": {
"max_hold_minutes": 3,
"target_bps": 50.0,
"stop_bps": 50.0,
"min_expected_net_edge_bps": 3.0,
}
},
)
write_json(
plan_path,
{
"pricePlanId": "unit-plan",
"pricePlanConfigHash": "unit-hash",
"targetDistanceBps": 50.0,
"stopDistanceBps": 50.0,
"maxHoldMinutes": 3,
"costBps": 6.5,
"entryLabelMethod": ENTRY_LABEL_METHOD,
},
)
build_entry_labels(
Namespace(
data_root=data_root,
run_id="unit-entry",
feature_path=feature_path,
replay_path=replay_path,
label_config_path=config_path,
cost_config_path=None,
price_plan_context_path=plan_path,
)
)
labels = pd.read_parquet(run_root / "label" / "entry_labels.parquet")
row = labels[labels["sample_id"].eq("s0") & labels["side"].eq("LONG")].iloc[0]
self.assertEqual(0, row["target_hit"])
self.assertEqual(1, row["entry_target"])
self.assertEqual(ENTRY_LABEL_METHOD, row["label_method"])
self.assertAlmostEqual(13.5, row["expected_net_edge_bps"], places=6)
self.assertAlmostEqual(row["mfe_bps"] - row["cost_bps"], row["max_achievable_net_edge_bps"], places=6)
def test_exported_onnx_accepts_java_feature_shape(self) -> None:
import onnxruntime as ort
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "direction.onnx"
export_heads(
path,
[
LinearHead(
"direction",
"softmax",
np.zeros((len(FEATURE_ORDER), 3), dtype=np.float32),
np.array([0.1, 0.2, 0.3], dtype=np.float32),
)
],
feature_count=len(FEATURE_ORDER),
)
session = ort.InferenceSession(str(path))
output = session.run(None, {"features": np.zeros((1, len(FEATURE_ORDER)), dtype=np.float32)})[0]
self.assertEqual((1, 3), output.shape)
self.assertAlmostEqual(1.0, float(output.sum()), places=6)
def test_promotion_requires_passed_validation_and_marks_all_manifests_active(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp) / "artifact_bundle"
manifest_dir = root / "manifests"
manifest_dir.mkdir(parents=True)
write_json(root.parent / "artifact_validation_result.json", {"status": "PASS", "release_gate_status": "PASS", "release_gate_reasons": []})
write_json(manifest_dir / "model_bundle_manifest.json", {"status": "CANDIDATE"})
write_json(manifest_dir / "model_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
write_json(manifest_dir / "calibration_manifest.json", [{"model_name": "DIRECTION", "status": "CANDIDATE"}])
write_json(manifest_dir / "position_manager_manifest.json", {"status": "CANDIDATE"})
write_json(manifest_dir / "training_export_manifest.json", {"status": "CANDIDATE"})
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_bundle_manifest.json")["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "model_manifest.json")[0]["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "calibration_manifest.json")[0]["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "position_manager_manifest.json")["status"])
self.assertEqual("ACTIVE", read_json(manifest_dir / "training_export_manifest.json")["status"])
def test_promotion_refuses_failed_validation(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp) / "artifact_bundle"
(root / "manifests").mkdir(parents=True)
write_json(root.parent / "artifact_validation_result.json", {"status": "FAIL"})
with self.assertRaises(SystemExit):
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
result = read_json(root.parent / "artifact_promotion_result.json")
self.assertEqual("REFUSED", result["status"])
self.assertEqual("validation result is not PASS", result["message"])
def test_promotion_refuses_failed_release_gate_and_overwrites_stale_result(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp) / "artifact_bundle"
(root / "manifests").mkdir(parents=True)
write_json(root.parent / "artifact_promotion_result.json", {"status": "ACTIVE"})
write_json(
root.parent / "artifact_validation_result.json",
{
"status": "PASS",
"release_gate_status": "REJECTED",
"release_gate_reasons": ["backtest_status=REJECTED"],
},
)
with self.assertRaises(SystemExit):
promote_artifact_bundle(Namespace(artifact_root=root, reason="unit test"))
result = read_json(root.parent / "artifact_promotion_result.json")
self.assertEqual("REFUSED", result["status"])
self.assertEqual("release gate is not PASS", result["message"])
self.assertEqual(["backtest_status=REJECTED"], result["release_gate_reasons"])
if __name__ == "__main__":
unittest.main()