Align Continue labels with price plan outcomes
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.io_utils import write_json
|
||||
from trader_training.labels import build_continue_exit_risk_labels
|
||||
|
||||
|
||||
class ContinueLabelsTest(unittest.TestCase):
|
||||
def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
run_root = data_root / "trader-v4" / "runs" / "unit-continue"
|
||||
feature_path = run_root / "feature" / "feature_frame.parquet"
|
||||
replay_path = run_root / "replay" / "replay_1m.parquet"
|
||||
plan_path = run_root / "label" / "price_plan_context.json"
|
||||
config_path = data_root / "label_config.json"
|
||||
feature_path.parent.mkdir(parents=True)
|
||||
replay_path.parent.mkdir(parents=True)
|
||||
plan_path.parent.mkdir(parents=True)
|
||||
|
||||
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
|
||||
pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s0"],
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"event_time": [times[0]],
|
||||
"open_time_ms": [0],
|
||||
"split_id": ["fit_inner"],
|
||||
"walk_forward_fold": [0],
|
||||
"data_quality_flag": ["OK"],
|
||||
"spread_bps": [1.0],
|
||||
"spread_rank_24h_pct": [0.1],
|
||||
"realized_vol_15m_bps": [2.0],
|
||||
}
|
||||
).to_parquet(feature_path, index=False)
|
||||
pd.DataFrame(
|
||||
{
|
||||
"event_time": times,
|
||||
"open_time_ms": [0, 60_000, 120_000, 180_000, 240_000],
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"open": [100.0, 100.0, 100.0, 100.0, 100.0],
|
||||
"high": [100.0, 100.30, 100.0, 100.0, 100.0],
|
||||
"low": [100.0, 100.0, 99.80, 100.0, 100.0],
|
||||
"close": [100.0, 100.10, 100.0, 100.0, 100.0],
|
||||
"spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
).to_parquet(replay_path, index=False)
|
||||
write_json(
|
||||
config_path,
|
||||
{
|
||||
"continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0},
|
||||
"entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3},
|
||||
},
|
||||
)
|
||||
write_json(
|
||||
plan_path,
|
||||
{
|
||||
"pricePlanId": "unit-plan",
|
||||
"pricePlanConfigHash": "unit-hash",
|
||||
"targetDistanceBps": 20.0,
|
||||
"stopDistanceBps": 8.0,
|
||||
"maxHoldMinutes": 3,
|
||||
"costBps": 6.5,
|
||||
"entryLabelMethod": "PRICE_PLAN_OUTCOME_V1",
|
||||
},
|
||||
)
|
||||
|
||||
build_continue_exit_risk_labels(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-continue",
|
||||
feature_path=feature_path,
|
||||
replay_path=replay_path,
|
||||
label_config_path=config_path,
|
||||
cost_config_path=None,
|
||||
price_plan_context_path=plan_path,
|
||||
)
|
||||
)
|
||||
|
||||
labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet")
|
||||
row = labels.iloc[0]
|
||||
self.assertEqual(1, int(row["long_continue_target"]))
|
||||
self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user