From 87849a66a732174e8debbad87821d4344b12f2f4 Mon Sep 17 00:00:00 2001 From: Codex Date: Sat, 27 Jun 2026 23:53:58 +0800 Subject: [PATCH] Align Continue labels with price plan outcomes --- training/tests/test_continue_labels.py | 98 +++++++++++++++++++ .../tests/test_state_continue_experiment.py | 4 +- training/trader_training/labels.py | 85 ++++++++++++---- .../state_continue_experiment.py | 7 +- 4 files changed, 170 insertions(+), 24 deletions(-) create mode 100644 training/tests/test_continue_labels.py diff --git a/training/tests/test_continue_labels.py b/training/tests/test_continue_labels.py new file mode 100644 index 0000000..10e0d2c --- /dev/null +++ b/training/tests/test_continue_labels.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import sys +import tempfile +import unittest +from argparse import Namespace +from pathlib import Path + +import pandas as pd + +TRAINING_ROOT = Path(__file__).resolve().parents[1] +if str(TRAINING_ROOT) not in sys.path: + sys.path.insert(0, str(TRAINING_ROOT)) + +from trader_training.io_utils import write_json +from trader_training.labels import build_continue_exit_risk_labels + + +class ContinueLabelsTest(unittest.TestCase): + def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + data_root = Path(tmp) + run_root = data_root / "trader-v4" / "runs" / "unit-continue" + feature_path = run_root / "feature" / "feature_frame.parquet" + replay_path = run_root / "replay" / "replay_1m.parquet" + plan_path = run_root / "label" / "price_plan_context.json" + config_path = data_root / "label_config.json" + feature_path.parent.mkdir(parents=True) + replay_path.parent.mkdir(parents=True) + plan_path.parent.mkdir(parents=True) + + times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC") + pd.DataFrame( + { + "sample_id": ["s0"], + "symbol": "BTC-USDT-PERP", + "event_time": [times[0]], + "open_time_ms": [0], + "split_id": ["fit_inner"], + "walk_forward_fold": [0], + "data_quality_flag": ["OK"], + "spread_bps": [1.0], + "spread_rank_24h_pct": [0.1], + "realized_vol_15m_bps": [2.0], + } + ).to_parquet(feature_path, index=False) + pd.DataFrame( + { + "event_time": times, + "open_time_ms": [0, 60_000, 120_000, 180_000, 240_000], + "symbol": "BTC-USDT-PERP", + "open": [100.0, 100.0, 100.0, 100.0, 100.0], + "high": [100.0, 100.30, 100.0, 100.0, 100.0], + "low": [100.0, 100.0, 99.80, 100.0, 100.0], + "close": [100.0, 100.10, 100.0, 100.0, 100.0], + "spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4], + } + ).to_parquet(replay_path, index=False) + write_json( + config_path, + { + "continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0}, + "entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3}, + }, + ) + write_json( + plan_path, + { + "pricePlanId": "unit-plan", + "pricePlanConfigHash": "unit-hash", + "targetDistanceBps": 20.0, + "stopDistanceBps": 8.0, + "maxHoldMinutes": 3, + "costBps": 6.5, + "entryLabelMethod": "PRICE_PLAN_OUTCOME_V1", + }, + ) + + build_continue_exit_risk_labels( + Namespace( + data_root=data_root, + run_id="unit-continue", + feature_path=feature_path, + replay_path=replay_path, + label_config_path=config_path, + cost_config_path=None, + price_plan_context_path=plan_path, + ) + ) + + labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet") + row = labels.iloc[0] + self.assertEqual(1, int(row["long_continue_target"])) + self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6) + + +if __name__ == "__main__": + unittest.main() diff --git a/training/tests/test_state_continue_experiment.py b/training/tests/test_state_continue_experiment.py index 765bd7b..0fe6d75 100644 --- a/training/tests/test_state_continue_experiment.py +++ b/training/tests/test_state_continue_experiment.py @@ -33,7 +33,9 @@ class StateContinueExperimentTest(unittest.TestCase): "high_since_entry": 100.2, "low_since_entry": 99.95, "future_return_bps": 12.0, - "mae_bps": 3.0, + "gross_edge_bps": 12.0, + "mae_bps": 20.0, + "stop_hit": 0, "entry_predicted_edge_bps": 8.5, "entry_direction_prob": 0.64, "add_count": 0.0, diff --git a/training/trader_training/labels.py b/training/trader_training/labels.py index 36db829..aa1b9ac 100644 --- a/training/trader_training/labels.py +++ b/training/trader_training/labels.py @@ -25,9 +25,15 @@ from trader_training.schemas import LABEL_VERSION DEFAULT_LABEL_CONFIG = { "direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0}, "entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0}, - "continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0}, - "exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0}, - "risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0}, + "continue": {"horizon_minutes": 45, "min_expected_continue_edge_bps": 5.0}, + "exit": {"horizon_minutes": 45, "adverse_move_bps": 20.0, "stagnation_abs_return_bps": 5.0}, + "risk": { + "horizon_minutes": 45, + "market_drawdown_bps": 60.0, + "position_path_risk_bps": 35.0, + "vol_expansion_ratio": 1.8, + "spike_bps": 80.0, + }, } @@ -37,7 +43,7 @@ DEFAULT_COST_CONFIG = { "funding_cost_bps": 0.5, } -ENTRY_LABEL_METHOD = "MAX_FUTURE_EDGE_V1" +ENTRY_LABEL_METHOD = "PRICE_PLAN_OUTCOME_V1" def _load_config(path, default): @@ -53,6 +59,13 @@ def _load_config(path, default): return merged +def _config_number(config: dict[str, Any], keys: tuple[str, ...], default: float) -> float: + for key in keys: + if key in config: + return float(config[key]) + return default + + def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]: root = run_root(args) feature_path = args.feature_path or root / "feature" / "feature_frame.parquet" @@ -144,16 +157,16 @@ def _path_stats_for_group(group: pd.DataFrame, side: str, horizon: int, target_b target_window = future_high >= target_price[:, None] stop_window = future_low <= stop_price[:, None] future_return_bps = (exit_price / entry - 1.0) * 10000.0 - mfe_bps = (high_max / entry - 1.0) * 10000.0 - mae_bps = (entry / low_min - 1.0) * 10000.0 + mfe_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0) + mae_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0) else: target_price = entry * (1.0 - target_bps / 10000.0) stop_price = entry * (1.0 + stop_bps / 10000.0) target_window = future_low <= target_price[:, None] stop_window = future_high >= stop_price[:, None] future_return_bps = (entry / exit_price - 1.0) * 10000.0 - mfe_bps = (entry / low_min - 1.0) * 10000.0 - mae_bps = (high_max / entry - 1.0) * 10000.0 + mfe_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0) + mae_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0) target_any, first_target_idx = _first_hit_index(target_window) stop_any, first_stop_idx = _first_hit_index(stop_window) @@ -310,8 +323,8 @@ def build_entry_labels(args: Any) -> None: merged = features[feature_columns].merge(stats, on=["symbol", "open_time_ms"], how="inner") merged["max_achievable_gross_edge_bps"] = merged["mfe_bps"] merged["max_achievable_net_edge_bps"] = merged["max_achievable_gross_edge_bps"] - cost_bps - merged["expected_net_edge_bps"] = merged["max_achievable_net_edge_bps"] - merged["entry_target"] = (merged["max_achievable_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8") + merged["expected_net_edge_bps"] = merged["gross_edge_bps"] - cost_bps + merged["entry_target"] = (merged["expected_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8") merged["price_plan_id"] = plan["pricePlanId"] merged["price_plan_hash"] = plan["pricePlanConfigHash"] merged["cost_bps"] = cost_bps @@ -403,9 +416,22 @@ def build_continue_exit_risk_labels(args: Any) -> None: min_continue = float(labels["continue"]["min_expected_continue_edge_bps"]) adverse_threshold = float(labels["exit"]["adverse_move_bps"]) current_vol = merged["realized_vol_15m_bps"].astype(float).fillna(0.0).clip(lower=1.0) + risk_config = labels["risk"] + market_risk_threshold = _config_number( + risk_config, + ("market_path_risk_threshold_bps", "market_drawdown_bps"), + 60.0, + ) + position_risk_threshold = _config_number( + risk_config, + ("position_path_risk_threshold_bps", "position_path_risk_bps"), + 35.0, + ) + spike_threshold = _config_number(risk_config, ("spike_1m_threshold_bps", "spike_bps"), 80.0) + vol_expansion_ratio = _config_number(risk_config, ("vol_expansion_ratio",), 1.8) - long_edge = merged["long_future_return_bps"] - cost_bps - short_edge = merged["short_future_return_bps"] - cost_bps + long_edge = merged["long_gross_edge_bps"] - cost_bps + short_edge = merged["short_gross_edge_bps"] - cost_bps path_risk = np.maximum(merged["long_mae_bps"], merged["short_mae_bps"]) max_path_move = np.maximum.reduce([merged["long_mfe_bps"], merged["short_mfe_bps"], path_risk]) if "ret_15m_bps" in merged.columns: @@ -413,7 +439,9 @@ def build_continue_exit_risk_labels(args: Any) -> None: else: reversal = pd.Series(0, index=merged.index, dtype="int8") future_vol = merged["long_future_realized_vol_bps"].fillna(0.0) - volatility_expansion = future_vol >= current_vol * float(labels["risk"]["vol_expansion_ratio"]) + volatility_expansion = future_vol >= current_vol * vol_expansion_ratio + spike = max_path_move >= spike_threshold + market_risk = (path_risk >= market_risk_threshold) | spike | volatility_expansion liquidity_deterioration = merged["spread_rank_24h_pct"].astype(float).fillna(0.0) >= 0.90 rows_continue = pd.DataFrame( @@ -421,8 +449,8 @@ def build_continue_exit_risk_labels(args: Any) -> None: "sample_id": merged["sample_id"], "symbol": merged["symbol"], "event_time": merged["event_time"], - "long_continue_target": ((long_edge >= min_continue) & (merged["long_mae_bps"] < stop_bps)).astype("int8"), - "short_continue_target": ((short_edge >= min_continue) & (merged["short_mae_bps"] < stop_bps)).astype("int8"), + "long_continue_target": ((long_edge >= min_continue) & (merged["long_stop_hit"] == 0)).astype("int8"), + "short_continue_target": ((short_edge >= min_continue) & (merged["short_stop_hit"] == 0)).astype("int8"), "long_expected_continue_edge_bps": long_edge, "short_expected_continue_edge_bps": short_edge, "split_id": merged["split_id"], @@ -453,17 +481,17 @@ def build_continue_exit_risk_labels(args: Any) -> None: "sample_id": merged["sample_id"], "symbol": merged["symbol"], "event_time": merged["event_time"], - "market_risk_target": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"), + "market_risk_target": market_risk.astype("int8"), "market_path_risk_bps": path_risk, "long_position_path_risk_bps": merged["long_mae_bps"], "short_position_path_risk_bps": merged["short_mae_bps"], - "long_position_risk_target": (merged["long_mae_bps"] >= stop_bps).astype("int8"), - "short_position_risk_target": (merged["short_mae_bps"] >= stop_bps).astype("int8"), - "market_drawdown_prob_label": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"), + "long_position_risk_target": ((merged["long_mae_bps"] >= position_risk_threshold) | (merged["long_stop_hit"] == 1)).astype("int8"), + "short_position_risk_target": ((merged["short_mae_bps"] >= position_risk_threshold) | (merged["short_stop_hit"] == 1)).astype("int8"), + "market_drawdown_prob_label": (path_risk >= market_risk_threshold).astype("int8"), "volatility_expansion_prob_label": volatility_expansion.astype("int8"), - "spike_prob_label": (max_path_move >= float(labels["risk"]["spike_bps"])).astype("int8"), + "spike_prob_label": spike.astype("int8"), "liquidity_deterioration_prob_label": liquidity_deterioration.astype("int8"), - "position_drawdown_prob_label": (path_risk >= stop_bps).astype("int8"), + "position_drawdown_prob_label": (path_risk >= position_risk_threshold).astype("int8"), "split_id": merged["split_id"], "walk_forward_fold": merged["walk_forward_fold"], "label_version": LABEL_VERSION, @@ -475,6 +503,21 @@ def build_continue_exit_risk_labels(args: Any) -> None: ("risk", pd.DataFrame(rows_risk), "market_risk_target"), ] report_parts = ["# Continue Exit Risk Label Report", ""] + report_parts.extend( + [ + "## Risk Thresholds", + "", + str( + { + "market_risk_threshold_bps": market_risk_threshold, + "position_risk_threshold_bps": position_risk_threshold, + "spike_threshold_bps": spike_threshold, + "vol_expansion_ratio": vol_expansion_ratio, + } + ), + "", + ] + ) for name, frame, target in outputs: path = root / "label" / f"{name}_labels.parquet" data_hash = write_parquet(path, frame) diff --git a/training/trader_training/state_continue_experiment.py b/training/trader_training/state_continue_experiment.py index b0ca781..ffaa443 100644 --- a/training/trader_training/state_continue_experiment.py +++ b/training/trader_training/state_continue_experiment.py @@ -380,8 +380,9 @@ def _state_rows_for_age(frame: pd.DataFrame, stop_bps: float, target_bps: float, target_price = np.where(long_mask, entry_price * (1.0 + target_bps / 10000.0), entry_price * (1.0 - target_bps / 10000.0)) distance_to_stop = np.where(long_mask, (current_price / stop_price - 1.0) * 10000.0, (stop_price / current_price - 1.0) * 10000.0) distance_to_target = np.where(long_mask, (target_price / current_price - 1.0) * 10000.0, (current_price / target_price - 1.0) * 10000.0) - expected_edge = frame["future_return_bps"].astype(float) - cost_bps - continue_target = ((expected_edge >= min_continue_edge_bps) & (frame["mae_bps"].astype(float) < stop_bps)).astype("int8") + # Continue must score the first price-plan outcome from the current state, not the raw horizon close. + expected_edge = frame["gross_edge_bps"].astype(float) - cost_bps + continue_target = ((expected_edge >= min_continue_edge_bps) & (frame["stop_hit"].astype(int) == 0)).astype("int8") out = frame[ [ @@ -601,6 +602,8 @@ def _source_manifest( "uses_same_round_model_prediction_as_feature": False, "entry_predicted_edge_bps": "baseline frozen ENTRY ONNX output selected by side", "entry_direction_prob": "baseline frozen DIRECTION ONNX output selected by side", + "expected_continue_edge_bps": "price-plan gross edge minus cost; target/stop/timeout outcome is respected", + "continue_target": "expected_continue_edge_bps >= threshold and stop is not the first path barrier", "path_features": "position path shape and side-adjusted market pressure at current state time", "add_count": "synthetic first-position diagnostic, fixed to 0", "minutes_since_last_add": "synthetic first-position diagnostic, fixed to 9999",