From 6be4bb976a284017c738ec76bcee37300aeb7d61 Mon Sep 17 00:00:00 2001 From: Codex Date: Sun, 28 Jun 2026 09:27:59 +0800 Subject: [PATCH] Add Entry opportunity training diagnostics --- training/scripts/11_train_small_models.py | 7 + .../30_diagnose_good_trade_structure.py | 26 ++ training/tests/test_training_contract.py | 42 ++ .../trader_training/good_trade_structure.py | 363 ++++++++++++++++++ training/trader_training/training.py | 40 +- 5 files changed, 468 insertions(+), 10 deletions(-) create mode 100644 training/scripts/30_diagnose_good_trade_structure.py create mode 100644 training/trader_training/good_trade_structure.py diff --git a/training/scripts/11_train_small_models.py b/training/scripts/11_train_small_models.py index 70ff17c..b41815f 100644 --- a/training/scripts/11_train_small_models.py +++ b/training/scripts/11_train_small_models.py @@ -11,6 +11,13 @@ def main() -> None: parser = argparse.ArgumentParser() add_common_args(parser) parser.add_argument("--max-rows", type=int, default=0) + parser.add_argument( + "--conditional-entry-source", + choices=("none", "direction_label", "side_opportunity"), + default="none", + help="Entry 训练样本人群来源:不筛选、按 Direction 标签筛选、或按本方向未来机会阈值筛选。", + ) + parser.add_argument("--conditional-entry-opportunity-bps", type=float, default=40.0) parser.add_argument("--conditional-entry-direction-labels", action="store_true") parser.add_argument("--conditional-entry-min-fit-rows", type=int, default=1000) args = parser.parse_args() diff --git a/training/scripts/30_diagnose_good_trade_structure.py b/training/scripts/30_diagnose_good_trade_structure.py new file mode 100644 index 0000000..f99a7f3 --- /dev/null +++ b/training/scripts/30_diagnose_good_trade_structure.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import argparse + +import _bootstrap # noqa: F401 +from trader_training.good_trade_structure import diagnose_good_trade_structure +from trader_training.io_utils import add_common_args, setup_logging + + +def _float_tuple(value: str) -> tuple[float, ...]: + return tuple(float(item.strip()) for item in value.split(",") if item.strip()) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Diagnose whether existing features separate good and bad Entry trades.") + add_common_args(parser) + parser.add_argument("--min-good-edge-bps", type=float, default=3.0) + parser.add_argument("--bad-edge-bps", type=float, default=-3.0) + parser.add_argument("--top-fractions", type=_float_tuple, default=(0.01, 0.05, 0.10)) + args = parser.parse_args() + setup_logging() + diagnose_good_trade_structure(args) + + +if __name__ == "__main__": + main() diff --git a/training/tests/test_training_contract.py b/training/tests/test_training_contract.py index 6854c4b..315d21e 100644 --- a/training/tests/test_training_contract.py +++ b/training/tests/test_training_contract.py @@ -20,6 +20,7 @@ from trader_training.dynamic_exit_search import search_dynamic_exit_plans from trader_training.entry_condition_pair_screen import screen_entry_condition_pairs from trader_training.entry_feature_screen import _bucket_edges, _screen_edge_column from trader_training.entry_mae_label_diagnostic import diagnose_entry_mae_labels +from trader_training.good_trade_structure import _side_frame, _top_fraction_metrics from trader_training.io_utils import read_json, write_json from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels from trader_training.nonlinear_pm_probe import _expanded_threshold_candidates @@ -97,6 +98,27 @@ class TrainingContractTest(unittest.TestCase): self.assertEqual("ALL_FIT_ROWS", default_filter) self.assertEqual([True, True, True, True], default_mask.tolist()) + def test_conditional_entry_training_can_use_side_opportunity_rows(self) -> None: + train = pd.DataFrame( + { + "long_max_achievable_net_edge_bps": [45.0, 10.0, 60.0, 39.0], + "short_max_achievable_net_edge_bps": [8.0, 42.0, 15.0, 80.0], + } + ) + args = Namespace( + conditional_entry_source="side_opportunity", + conditional_entry_direction_labels=False, + conditional_entry_opportunity_bps=40.0, + ) + + long_mask, long_filter = _head_train_mask("ENTRY", "long_entry_prob", train, args) + short_mask, short_filter = _head_train_mask("ENTRY", "short_expected_net_edge_bps", train, args) + + self.assertEqual("SIDE_OPPORTUNITY_LONG_GE_40_BPS_FIT_ROWS", long_filter) + self.assertEqual([True, False, True, False], long_mask.tolist()) + self.assertEqual("SIDE_OPPORTUNITY_SHORT_GE_40_BPS_FIT_ROWS", short_filter) + self.assertEqual([False, True, False, True], short_mask.tolist()) + def test_direction_opportunity_labels_choose_clear_path_opportunity(self) -> None: labels = _opportunity_labels( np.array([45.0, 10.0, 45.0, 42.0, np.nan]), @@ -142,6 +164,26 @@ class TrainingContractTest(unittest.TestCase): self.assertEqual("dataset/entry_train.parquet", summary["fit_inner"]["entry"]["source"]) self.assertEqual(0.5, summary["fit_inner"]["entry"]["target_rate_by_side"]["LONG"]) + def test_good_trade_structure_builds_side_frame_and_top_metrics(self) -> None: + dataset = pd.DataFrame( + { + "sample_id": ["s1", "s2", "s3"], + "split_id": ["fit_inner", "fit_inner", "fit_inner"], + "long_actual_plan_net_edge_bps": [4.0, -5.0, 1.0], + "short_actual_plan_net_edge_bps": [-5.0, 6.0, -1.0], + **{feature: [0.1, 0.2, 0.3] for feature in FEATURE_ORDER}, + } + ) + + frame = _side_frame(dataset, "LONG", min_good_edge_bps=3.0, bad_edge_bps=-3.0) + metrics = _top_fraction_metrics(frame, np.array([0.9, 0.1, 0.2]), 1 / 3) + + self.assertEqual([1, 0, 0], frame["good_trade"].tolist()) + self.assertEqual([0, 1, 0], frame["bad_trade"].tolist()) + self.assertEqual(1, metrics["rows"]) + self.assertEqual(1.0, metrics["good_rate"]) + self.assertEqual(4.0, metrics["avg_edge_bps"]) + def test_entry_feature_screen_keeps_zero_inflated_event_features(self) -> None: values = np.concatenate((np.zeros(5000), np.linspace(1.0, 100.0, 500))) edges = _bucket_edges(values) diff --git a/training/trader_training/good_trade_structure.py b/training/trader_training/good_trade_structure.py new file mode 100644 index 0000000..8274cb1 --- /dev/null +++ b/training/trader_training/good_trade_structure.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import pandas as pd +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.metrics import roc_auc_score + +from trader_training.io_utils import read_parquet, run_root, write_json, write_text +from trader_training.schemas import FEATURE_ORDER, FIT_SPLIT, LATEST_STRESS_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT + + +ALL_SPLITS = (FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT) +EVAL_SPLITS = (TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT) + + +def diagnose_good_trade_structure(args: Any) -> None: + root = run_root(args) + dataset = read_parquet(root / "dataset" / "entry_train.parquet") + min_good_edge_bps = float(args.min_good_edge_bps) + bad_edge_bps = float(args.bad_edge_bps) + top_fractions = tuple(float(item) for item in args.top_fractions) + _require_columns(dataset) + + side_frames = { + side: _side_frame(dataset, side, min_good_edge_bps, bad_edge_bps) + for side in ("LONG", "SHORT") + } + split_summary = pd.concat([_split_summary(frame, side) for side, frame in side_frames.items()], ignore_index=True) + feature_rows = pd.concat([_feature_candidates(frame, side, top_fractions) for side, frame in side_frames.items()], ignore_index=True) + model_rows = pd.concat([_tree_model_top_rows(frame, side, top_fractions) for side, frame in side_frames.items()], ignore_index=True) + result = { + "run_id": args.run_id, + "min_good_edge_bps": min_good_edge_bps, + "bad_edge_bps": bad_edge_bps, + "feature_count": len(FEATURE_ORDER), + "feature_candidate_count": int(len(feature_rows)), + "stable_feature_count": int(feature_rows["stable_auc"].sum()) if not feature_rows.empty else 0, + "stable_positive_top_feature_count": int(feature_rows["stable_positive_top_edge"].sum()) if not feature_rows.empty else 0, + "tree_model_verdict": _tree_verdict(model_rows), + } + out_dir = root / "diagnostics" + write_json(out_dir / "good_trade_structure_result.json", _jsonable(result)) + write_text(out_dir / "good_trade_split_summary.csv", split_summary.to_csv(index=False)) + write_text(out_dir / "good_trade_feature_candidates.csv", feature_rows.to_csv(index=False)) + write_text(out_dir / "good_trade_tree_model_top.csv", model_rows.to_csv(index=False)) + write_text(out_dir / "good_trade_structure_report.md", _markdown_report(result, split_summary, feature_rows, model_rows)) + logging.info( + "trader.training.good_trade_structure_written runId=%s stableFeatureCount=%s stablePositiveTopFeatureCount=%s treeVerdict=%s", + args.run_id, + result["stable_feature_count"], + result["stable_positive_top_feature_count"], + result["tree_model_verdict"]["status"], + ) + + +def _require_columns(dataset: pd.DataFrame) -> None: + required = { + "split_id", + *FEATURE_ORDER, + "long_actual_plan_net_edge_bps", + "short_actual_plan_net_edge_bps", + } + missing = sorted(required - set(dataset.columns)) + if missing: + raise ValueError(f"good trade structure diagnostic missing required columns: {missing}") + + +def _side_frame(dataset: pd.DataFrame, side: str, min_good_edge_bps: float, bad_edge_bps: float) -> pd.DataFrame: + edge_col = "long_actual_plan_net_edge_bps" if side == "LONG" else "short_actual_plan_net_edge_bps" + frame = dataset[["sample_id", "split_id", edge_col, *FEATURE_ORDER]].copy() + frame = frame.rename(columns={edge_col: "actual_edge_bps"}) + frame["side"] = side + frame["actual_edge_bps"] = pd.to_numeric(frame["actual_edge_bps"], errors="coerce") + frame["good_trade"] = frame["actual_edge_bps"].ge(min_good_edge_bps).astype("int8") + frame["breakeven_trade"] = frame["actual_edge_bps"].ge(0.0).astype("int8") + frame["bad_trade"] = frame["actual_edge_bps"].le(bad_edge_bps).astype("int8") + return frame.dropna(subset=["actual_edge_bps"]).reset_index(drop=True) + + +def _split_summary(frame: pd.DataFrame, side: str) -> pd.DataFrame: + rows: list[dict[str, Any]] = [] + for split_id in ALL_SPLITS: + part = frame[frame["split_id"].eq(split_id)] + if part.empty: + continue + edge = part["actual_edge_bps"].astype(float) + rows.append( + { + "side": side, + "split_id": split_id, + "rows": int(len(part)), + "good_rate": float(part["good_trade"].mean()), + "breakeven_rate": float(part["breakeven_trade"].mean()), + "bad_rate": float(part["bad_trade"].mean()), + "avg_edge_bps": float(edge.mean()), + "p50_edge_bps": float(edge.quantile(0.50)), + "p90_edge_bps": float(edge.quantile(0.90)), + "p99_edge_bps": float(edge.quantile(0.99)), + } + ) + return pd.DataFrame(rows) + + +def _feature_candidates(frame: pd.DataFrame, side: str, top_fractions: tuple[float, ...]) -> pd.DataFrame: + rows: list[dict[str, Any]] = [] + tune = frame[frame["split_id"].eq(TUNE_SPLIT)] + for feature in FEATURE_ORDER: + tune_auc = _raw_auc(tune, feature) + if tune_auc is None: + continue + direction = "HIGH" if tune_auc >= 0.5 else "LOW" + row: dict[str, Any] = { + "side": side, + "feature": feature, + "better_when": direction, + "tune_raw_auc": float(tune_auc), + } + directional_aucs = [] + top_edges = [] + top_good_rates = [] + for split_id in EVAL_SPLITS: + part = frame[frame["split_id"].eq(split_id)] + directional_auc = _directional_auc(part, feature, direction) + top_metrics = _feature_top_metrics(part, feature, direction, top_fractions[0]) + row[f"{split_id}_directional_auc"] = directional_auc + row[f"{split_id}_top{_fraction_label(top_fractions[0])}_rows"] = top_metrics["rows"] + row[f"{split_id}_top{_fraction_label(top_fractions[0])}_good_rate"] = top_metrics["good_rate"] + row[f"{split_id}_top{_fraction_label(top_fractions[0])}_avg_edge_bps"] = top_metrics["avg_edge_bps"] + if directional_auc is not None: + directional_aucs.append(float(directional_auc)) + if top_metrics["rows"] > 0: + top_edges.append(float(top_metrics["avg_edge_bps"])) + top_good_rates.append(float(top_metrics["good_rate"])) + row["min_eval_directional_auc"] = min(directional_aucs) if directional_aucs else np.nan + row["min_top_avg_edge_bps"] = min(top_edges) if top_edges else np.nan + row["min_top_good_rate"] = min(top_good_rates) if top_good_rates else np.nan + row["stable_auc"] = bool(len(directional_aucs) == len(EVAL_SPLITS) and min(directional_aucs) >= 0.53) + row["stable_positive_top_edge"] = bool(len(top_edges) == len(EVAL_SPLITS) and min(top_edges) > 0.0) + row["score"] = ( + float(row["min_eval_directional_auc"]) * 10.0 + + float(row["min_top_avg_edge_bps"]) * 0.10 + + (2.0 if row["stable_auc"] else 0.0) + + (3.0 if row["stable_positive_top_edge"] else 0.0) + if np.isfinite(row["min_eval_directional_auc"]) and np.isfinite(row["min_top_avg_edge_bps"]) + else -999.0 + ) + rows.append(row) + if not rows: + return pd.DataFrame() + return pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True) + + +def _raw_auc(frame: pd.DataFrame, feature: str) -> float | None: + values = pd.to_numeric(frame[feature], errors="coerce").replace([np.inf, -np.inf], np.nan) + working = pd.DataFrame({"x": values, "y": frame["good_trade"].astype(int)}).dropna() + if len(working) < 1000 or working["x"].nunique() < 2 or working["y"].nunique() < 2: + return None + return float(roc_auc_score(working["y"].to_numpy(), working["x"].to_numpy())) + + +def _directional_auc(frame: pd.DataFrame, feature: str, direction: str) -> float | None: + auc = _raw_auc(frame, feature) + if auc is None: + return None + return float(auc if direction == "HIGH" else 1.0 - auc) + + +def _feature_top_metrics(frame: pd.DataFrame, feature: str, direction: str, fraction: float) -> dict[str, Any]: + values = pd.to_numeric(frame[feature], errors="coerce").replace([np.inf, -np.inf], np.nan) + working = pd.DataFrame( + { + "x": values, + "good_trade": frame["good_trade"].astype(int), + "actual_edge_bps": frame["actual_edge_bps"].astype(float), + } + ).dropna() + if working.empty: + return {"rows": 0, "good_rate": 0.0, "avg_edge_bps": 0.0} + ascending = direction == "LOW" + top = working.sort_values("x", ascending=ascending).head(max(1, int(len(working) * fraction))) + return { + "rows": int(len(top)), + "good_rate": float(top["good_trade"].mean()), + "avg_edge_bps": float(top["actual_edge_bps"].mean()), + } + + +def _tree_model_top_rows(frame: pd.DataFrame, side: str, top_fractions: tuple[float, ...]) -> pd.DataFrame: + train = frame[frame["split_id"].eq(FIT_SPLIT)].copy() + if train.empty or train["good_trade"].nunique() < 2: + return pd.DataFrame() + model = HistGradientBoostingClassifier( + max_iter=180, + learning_rate=0.04, + max_leaf_nodes=31, + l2_regularization=0.02, + early_stopping=True, + random_state=71 if side == "LONG" else 73, + ) + model.fit(_x(train), train["good_trade"].astype(int).to_numpy()) + rows: list[dict[str, Any]] = [] + for split_id in EVAL_SPLITS: + part = frame[frame["split_id"].eq(split_id)].copy() + if part.empty: + continue + proba = model.predict_proba(_x(part))[:, 1] + auc = _model_auc(part["good_trade"].astype(int).to_numpy(), proba) + for fraction in top_fractions: + metrics = _top_fraction_metrics(part, proba, fraction) + rows.append( + { + "side": side, + "split_id": split_id, + "model": "HistGradientBoostingClassifier", + "auc": auc, + "top_fraction": fraction, + **metrics, + } + ) + return pd.DataFrame(rows) + + +def _model_auc(y_true: np.ndarray, proba: np.ndarray) -> float | None: + if len(np.unique(y_true)) < 2: + return None + return float(roc_auc_score(y_true, proba)) + + +def _top_fraction_metrics(frame: pd.DataFrame, score: np.ndarray, fraction: float) -> dict[str, Any]: + working = frame[["good_trade", "actual_edge_bps"]].copy() + working["score"] = score + top = working.sort_values("score", ascending=False).head(max(1, int(len(working) * fraction))) + return { + "rows": int(len(top)), + "good_rate": float(top["good_trade"].mean()), + "avg_edge_bps": float(top["actual_edge_bps"].mean()), + "p50_edge_bps": float(top["actual_edge_bps"].quantile(0.50)), + "p90_edge_bps": float(top["actual_edge_bps"].quantile(0.90)), + } + + +def _tree_verdict(model_rows: pd.DataFrame) -> dict[str, Any]: + if model_rows.empty: + return {"status": "NO_MODEL_ROWS", "reason": "没有足够样本训练树模型诊断。"} + top10 = model_rows[model_rows["top_fraction"].eq(0.10)].copy() + if top10.empty: + return {"status": "NO_TOP10_ROWS", "reason": "没有 top10 诊断结果。"} + grouped = top10.groupby("side", observed=False) + promising = [] + for side, part in grouped: + if set(part["split_id"]) >= set(EVAL_SPLITS) and part["avg_edge_bps"].min() > 0.0 and part["auc"].min() >= 0.56: + promising.append(str(side)) + if promising: + return {"status": "PROMISING_TREE_STRUCTURE", "reason": f"树模型 top10 在这些方向三段为正: {promising}"} + return {"status": "NO_STABLE_TREE_STRUCTURE", "reason": "树模型 top10 也没有在 tune/validation/latest 三段同时转正。"} + + +def _x(frame: pd.DataFrame) -> np.ndarray: + return frame[FEATURE_ORDER].apply(pd.to_numeric, errors="coerce").replace([np.inf, -np.inf], np.nan).astype("float32").to_numpy() + + +def _markdown_report(result: dict[str, Any], split_summary: pd.DataFrame, feature_rows: pd.DataFrame, model_rows: pd.DataFrame) -> str: + top_fraction = 0.10 + lines = [ + "# 好单结构诊断报告", + "", + "这份报告只看一件事:现有 54 个特征能不能把真实赚钱单和亏钱单分开。", + "", + f"- run_id: `{result['run_id']}`", + f"- 好单定义: 当前价格计划真实净收益 >= `{result['min_good_edge_bps']}` bps", + f"- 坏单辅助定义: 当前价格计划真实净收益 <= `{result['bad_edge_bps']}` bps", + f"- 树模型诊断结论: `{result['tree_model_verdict']['status']}`", + f"- 结论说明: {result['tree_model_verdict']['reason']}", + "", + "## 基础分布", + "", + _markdown_table(split_summary), + "", + "## 单特征分辨力", + "", + f"- 稳定 AUC 特征数: `{result['stable_feature_count']}`", + f"- top {_fraction_label(top_fraction)} 平均收益三段都为正的特征数: `{result['stable_positive_top_feature_count']}`", + "", + ] + feature_display = _feature_display(feature_rows, top_fraction) + lines.append(_markdown_table(feature_display.head(25))) + lines.extend(["", "## 树模型 top 分桶", ""]) + model_display = model_rows.sort_values(["side", "top_fraction", "split_id"]).copy() if not model_rows.empty else pd.DataFrame() + lines.append(_markdown_table(model_display)) + lines.extend( + [ + "", + "## 文件", + "", + "- `diagnostics/good_trade_split_summary.csv`: 好单/坏单基础分布。", + "- `diagnostics/good_trade_feature_candidates.csv`: 单特征分辨力明细。", + "- `diagnostics/good_trade_tree_model_top.csv`: 树模型 top 分桶明细。", + "", + ] + ) + return "\n".join(lines) + + +def _feature_display(feature_rows: pd.DataFrame, top_fraction: float) -> pd.DataFrame: + if feature_rows.empty: + return pd.DataFrame() + label = _fraction_label(top_fraction) + columns = [ + "side", + "feature", + "better_when", + "min_eval_directional_auc", + f"{TUNE_SPLIT}_top{label}_avg_edge_bps", + f"{VALIDATION_LOCKED_SPLIT}_top{label}_avg_edge_bps", + f"{LATEST_STRESS_SPLIT}_top{label}_avg_edge_bps", + "min_top_avg_edge_bps", + "min_top_good_rate", + "stable_auc", + "stable_positive_top_edge", + "score", + ] + return feature_rows[[column for column in columns if column in feature_rows.columns]].copy() + + +def _markdown_table(frame: pd.DataFrame) -> str: + if frame.empty: + return "_无_" + columns = list(frame.columns) + lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join(["---"] * len(columns)) + " |"] + for _, row in frame.iterrows(): + lines.append("| " + " | ".join(_format_cell(row[column]) for column in columns) + " |") + return "\n".join(lines) + + +def _format_cell(value: Any) -> str: + if value is None or pd.isna(value): + return "" + if isinstance(value, (float, np.floating)): + return f"{float(value):.6g}" + if isinstance(value, (bool, np.bool_)): + return "true" if bool(value) else "false" + return str(value) + + +def _fraction_label(fraction: float) -> str: + return str(int(round(fraction * 100))) + + +def _jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): _jsonable(item) for key, item in value.items()} + if isinstance(value, list): + return [_jsonable(item) for item in value] + if isinstance(value, (np.integer,)): + return int(value) + if isinstance(value, (np.floating,)): + return float(value) + if isinstance(value, np.ndarray): + return value.tolist() + return value diff --git a/training/trader_training/training.py b/training/trader_training/training.py index 400437d..8e3c723 100644 --- a/training/trader_training/training.py +++ b/training/trader_training/training.py @@ -89,7 +89,7 @@ def train_small_models(args: Any) -> None: model_manifest: dict[str, Any] = {} for model_name, spec in TARGETS.items(): dataset = read_parquet(root / "dataset" / spec["dataset"]) - if model_name == "ENTRY" and _conditional_entry_enabled(args): + if model_name == "ENTRY" and _conditional_entry_source(args) == "direction_label": dataset = _attach_direction_fit_labels(root, dataset) if args.max_rows and len(dataset) > args.max_rows: dataset = dataset.sort_values("event_time").tail(args.max_rows).copy() @@ -189,7 +189,17 @@ def train_small_models(args: Any) -> None: def _conditional_entry_enabled(args: Any) -> bool: - return bool(getattr(args, "conditional_entry_direction_labels", False)) + return _conditional_entry_source(args) != "none" + + +def _conditional_entry_source(args: Any) -> str: + source = str(getattr(args, "conditional_entry_source", "none") or "none").strip().lower() + if bool(getattr(args, "conditional_entry_direction_labels", False)): + source = "direction_label" + allowed = {"none", "direction_label", "side_opportunity"} + if source not in allowed: + raise ValueError(f"unsupported conditional Entry source: {source}") + return source def _attach_direction_fit_labels(root: Path, entry_dataset: pd.DataFrame) -> pd.DataFrame: @@ -213,19 +223,29 @@ def _attach_direction_fit_labels(root: Path, entry_dataset: pd.DataFrame) -> pd. def _head_train_mask(model_name: str, head_name: str, train: pd.DataFrame, args: Any) -> tuple[np.ndarray, str]: - if model_name != "ENTRY" or not _conditional_entry_enabled(args): + source = _conditional_entry_source(args) + if model_name != "ENTRY" or source == "none": return np.ones(len(train), dtype=bool), "ALL_FIT_ROWS" if head_name.startswith("long_"): - condition_column = "long_target" - filter_name = "DIRECTION_LABEL_LONG_FIT_ROWS" + side = "LONG" + direction_label_column = "long_target" + opportunity_column = "long_max_achievable_net_edge_bps" elif head_name.startswith("short_"): - condition_column = "short_target" - filter_name = "DIRECTION_LABEL_SHORT_FIT_ROWS" + side = "SHORT" + direction_label_column = "short_target" + opportunity_column = "short_max_achievable_net_edge_bps" else: return np.ones(len(train), dtype=bool), "ALL_FIT_ROWS" - if condition_column not in train.columns: - raise ValueError(f"conditional Entry training requires {condition_column} for head {head_name}") - mask = pd.to_numeric(train[condition_column], errors="coerce").fillna(0).astype(int).eq(1).to_numpy() + if source == "direction_label": + if direction_label_column not in train.columns: + raise ValueError(f"conditional Entry training requires {direction_label_column} for head {head_name}") + mask = pd.to_numeric(train[direction_label_column], errors="coerce").fillna(0).astype(int).eq(1).to_numpy() + return mask, f"DIRECTION_LABEL_{side}_FIT_ROWS" + threshold = float(getattr(args, "conditional_entry_opportunity_bps", 40.0) or 40.0) + if opportunity_column not in train.columns: + raise ValueError(f"side opportunity Entry training requires {opportunity_column} for head {head_name}") + mask = pd.to_numeric(train[opportunity_column], errors="coerce").ge(threshold).fillna(False).to_numpy() + filter_name = f"SIDE_OPPORTUNITY_{side}_GE_{threshold:g}_BPS_FIT_ROWS" return mask, filter_name