Align Continue labels with price plan outcomes

This commit is contained in:
Codex
2026-06-27 23:53:58 +08:00
parent 38a728c00b
commit 87849a66a7
4 changed files with 170 additions and 24 deletions
+98
View File
@@ -0,0 +1,98 @@
from __future__ import annotations
import sys
import tempfile
import unittest
from argparse import Namespace
from pathlib import Path
import pandas as pd
TRAINING_ROOT = Path(__file__).resolve().parents[1]
if str(TRAINING_ROOT) not in sys.path:
sys.path.insert(0, str(TRAINING_ROOT))
from trader_training.io_utils import write_json
from trader_training.labels import build_continue_exit_risk_labels
class ContinueLabelsTest(unittest.TestCase):
def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)
run_root = data_root / "trader-v4" / "runs" / "unit-continue"
feature_path = run_root / "feature" / "feature_frame.parquet"
replay_path = run_root / "replay" / "replay_1m.parquet"
plan_path = run_root / "label" / "price_plan_context.json"
config_path = data_root / "label_config.json"
feature_path.parent.mkdir(parents=True)
replay_path.parent.mkdir(parents=True)
plan_path.parent.mkdir(parents=True)
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
pd.DataFrame(
{
"sample_id": ["s0"],
"symbol": "BTC-USDT-PERP",
"event_time": [times[0]],
"open_time_ms": [0],
"split_id": ["fit_inner"],
"walk_forward_fold": [0],
"data_quality_flag": ["OK"],
"spread_bps": [1.0],
"spread_rank_24h_pct": [0.1],
"realized_vol_15m_bps": [2.0],
}
).to_parquet(feature_path, index=False)
pd.DataFrame(
{
"event_time": times,
"open_time_ms": [0, 60_000, 120_000, 180_000, 240_000],
"symbol": "BTC-USDT-PERP",
"open": [100.0, 100.0, 100.0, 100.0, 100.0],
"high": [100.0, 100.30, 100.0, 100.0, 100.0],
"low": [100.0, 100.0, 99.80, 100.0, 100.0],
"close": [100.0, 100.10, 100.0, 100.0, 100.0],
"spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4],
}
).to_parquet(replay_path, index=False)
write_json(
config_path,
{
"continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0},
"entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3},
},
)
write_json(
plan_path,
{
"pricePlanId": "unit-plan",
"pricePlanConfigHash": "unit-hash",
"targetDistanceBps": 20.0,
"stopDistanceBps": 8.0,
"maxHoldMinutes": 3,
"costBps": 6.5,
"entryLabelMethod": "PRICE_PLAN_OUTCOME_V1",
},
)
build_continue_exit_risk_labels(
Namespace(
data_root=data_root,
run_id="unit-continue",
feature_path=feature_path,
replay_path=replay_path,
label_config_path=config_path,
cost_config_path=None,
price_plan_context_path=plan_path,
)
)
labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet")
row = labels.iloc[0]
self.assertEqual(1, int(row["long_continue_target"]))
self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6)
if __name__ == "__main__":
unittest.main()
@@ -33,7 +33,9 @@ class StateContinueExperimentTest(unittest.TestCase):
"high_since_entry": 100.2,
"low_since_entry": 99.95,
"future_return_bps": 12.0,
"mae_bps": 3.0,
"gross_edge_bps": 12.0,
"mae_bps": 20.0,
"stop_hit": 0,
"entry_predicted_edge_bps": 8.5,
"entry_direction_prob": 0.64,
"add_count": 0.0,
+64 -21
View File
@@ -25,9 +25,15 @@ from trader_training.schemas import LABEL_VERSION
DEFAULT_LABEL_CONFIG = {
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
"continue": {"horizon_minutes": 45, "min_expected_continue_edge_bps": 5.0},
"exit": {"horizon_minutes": 45, "adverse_move_bps": 20.0, "stagnation_abs_return_bps": 5.0},
"risk": {
"horizon_minutes": 45,
"market_drawdown_bps": 60.0,
"position_path_risk_bps": 35.0,
"vol_expansion_ratio": 1.8,
"spike_bps": 80.0,
},
}
@@ -37,7 +43,7 @@ DEFAULT_COST_CONFIG = {
"funding_cost_bps": 0.5,
}
ENTRY_LABEL_METHOD = "MAX_FUTURE_EDGE_V1"
ENTRY_LABEL_METHOD = "PRICE_PLAN_OUTCOME_V1"
def _load_config(path, default):
@@ -53,6 +59,13 @@ def _load_config(path, default):
return merged
def _config_number(config: dict[str, Any], keys: tuple[str, ...], default: float) -> float:
for key in keys:
if key in config:
return float(config[key])
return default
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
root = run_root(args)
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
@@ -144,16 +157,16 @@ def _path_stats_for_group(group: pd.DataFrame, side: str, horizon: int, target_b
target_window = future_high >= target_price[:, None]
stop_window = future_low <= stop_price[:, None]
future_return_bps = (exit_price / entry - 1.0) * 10000.0
mfe_bps = (high_max / entry - 1.0) * 10000.0
mae_bps = (entry / low_min - 1.0) * 10000.0
mfe_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
mae_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
else:
target_price = entry * (1.0 - target_bps / 10000.0)
stop_price = entry * (1.0 + stop_bps / 10000.0)
target_window = future_low <= target_price[:, None]
stop_window = future_high >= stop_price[:, None]
future_return_bps = (entry / exit_price - 1.0) * 10000.0
mfe_bps = (entry / low_min - 1.0) * 10000.0
mae_bps = (high_max / entry - 1.0) * 10000.0
mfe_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
mae_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
target_any, first_target_idx = _first_hit_index(target_window)
stop_any, first_stop_idx = _first_hit_index(stop_window)
@@ -310,8 +323,8 @@ def build_entry_labels(args: Any) -> None:
merged = features[feature_columns].merge(stats, on=["symbol", "open_time_ms"], how="inner")
merged["max_achievable_gross_edge_bps"] = merged["mfe_bps"]
merged["max_achievable_net_edge_bps"] = merged["max_achievable_gross_edge_bps"] - cost_bps
merged["expected_net_edge_bps"] = merged["max_achievable_net_edge_bps"]
merged["entry_target"] = (merged["max_achievable_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
merged["expected_net_edge_bps"] = merged["gross_edge_bps"] - cost_bps
merged["entry_target"] = (merged["expected_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
merged["price_plan_id"] = plan["pricePlanId"]
merged["price_plan_hash"] = plan["pricePlanConfigHash"]
merged["cost_bps"] = cost_bps
@@ -403,9 +416,22 @@ def build_continue_exit_risk_labels(args: Any) -> None:
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
current_vol = merged["realized_vol_15m_bps"].astype(float).fillna(0.0).clip(lower=1.0)
risk_config = labels["risk"]
market_risk_threshold = _config_number(
risk_config,
("market_path_risk_threshold_bps", "market_drawdown_bps"),
60.0,
)
position_risk_threshold = _config_number(
risk_config,
("position_path_risk_threshold_bps", "position_path_risk_bps"),
35.0,
)
spike_threshold = _config_number(risk_config, ("spike_1m_threshold_bps", "spike_bps"), 80.0)
vol_expansion_ratio = _config_number(risk_config, ("vol_expansion_ratio",), 1.8)
long_edge = merged["long_future_return_bps"] - cost_bps
short_edge = merged["short_future_return_bps"] - cost_bps
long_edge = merged["long_gross_edge_bps"] - cost_bps
short_edge = merged["short_gross_edge_bps"] - cost_bps
path_risk = np.maximum(merged["long_mae_bps"], merged["short_mae_bps"])
max_path_move = np.maximum.reduce([merged["long_mfe_bps"], merged["short_mfe_bps"], path_risk])
if "ret_15m_bps" in merged.columns:
@@ -413,7 +439,9 @@ def build_continue_exit_risk_labels(args: Any) -> None:
else:
reversal = pd.Series(0, index=merged.index, dtype="int8")
future_vol = merged["long_future_realized_vol_bps"].fillna(0.0)
volatility_expansion = future_vol >= current_vol * float(labels["risk"]["vol_expansion_ratio"])
volatility_expansion = future_vol >= current_vol * vol_expansion_ratio
spike = max_path_move >= spike_threshold
market_risk = (path_risk >= market_risk_threshold) | spike | volatility_expansion
liquidity_deterioration = merged["spread_rank_24h_pct"].astype(float).fillna(0.0) >= 0.90
rows_continue = pd.DataFrame(
@@ -421,8 +449,8 @@ def build_continue_exit_risk_labels(args: Any) -> None:
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"long_continue_target": ((long_edge >= min_continue) & (merged["long_mae_bps"] < stop_bps)).astype("int8"),
"short_continue_target": ((short_edge >= min_continue) & (merged["short_mae_bps"] < stop_bps)).astype("int8"),
"long_continue_target": ((long_edge >= min_continue) & (merged["long_stop_hit"] == 0)).astype("int8"),
"short_continue_target": ((short_edge >= min_continue) & (merged["short_stop_hit"] == 0)).astype("int8"),
"long_expected_continue_edge_bps": long_edge,
"short_expected_continue_edge_bps": short_edge,
"split_id": merged["split_id"],
@@ -453,17 +481,17 @@ def build_continue_exit_risk_labels(args: Any) -> None:
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"market_risk_target": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
"market_risk_target": market_risk.astype("int8"),
"market_path_risk_bps": path_risk,
"long_position_path_risk_bps": merged["long_mae_bps"],
"short_position_path_risk_bps": merged["short_mae_bps"],
"long_position_risk_target": (merged["long_mae_bps"] >= stop_bps).astype("int8"),
"short_position_risk_target": (merged["short_mae_bps"] >= stop_bps).astype("int8"),
"market_drawdown_prob_label": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
"long_position_risk_target": ((merged["long_mae_bps"] >= position_risk_threshold) | (merged["long_stop_hit"] == 1)).astype("int8"),
"short_position_risk_target": ((merged["short_mae_bps"] >= position_risk_threshold) | (merged["short_stop_hit"] == 1)).astype("int8"),
"market_drawdown_prob_label": (path_risk >= market_risk_threshold).astype("int8"),
"volatility_expansion_prob_label": volatility_expansion.astype("int8"),
"spike_prob_label": (max_path_move >= float(labels["risk"]["spike_bps"])).astype("int8"),
"spike_prob_label": spike.astype("int8"),
"liquidity_deterioration_prob_label": liquidity_deterioration.astype("int8"),
"position_drawdown_prob_label": (path_risk >= stop_bps).astype("int8"),
"position_drawdown_prob_label": (path_risk >= position_risk_threshold).astype("int8"),
"split_id": merged["split_id"],
"walk_forward_fold": merged["walk_forward_fold"],
"label_version": LABEL_VERSION,
@@ -475,6 +503,21 @@ def build_continue_exit_risk_labels(args: Any) -> None:
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
]
report_parts = ["# Continue Exit Risk Label Report", ""]
report_parts.extend(
[
"## Risk Thresholds",
"",
str(
{
"market_risk_threshold_bps": market_risk_threshold,
"position_risk_threshold_bps": position_risk_threshold,
"spike_threshold_bps": spike_threshold,
"vol_expansion_ratio": vol_expansion_ratio,
}
),
"",
]
)
for name, frame, target in outputs:
path = root / "label" / f"{name}_labels.parquet"
data_hash = write_parquet(path, frame)
@@ -380,8 +380,9 @@ def _state_rows_for_age(frame: pd.DataFrame, stop_bps: float, target_bps: float,
target_price = np.where(long_mask, entry_price * (1.0 + target_bps / 10000.0), entry_price * (1.0 - target_bps / 10000.0))
distance_to_stop = np.where(long_mask, (current_price / stop_price - 1.0) * 10000.0, (stop_price / current_price - 1.0) * 10000.0)
distance_to_target = np.where(long_mask, (target_price / current_price - 1.0) * 10000.0, (current_price / target_price - 1.0) * 10000.0)
expected_edge = frame["future_return_bps"].astype(float) - cost_bps
continue_target = ((expected_edge >= min_continue_edge_bps) & (frame["mae_bps"].astype(float) < stop_bps)).astype("int8")
# Continue must score the first price-plan outcome from the current state, not the raw horizon close.
expected_edge = frame["gross_edge_bps"].astype(float) - cost_bps
continue_target = ((expected_edge >= min_continue_edge_bps) & (frame["stop_hit"].astype(int) == 0)).astype("int8")
out = frame[
[
@@ -601,6 +602,8 @@ def _source_manifest(
"uses_same_round_model_prediction_as_feature": False,
"entry_predicted_edge_bps": "baseline frozen ENTRY ONNX output selected by side",
"entry_direction_prob": "baseline frozen DIRECTION ONNX output selected by side",
"expected_continue_edge_bps": "price-plan gross edge minus cost; target/stop/timeout outcome is respected",
"continue_target": "expected_continue_edge_bps >= threshold and stop is not the first path barrier",
"path_features": "position path shape and side-adjusted market pressure at current state time",
"add_count": "synthetic first-position diagnostic, fixed to 0",
"minutes_since_last_add": "synthetic first-position diagnostic, fixed to 9999",