Use actual plan edge for Entry PM training
This commit is contained in:
@@ -232,8 +232,8 @@ def _pm_frame(root, split_id: str) -> pd.DataFrame:
|
||||
price_plan = _price_plan_context(root)
|
||||
entry_dataset = read_parquet(root / "dataset" / "entry_train.parquet").rename(
|
||||
columns={
|
||||
"long_expected_net_edge_bps": "actual_long_expected_net_edge_bps",
|
||||
"short_expected_net_edge_bps": "actual_short_expected_net_edge_bps",
|
||||
"long_actual_plan_net_edge_bps": "actual_long_plan_edge_bps",
|
||||
"short_actual_plan_net_edge_bps": "actual_short_plan_edge_bps",
|
||||
}
|
||||
)
|
||||
entry_plan_outcome = _entry_plan_outcome_frame(root)
|
||||
@@ -245,7 +245,10 @@ def _pm_frame(root, split_id: str) -> pd.DataFrame:
|
||||
"pred_short_expected_net_edge_bps",
|
||||
]
|
||||
risk_cols = ["sample_id", "market_risk_prob", "long_position_risk_prob", "short_position_risk_prob"]
|
||||
actual_cols = ["sample_id", "actual_long_expected_net_edge_bps", "actual_short_expected_net_edge_bps", "long_entry_target", "short_entry_target"]
|
||||
actual_cols = ["sample_id", "actual_long_plan_edge_bps", "actual_short_plan_edge_bps", "long_entry_target", "short_entry_target"]
|
||||
missing_actual_cols = sorted(set(actual_cols) - set(entry_dataset.columns))
|
||||
if missing_actual_cols:
|
||||
raise ValueError(f"entry_train is missing actual plan edge columns for PM: {missing_actual_cols}")
|
||||
frame = (
|
||||
direction[["sample_id", "symbol", "event_time", "split_id", "long_prob", "short_prob", "neutral_prob"]]
|
||||
.merge(entry[entry_cols], on="sample_id", how="inner")
|
||||
@@ -257,7 +260,7 @@ def _pm_frame(root, split_id: str) -> pd.DataFrame:
|
||||
raise ValueError(f"PM frame is empty for {split_id}; check model predictions and entry dataset")
|
||||
frame["model_pred_long_expected_net_edge_bps"] = frame["pred_long_expected_net_edge_bps"]
|
||||
frame["model_pred_short_expected_net_edge_bps"] = frame["pred_short_expected_net_edge_bps"]
|
||||
edge_mode = "MODEL_EXPECTED_NET_EDGE"
|
||||
edge_mode = "MODEL_ACTUAL_PLAN_EDGE"
|
||||
if price_plan.get("entryTargetMethod") not in {"OPPORTUNITY_MFE_V1", "OPPORTUNITY_QUALITY_V1"}:
|
||||
frame["pred_long_expected_net_edge_bps"] = _probability_implied_edge(frame["long_entry_prob"], price_plan)
|
||||
frame["pred_short_expected_net_edge_bps"] = _probability_implied_edge(frame["short_entry_prob"], price_plan)
|
||||
@@ -333,9 +336,9 @@ def _threshold_candidates() -> list[dict[str, float]]:
|
||||
values = itertools.product(
|
||||
[0.50, 0.60, 0.70, 1.01],
|
||||
[0.50, 0.60, 0.70, 1.01],
|
||||
[0.03, 0.50, 0.70, 0.85],
|
||||
[0.45, 0.65, 0.85],
|
||||
[0.0, 8.0, 15.0, 25.0],
|
||||
[0.30, 0.50, 0.70, 0.85],
|
||||
[0.45, 0.65],
|
||||
[3.0, 8.0, 15.0, 25.0],
|
||||
[0.02, 0.06, 0.10],
|
||||
)
|
||||
return [
|
||||
@@ -394,7 +397,7 @@ def _simulate_open_trades(
|
||||
trades["entry_prob"] = np.where(is_long, trades["long_entry_prob"], trades["short_entry_prob"])
|
||||
trades["predicted_edge_bps"] = np.where(is_long, trades["pred_long_expected_net_edge_bps"], trades["pred_short_expected_net_edge_bps"])
|
||||
trades["actual_edge_bps"] = np.where(is_long, trades["long_trade_net_edge_bps"], trades["short_trade_net_edge_bps"])
|
||||
trades["label_max_edge_bps"] = np.where(is_long, trades["actual_long_expected_net_edge_bps"], trades["actual_short_expected_net_edge_bps"])
|
||||
trades["label_actual_plan_edge_bps"] = np.where(is_long, trades["actual_long_plan_edge_bps"], trades["actual_short_plan_edge_bps"])
|
||||
trades["entry_target"] = np.where(is_long, trades["long_entry_target"], trades["short_entry_target"])
|
||||
effective_pm_config = pm_config or _pm_config_from_thresholds(thresholds)
|
||||
effective_price_plan = price_plan or DEFAULT_BACKTEST_PRICE_PLAN
|
||||
@@ -417,7 +420,7 @@ def _simulate_open_trades(
|
||||
"entry_prob",
|
||||
"market_risk_prob",
|
||||
"predicted_edge_bps",
|
||||
"label_max_edge_bps",
|
||||
"label_actual_plan_edge_bps",
|
||||
"actual_edge_bps",
|
||||
"entry_target",
|
||||
"time_to_exit_ms",
|
||||
@@ -440,7 +443,7 @@ def _empty_trade_frame() -> pd.DataFrame:
|
||||
"entry_prob",
|
||||
"market_risk_prob",
|
||||
"predicted_edge_bps",
|
||||
"label_max_edge_bps",
|
||||
"label_actual_plan_edge_bps",
|
||||
"actual_edge_bps",
|
||||
"entry_target",
|
||||
"time_to_exit_ms",
|
||||
|
||||
Reference in New Issue
Block a user