from __future__ import annotations import sys import tempfile import unittest from argparse import Namespace from pathlib import Path import pandas as pd TRAINING_ROOT = Path(__file__).resolve().parents[1] if str(TRAINING_ROOT) not in sys.path: sys.path.insert(0, str(TRAINING_ROOT)) from trader_training.io_utils import write_json from trader_training.labels import build_continue_exit_risk_labels class ContinueLabelsTest(unittest.TestCase): def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None: with tempfile.TemporaryDirectory() as tmp: data_root = Path(tmp) run_root = data_root / "trader-v4" / "runs" / "unit-continue" feature_path = run_root / "feature" / "feature_frame.parquet" replay_path = run_root / "replay" / "replay_1m.parquet" plan_path = run_root / "label" / "price_plan_context.json" config_path = data_root / "label_config.json" feature_path.parent.mkdir(parents=True) replay_path.parent.mkdir(parents=True) plan_path.parent.mkdir(parents=True) times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC") pd.DataFrame( { "sample_id": ["s0"], "symbol": "BTC-USDT-PERP", "event_time": [times[0]], "open_time_ms": [0], "split_id": ["fit_inner"], "walk_forward_fold": [0], "data_quality_flag": ["OK"], "spread_bps": [1.0], "spread_rank_24h_pct": [0.1], "realized_vol_15m_bps": [2.0], } ).to_parquet(feature_path, index=False) pd.DataFrame( { "event_time": times, "open_time_ms": [0, 60_000, 120_000, 180_000, 240_000], "symbol": "BTC-USDT-PERP", "open": [100.0, 100.0, 100.0, 100.0, 100.0], "high": [100.0, 100.30, 100.0, 100.0, 100.0], "low": [100.0, 100.0, 99.80, 100.0, 100.0], "close": [100.0, 100.10, 100.0, 100.0, 100.0], "spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4], } ).to_parquet(replay_path, index=False) write_json( config_path, { "continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0}, "entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3}, }, ) write_json( plan_path, { "pricePlanId": "unit-plan", "pricePlanConfigHash": "unit-hash", "targetDistanceBps": 20.0, "stopDistanceBps": 8.0, "maxHoldMinutes": 3, "costBps": 6.5, "entryLabelMethod": "PRICE_PLAN_OUTCOME_V1", }, ) build_continue_exit_risk_labels( Namespace( data_root=data_root, run_id="unit-continue", feature_path=feature_path, replay_path=replay_path, label_config_path=config_path, cost_config_path=None, price_plan_context_path=plan_path, ) ) labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet") row = labels.iloc[0] self.assertEqual(1, int(row["long_continue_target"])) self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6) if __name__ == "__main__": unittest.main()