Files
quant-trader-service/training/tests/test_continue_labels.py
T

99 lines
3.8 KiB
Python
Raw Normal View History

from __future__ import annotations
import sys
import tempfile
import unittest
from argparse import Namespace
from pathlib import Path
import pandas as pd
TRAINING_ROOT = Path(__file__).resolve().parents[1]
if str(TRAINING_ROOT) not in sys.path:
sys.path.insert(0, str(TRAINING_ROOT))
from trader_training.io_utils import write_json
from trader_training.labels import build_continue_exit_risk_labels
class ContinueLabelsTest(unittest.TestCase):
def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)
run_root = data_root / "trader-v4" / "runs" / "unit-continue"
feature_path = run_root / "feature" / "feature_frame.parquet"
replay_path = run_root / "replay" / "replay_1m.parquet"
plan_path = run_root / "label" / "price_plan_context.json"
config_path = data_root / "label_config.json"
feature_path.parent.mkdir(parents=True)
replay_path.parent.mkdir(parents=True)
plan_path.parent.mkdir(parents=True)
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
pd.DataFrame(
{
"sample_id": ["s0"],
"symbol": "BTC-USDT-PERP",
"event_time": [times[0]],
"open_time_ms": [0],
"split_id": ["fit_inner"],
"walk_forward_fold": [0],
"data_quality_flag": ["OK"],
"spread_bps": [1.0],
"spread_rank_24h_pct": [0.1],
"realized_vol_15m_bps": [2.0],
}
).to_parquet(feature_path, index=False)
pd.DataFrame(
{
"event_time": times,
"open_time_ms": [0, 60_000, 120_000, 180_000, 240_000],
"symbol": "BTC-USDT-PERP",
"open": [100.0, 100.0, 100.0, 100.0, 100.0],
"high": [100.0, 100.30, 100.0, 100.0, 100.0],
"low": [100.0, 100.0, 99.80, 100.0, 100.0],
"close": [100.0, 100.10, 100.0, 100.0, 100.0],
"spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4],
}
).to_parquet(replay_path, index=False)
write_json(
config_path,
{
"continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0},
"entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3},
},
)
write_json(
plan_path,
{
"pricePlanId": "unit-plan",
"pricePlanConfigHash": "unit-hash",
"targetDistanceBps": 20.0,
"stopDistanceBps": 8.0,
"maxHoldMinutes": 3,
"costBps": 6.5,
"entryLabelMethod": "PRICE_PLAN_OUTCOME_V1",
},
)
build_continue_exit_risk_labels(
Namespace(
data_root=data_root,
run_id="unit-continue",
feature_path=feature_path,
replay_path=replay_path,
label_config_path=config_path,
cost_config_path=None,
price_plan_context_path=plan_path,
)
)
labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet")
row = labels.iloc[0]
self.assertEqual(1, int(row["long_continue_target"]))
self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6)
if __name__ == "__main__":
unittest.main()