99 lines
3.8 KiB
Python
99 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from argparse import Namespace
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
|
|
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(TRAINING_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(TRAINING_ROOT))
|
|
|
|
from trader_training.io_utils import write_json
|
|
from trader_training.labels import build_continue_exit_risk_labels
|
|
|
|
|
|
class ContinueLabelsTest(unittest.TestCase):
|
|
def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None:
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
data_root = Path(tmp)
|
|
run_root = data_root / "trader-v4" / "runs" / "unit-continue"
|
|
feature_path = run_root / "feature" / "feature_frame.parquet"
|
|
replay_path = run_root / "replay" / "replay_1m.parquet"
|
|
plan_path = run_root / "label" / "price_plan_context.json"
|
|
config_path = data_root / "label_config.json"
|
|
feature_path.parent.mkdir(parents=True)
|
|
replay_path.parent.mkdir(parents=True)
|
|
plan_path.parent.mkdir(parents=True)
|
|
|
|
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
|
|
pd.DataFrame(
|
|
{
|
|
"sample_id": ["s0"],
|
|
"symbol": "BTC-USDT-PERP",
|
|
"event_time": [times[0]],
|
|
"open_time_ms": [0],
|
|
"split_id": ["fit_inner"],
|
|
"walk_forward_fold": [0],
|
|
"data_quality_flag": ["OK"],
|
|
"spread_bps": [1.0],
|
|
"spread_rank_24h_pct": [0.1],
|
|
"realized_vol_15m_bps": [2.0],
|
|
}
|
|
).to_parquet(feature_path, index=False)
|
|
pd.DataFrame(
|
|
{
|
|
"event_time": times,
|
|
"open_time_ms": [0, 60_000, 120_000, 180_000, 240_000],
|
|
"symbol": "BTC-USDT-PERP",
|
|
"open": [100.0, 100.0, 100.0, 100.0, 100.0],
|
|
"high": [100.0, 100.30, 100.0, 100.0, 100.0],
|
|
"low": [100.0, 100.0, 99.80, 100.0, 100.0],
|
|
"close": [100.0, 100.10, 100.0, 100.0, 100.0],
|
|
"spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4],
|
|
}
|
|
).to_parquet(replay_path, index=False)
|
|
write_json(
|
|
config_path,
|
|
{
|
|
"continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0},
|
|
"entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3},
|
|
},
|
|
)
|
|
write_json(
|
|
plan_path,
|
|
{
|
|
"pricePlanId": "unit-plan",
|
|
"pricePlanConfigHash": "unit-hash",
|
|
"targetDistanceBps": 20.0,
|
|
"stopDistanceBps": 8.0,
|
|
"maxHoldMinutes": 3,
|
|
"costBps": 6.5,
|
|
"entryLabelMethod": "PRICE_PLAN_OUTCOME_V1",
|
|
},
|
|
)
|
|
|
|
build_continue_exit_risk_labels(
|
|
Namespace(
|
|
data_root=data_root,
|
|
run_id="unit-continue",
|
|
feature_path=feature_path,
|
|
replay_path=replay_path,
|
|
label_config_path=config_path,
|
|
cost_config_path=None,
|
|
price_plan_context_path=plan_path,
|
|
)
|
|
)
|
|
|
|
labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet")
|
|
row = labels.iloc[0]
|
|
self.assertEqual(1, int(row["long_continue_target"]))
|
|
self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|