Align Continue labels with price plan outcomes

This commit is contained in:
Codex
2026-06-27 23:53:58 +08:00
parent 38a728c00b
commit 87849a66a7
4 changed files with 170 additions and 24 deletions
+64 -21
View File
@@ -25,9 +25,15 @@ from trader_training.schemas import LABEL_VERSION
DEFAULT_LABEL_CONFIG = {
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
"continue": {"horizon_minutes": 45, "min_expected_continue_edge_bps": 5.0},
"exit": {"horizon_minutes": 45, "adverse_move_bps": 20.0, "stagnation_abs_return_bps": 5.0},
"risk": {
"horizon_minutes": 45,
"market_drawdown_bps": 60.0,
"position_path_risk_bps": 35.0,
"vol_expansion_ratio": 1.8,
"spike_bps": 80.0,
},
}
@@ -37,7 +43,7 @@ DEFAULT_COST_CONFIG = {
"funding_cost_bps": 0.5,
}
ENTRY_LABEL_METHOD = "MAX_FUTURE_EDGE_V1"
ENTRY_LABEL_METHOD = "PRICE_PLAN_OUTCOME_V1"
def _load_config(path, default):
@@ -53,6 +59,13 @@ def _load_config(path, default):
return merged
def _config_number(config: dict[str, Any], keys: tuple[str, ...], default: float) -> float:
for key in keys:
if key in config:
return float(config[key])
return default
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
root = run_root(args)
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
@@ -144,16 +157,16 @@ def _path_stats_for_group(group: pd.DataFrame, side: str, horizon: int, target_b
target_window = future_high >= target_price[:, None]
stop_window = future_low <= stop_price[:, None]
future_return_bps = (exit_price / entry - 1.0) * 10000.0
mfe_bps = (high_max / entry - 1.0) * 10000.0
mae_bps = (entry / low_min - 1.0) * 10000.0
mfe_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
mae_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
else:
target_price = entry * (1.0 - target_bps / 10000.0)
stop_price = entry * (1.0 + stop_bps / 10000.0)
target_window = future_low <= target_price[:, None]
stop_window = future_high >= stop_price[:, None]
future_return_bps = (entry / exit_price - 1.0) * 10000.0
mfe_bps = (entry / low_min - 1.0) * 10000.0
mae_bps = (high_max / entry - 1.0) * 10000.0
mfe_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
mae_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
target_any, first_target_idx = _first_hit_index(target_window)
stop_any, first_stop_idx = _first_hit_index(stop_window)
@@ -310,8 +323,8 @@ def build_entry_labels(args: Any) -> None:
merged = features[feature_columns].merge(stats, on=["symbol", "open_time_ms"], how="inner")
merged["max_achievable_gross_edge_bps"] = merged["mfe_bps"]
merged["max_achievable_net_edge_bps"] = merged["max_achievable_gross_edge_bps"] - cost_bps
merged["expected_net_edge_bps"] = merged["max_achievable_net_edge_bps"]
merged["entry_target"] = (merged["max_achievable_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
merged["expected_net_edge_bps"] = merged["gross_edge_bps"] - cost_bps
merged["entry_target"] = (merged["expected_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
merged["price_plan_id"] = plan["pricePlanId"]
merged["price_plan_hash"] = plan["pricePlanConfigHash"]
merged["cost_bps"] = cost_bps
@@ -403,9 +416,22 @@ def build_continue_exit_risk_labels(args: Any) -> None:
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
current_vol = merged["realized_vol_15m_bps"].astype(float).fillna(0.0).clip(lower=1.0)
risk_config = labels["risk"]
market_risk_threshold = _config_number(
risk_config,
("market_path_risk_threshold_bps", "market_drawdown_bps"),
60.0,
)
position_risk_threshold = _config_number(
risk_config,
("position_path_risk_threshold_bps", "position_path_risk_bps"),
35.0,
)
spike_threshold = _config_number(risk_config, ("spike_1m_threshold_bps", "spike_bps"), 80.0)
vol_expansion_ratio = _config_number(risk_config, ("vol_expansion_ratio",), 1.8)
long_edge = merged["long_future_return_bps"] - cost_bps
short_edge = merged["short_future_return_bps"] - cost_bps
long_edge = merged["long_gross_edge_bps"] - cost_bps
short_edge = merged["short_gross_edge_bps"] - cost_bps
path_risk = np.maximum(merged["long_mae_bps"], merged["short_mae_bps"])
max_path_move = np.maximum.reduce([merged["long_mfe_bps"], merged["short_mfe_bps"], path_risk])
if "ret_15m_bps" in merged.columns:
@@ -413,7 +439,9 @@ def build_continue_exit_risk_labels(args: Any) -> None:
else:
reversal = pd.Series(0, index=merged.index, dtype="int8")
future_vol = merged["long_future_realized_vol_bps"].fillna(0.0)
volatility_expansion = future_vol >= current_vol * float(labels["risk"]["vol_expansion_ratio"])
volatility_expansion = future_vol >= current_vol * vol_expansion_ratio
spike = max_path_move >= spike_threshold
market_risk = (path_risk >= market_risk_threshold) | spike | volatility_expansion
liquidity_deterioration = merged["spread_rank_24h_pct"].astype(float).fillna(0.0) >= 0.90
rows_continue = pd.DataFrame(
@@ -421,8 +449,8 @@ def build_continue_exit_risk_labels(args: Any) -> None:
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"long_continue_target": ((long_edge >= min_continue) & (merged["long_mae_bps"] < stop_bps)).astype("int8"),
"short_continue_target": ((short_edge >= min_continue) & (merged["short_mae_bps"] < stop_bps)).astype("int8"),
"long_continue_target": ((long_edge >= min_continue) & (merged["long_stop_hit"] == 0)).astype("int8"),
"short_continue_target": ((short_edge >= min_continue) & (merged["short_stop_hit"] == 0)).astype("int8"),
"long_expected_continue_edge_bps": long_edge,
"short_expected_continue_edge_bps": short_edge,
"split_id": merged["split_id"],
@@ -453,17 +481,17 @@ def build_continue_exit_risk_labels(args: Any) -> None:
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"market_risk_target": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
"market_risk_target": market_risk.astype("int8"),
"market_path_risk_bps": path_risk,
"long_position_path_risk_bps": merged["long_mae_bps"],
"short_position_path_risk_bps": merged["short_mae_bps"],
"long_position_risk_target": (merged["long_mae_bps"] >= stop_bps).astype("int8"),
"short_position_risk_target": (merged["short_mae_bps"] >= stop_bps).astype("int8"),
"market_drawdown_prob_label": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
"long_position_risk_target": ((merged["long_mae_bps"] >= position_risk_threshold) | (merged["long_stop_hit"] == 1)).astype("int8"),
"short_position_risk_target": ((merged["short_mae_bps"] >= position_risk_threshold) | (merged["short_stop_hit"] == 1)).astype("int8"),
"market_drawdown_prob_label": (path_risk >= market_risk_threshold).astype("int8"),
"volatility_expansion_prob_label": volatility_expansion.astype("int8"),
"spike_prob_label": (max_path_move >= float(labels["risk"]["spike_bps"])).astype("int8"),
"spike_prob_label": spike.astype("int8"),
"liquidity_deterioration_prob_label": liquidity_deterioration.astype("int8"),
"position_drawdown_prob_label": (path_risk >= stop_bps).astype("int8"),
"position_drawdown_prob_label": (path_risk >= position_risk_threshold).astype("int8"),
"split_id": merged["split_id"],
"walk_forward_fold": merged["walk_forward_fold"],
"label_version": LABEL_VERSION,
@@ -475,6 +503,21 @@ def build_continue_exit_risk_labels(args: Any) -> None:
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
]
report_parts = ["# Continue Exit Risk Label Report", ""]
report_parts.extend(
[
"## Risk Thresholds",
"",
str(
{
"market_risk_threshold_bps": market_risk_threshold,
"position_risk_threshold_bps": position_risk_threshold,
"spike_threshold_bps": spike_threshold,
"vol_expansion_ratio": vol_expansion_ratio,
}
),
"",
]
)
for name, frame, target in outputs:
path = root / "label" / f"{name}_labels.parquet"
data_hash = write_parquet(path, frame)