Align Continue labels with price plan outcomes
This commit is contained in:
@@ -25,9 +25,15 @@ from trader_training.schemas import LABEL_VERSION
|
||||
DEFAULT_LABEL_CONFIG = {
|
||||
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
|
||||
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
|
||||
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
|
||||
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
|
||||
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
|
||||
"continue": {"horizon_minutes": 45, "min_expected_continue_edge_bps": 5.0},
|
||||
"exit": {"horizon_minutes": 45, "adverse_move_bps": 20.0, "stagnation_abs_return_bps": 5.0},
|
||||
"risk": {
|
||||
"horizon_minutes": 45,
|
||||
"market_drawdown_bps": 60.0,
|
||||
"position_path_risk_bps": 35.0,
|
||||
"vol_expansion_ratio": 1.8,
|
||||
"spike_bps": 80.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +43,7 @@ DEFAULT_COST_CONFIG = {
|
||||
"funding_cost_bps": 0.5,
|
||||
}
|
||||
|
||||
ENTRY_LABEL_METHOD = "MAX_FUTURE_EDGE_V1"
|
||||
ENTRY_LABEL_METHOD = "PRICE_PLAN_OUTCOME_V1"
|
||||
|
||||
|
||||
def _load_config(path, default):
|
||||
@@ -53,6 +59,13 @@ def _load_config(path, default):
|
||||
return merged
|
||||
|
||||
|
||||
def _config_number(config: dict[str, Any], keys: tuple[str, ...], default: float) -> float:
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return float(config[key])
|
||||
return default
|
||||
|
||||
|
||||
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
root = run_root(args)
|
||||
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
||||
@@ -144,16 +157,16 @@ def _path_stats_for_group(group: pd.DataFrame, side: str, horizon: int, target_b
|
||||
target_window = future_high >= target_price[:, None]
|
||||
stop_window = future_low <= stop_price[:, None]
|
||||
future_return_bps = (exit_price / entry - 1.0) * 10000.0
|
||||
mfe_bps = (high_max / entry - 1.0) * 10000.0
|
||||
mae_bps = (entry / low_min - 1.0) * 10000.0
|
||||
mfe_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
|
||||
mae_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
|
||||
else:
|
||||
target_price = entry * (1.0 - target_bps / 10000.0)
|
||||
stop_price = entry * (1.0 + stop_bps / 10000.0)
|
||||
target_window = future_low <= target_price[:, None]
|
||||
stop_window = future_high >= stop_price[:, None]
|
||||
future_return_bps = (entry / exit_price - 1.0) * 10000.0
|
||||
mfe_bps = (entry / low_min - 1.0) * 10000.0
|
||||
mae_bps = (high_max / entry - 1.0) * 10000.0
|
||||
mfe_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
|
||||
mae_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
|
||||
|
||||
target_any, first_target_idx = _first_hit_index(target_window)
|
||||
stop_any, first_stop_idx = _first_hit_index(stop_window)
|
||||
@@ -310,8 +323,8 @@ def build_entry_labels(args: Any) -> None:
|
||||
merged = features[feature_columns].merge(stats, on=["symbol", "open_time_ms"], how="inner")
|
||||
merged["max_achievable_gross_edge_bps"] = merged["mfe_bps"]
|
||||
merged["max_achievable_net_edge_bps"] = merged["max_achievable_gross_edge_bps"] - cost_bps
|
||||
merged["expected_net_edge_bps"] = merged["max_achievable_net_edge_bps"]
|
||||
merged["entry_target"] = (merged["max_achievable_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
|
||||
merged["expected_net_edge_bps"] = merged["gross_edge_bps"] - cost_bps
|
||||
merged["entry_target"] = (merged["expected_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
|
||||
merged["price_plan_id"] = plan["pricePlanId"]
|
||||
merged["price_plan_hash"] = plan["pricePlanConfigHash"]
|
||||
merged["cost_bps"] = cost_bps
|
||||
@@ -403,9 +416,22 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
|
||||
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
|
||||
current_vol = merged["realized_vol_15m_bps"].astype(float).fillna(0.0).clip(lower=1.0)
|
||||
risk_config = labels["risk"]
|
||||
market_risk_threshold = _config_number(
|
||||
risk_config,
|
||||
("market_path_risk_threshold_bps", "market_drawdown_bps"),
|
||||
60.0,
|
||||
)
|
||||
position_risk_threshold = _config_number(
|
||||
risk_config,
|
||||
("position_path_risk_threshold_bps", "position_path_risk_bps"),
|
||||
35.0,
|
||||
)
|
||||
spike_threshold = _config_number(risk_config, ("spike_1m_threshold_bps", "spike_bps"), 80.0)
|
||||
vol_expansion_ratio = _config_number(risk_config, ("vol_expansion_ratio",), 1.8)
|
||||
|
||||
long_edge = merged["long_future_return_bps"] - cost_bps
|
||||
short_edge = merged["short_future_return_bps"] - cost_bps
|
||||
long_edge = merged["long_gross_edge_bps"] - cost_bps
|
||||
short_edge = merged["short_gross_edge_bps"] - cost_bps
|
||||
path_risk = np.maximum(merged["long_mae_bps"], merged["short_mae_bps"])
|
||||
max_path_move = np.maximum.reduce([merged["long_mfe_bps"], merged["short_mfe_bps"], path_risk])
|
||||
if "ret_15m_bps" in merged.columns:
|
||||
@@ -413,7 +439,9 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
else:
|
||||
reversal = pd.Series(0, index=merged.index, dtype="int8")
|
||||
future_vol = merged["long_future_realized_vol_bps"].fillna(0.0)
|
||||
volatility_expansion = future_vol >= current_vol * float(labels["risk"]["vol_expansion_ratio"])
|
||||
volatility_expansion = future_vol >= current_vol * vol_expansion_ratio
|
||||
spike = max_path_move >= spike_threshold
|
||||
market_risk = (path_risk >= market_risk_threshold) | spike | volatility_expansion
|
||||
liquidity_deterioration = merged["spread_rank_24h_pct"].astype(float).fillna(0.0) >= 0.90
|
||||
|
||||
rows_continue = pd.DataFrame(
|
||||
@@ -421,8 +449,8 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
"sample_id": merged["sample_id"],
|
||||
"symbol": merged["symbol"],
|
||||
"event_time": merged["event_time"],
|
||||
"long_continue_target": ((long_edge >= min_continue) & (merged["long_mae_bps"] < stop_bps)).astype("int8"),
|
||||
"short_continue_target": ((short_edge >= min_continue) & (merged["short_mae_bps"] < stop_bps)).astype("int8"),
|
||||
"long_continue_target": ((long_edge >= min_continue) & (merged["long_stop_hit"] == 0)).astype("int8"),
|
||||
"short_continue_target": ((short_edge >= min_continue) & (merged["short_stop_hit"] == 0)).astype("int8"),
|
||||
"long_expected_continue_edge_bps": long_edge,
|
||||
"short_expected_continue_edge_bps": short_edge,
|
||||
"split_id": merged["split_id"],
|
||||
@@ -453,17 +481,17 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
"sample_id": merged["sample_id"],
|
||||
"symbol": merged["symbol"],
|
||||
"event_time": merged["event_time"],
|
||||
"market_risk_target": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
|
||||
"market_risk_target": market_risk.astype("int8"),
|
||||
"market_path_risk_bps": path_risk,
|
||||
"long_position_path_risk_bps": merged["long_mae_bps"],
|
||||
"short_position_path_risk_bps": merged["short_mae_bps"],
|
||||
"long_position_risk_target": (merged["long_mae_bps"] >= stop_bps).astype("int8"),
|
||||
"short_position_risk_target": (merged["short_mae_bps"] >= stop_bps).astype("int8"),
|
||||
"market_drawdown_prob_label": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
|
||||
"long_position_risk_target": ((merged["long_mae_bps"] >= position_risk_threshold) | (merged["long_stop_hit"] == 1)).astype("int8"),
|
||||
"short_position_risk_target": ((merged["short_mae_bps"] >= position_risk_threshold) | (merged["short_stop_hit"] == 1)).astype("int8"),
|
||||
"market_drawdown_prob_label": (path_risk >= market_risk_threshold).astype("int8"),
|
||||
"volatility_expansion_prob_label": volatility_expansion.astype("int8"),
|
||||
"spike_prob_label": (max_path_move >= float(labels["risk"]["spike_bps"])).astype("int8"),
|
||||
"spike_prob_label": spike.astype("int8"),
|
||||
"liquidity_deterioration_prob_label": liquidity_deterioration.astype("int8"),
|
||||
"position_drawdown_prob_label": (path_risk >= stop_bps).astype("int8"),
|
||||
"position_drawdown_prob_label": (path_risk >= position_risk_threshold).astype("int8"),
|
||||
"split_id": merged["split_id"],
|
||||
"walk_forward_fold": merged["walk_forward_fold"],
|
||||
"label_version": LABEL_VERSION,
|
||||
@@ -475,6 +503,21 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
|
||||
]
|
||||
report_parts = ["# Continue Exit Risk Label Report", ""]
|
||||
report_parts.extend(
|
||||
[
|
||||
"## Risk Thresholds",
|
||||
"",
|
||||
str(
|
||||
{
|
||||
"market_risk_threshold_bps": market_risk_threshold,
|
||||
"position_risk_threshold_bps": position_risk_threshold,
|
||||
"spike_threshold_bps": spike_threshold,
|
||||
"vol_expansion_ratio": vol_expansion_ratio,
|
||||
}
|
||||
),
|
||||
"",
|
||||
]
|
||||
)
|
||||
for name, frame, target in outputs:
|
||||
path = root / "label" / f"{name}_labels.parquet"
|
||||
data_hash = write_parquet(path, frame)
|
||||
|
||||
Reference in New Issue
Block a user