diff --git a/training/tests/test_training_contract.py b/training/tests/test_training_contract.py index 98a5f9d..be61aec 100644 --- a/training/tests/test_training_contract.py +++ b/training/tests/test_training_contract.py @@ -18,7 +18,7 @@ from trader_training.dynamic_exit_search import search_dynamic_exit_plans from trader_training.entry_feature_screen import _bucket_edges, _screen_edge_column from trader_training.io_utils import read_json, write_json from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels -from trader_training.ofi_feature_experiment import l1_snapshot_diff_ofi_quote +from trader_training.ofi_feature_experiment import _load_entry_dataset, l1_snapshot_diff_ofi_quote from trader_training.promote import promote_artifact_bundle from trader_training.replay import build_splits from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT @@ -117,6 +117,38 @@ class TrainingContractTest(unittest.TestCase): self.assertEqual(1, result["candidate_count"]) self.assertTrue((run_root / "dynamic-exit-search" / "dynamic_exit_search_report.md").is_file()) + def test_ofi_entry_dataset_uses_actual_plan_edge(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + baseline_root = Path(tmp) + dataset_path = baseline_root / "dataset" / "entry_train.parquet" + dataset_path.parent.mkdir(parents=True) + pd.DataFrame( + { + "sample_id": ["s1"], + "long_entry_target": [1], + "short_entry_target": [0], + "long_actual_plan_net_edge_bps": [4.0], + "short_actual_plan_net_edge_bps": [-7.0], + } + ).to_parquet(dataset_path, index=False) + feature = pd.DataFrame( + { + "sample_id": ["s1"], + "symbol": ["BTC-USDT-PERP"], + "event_time": pd.to_datetime(["2026-01-01T00:00:00Z"]), + "open_time_ms": [0], + "split_id": ["fit_inner"], + "walk_forward_fold": [0], + "data_quality_flag": ["OK"], + } + ) + + dataset = _load_entry_dataset(baseline_root, feature) + + self.assertIn("long_actual_plan_net_edge_bps", dataset.columns) + self.assertNotIn("long_expected_net_edge_bps", dataset.columns) + self.assertEqual(4.0, float(dataset.loc[0, "long_actual_plan_net_edge_bps"])) + def test_split_builder_uses_locked_validation_contract(self) -> None: with tempfile.TemporaryDirectory() as tmp: data_root = Path(tmp) diff --git a/training/trader_training/ofi_feature_experiment.py b/training/trader_training/ofi_feature_experiment.py index c49fe2b..9357de0 100644 --- a/training/trader_training/ofi_feature_experiment.py +++ b/training/trader_training/ofi_feature_experiment.py @@ -337,6 +337,23 @@ def _load_direction_dataset(baseline_root: Path, feature: pd.DataFrame) -> pd.Da def _load_entry_dataset(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFrame: + dataset_path = baseline_root / "dataset" / "entry_train.parquet" + if dataset_path.is_file(): + labels = read_parquet(dataset_path) + required = { + "sample_id", + "long_entry_target", + "short_entry_target", + "long_actual_plan_net_edge_bps", + "short_actual_plan_net_edge_bps", + } + missing = sorted(required.difference(labels.columns)) + if missing: + raise ValueError(f"entry_train dataset missing columns: {missing}") + dataset = feature.merge(labels[list(required)], on="sample_id", how="inner") + logging.info("trader.training.ofi_entry_dataset_loaded source=entry_train rowCount=%s", len(dataset)) + return dataset + labels = read_parquet(baseline_root / "label" / "entry_labels.parquet") required = {"sample_id", "side", "entry_target", "expected_net_edge_bps"} missing = sorted(required.difference(labels.columns)) @@ -350,7 +367,7 @@ def _load_entry_dataset(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFr ) pivot = long.merge(short, on="sample_id", how="inner") dataset = feature.merge(pivot, on="sample_id", how="inner") - logging.info("trader.training.ofi_entry_dataset_loaded rowCount=%s", len(dataset)) + logging.info("trader.training.ofi_entry_dataset_loaded source=entry_labels_legacy rowCount=%s", len(dataset)) return dataset @@ -407,8 +424,8 @@ def _train_entry(frame: pd.DataFrame, feature_columns: list[str]) -> tuple[dict[ specs = [ ("long_entry_prob", "binary", "long_entry_target"), ("short_entry_prob", "binary", "short_entry_target"), - ("long_expected_net_edge_bps", "regression", "long_expected_net_edge_bps"), - ("short_expected_net_edge_bps", "regression", "short_expected_net_edge_bps"), + ("long_actual_plan_net_edge_bps", "regression", "long_actual_plan_net_edge_bps"), + ("short_actual_plan_net_edge_bps", "regression", "short_actual_plan_net_edge_bps"), ] results: dict[str, Any] = {"feature_count": len(feature_columns), "feature_hash": sha256_json(feature_columns)} split_predictions: dict[str, pd.DataFrame] = { @@ -785,8 +802,8 @@ def _model_compare_report(args: Any, baseline_root: Path, results: dict[str, Any f"| Direction | neutral_auc | {baseline_direction.get('neutral_auc')} |", f"| Entry | long_auc | {baseline_entry['long_entry_prob'].get('auc')} |", f"| Entry | short_auc | {baseline_entry['short_entry_prob'].get('auc')} |", - f"| Entry | long_edge_mae_ratio | {baseline_entry['long_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |", - f"| Entry | short_edge_mae_ratio | {baseline_entry['short_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |", + f"| Entry | long_exported_edge_mae_ratio | {baseline_entry['long_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |", + f"| Entry | short_exported_edge_mae_ratio | {baseline_entry['short_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |", "", "## Diagnostic Direction Result", "", @@ -817,7 +834,7 @@ def _model_compare_report(args: Any, baseline_root: Path, results: dict[str, Any lines.append( f"| {head} | {feature_set_name} | {split_id} | {metric.get('auc')} | {metric.get('brier_vs_constant_ratio')} | {metric.get('top10_hit_rate')} |" ) - for head in ("long_expected_net_edge_bps", "short_expected_net_edge_bps"): + for head in ("long_actual_plan_net_edge_bps", "short_actual_plan_net_edge_bps"): for split_id in EVAL_SPLITS: metric = entry.get(head, {}).get(split_id, {}) lines.append(f"| {head} | {feature_set_name} | {split_id} | {metric.get('mae_vs_constant_ratio')} | | |") @@ -827,6 +844,7 @@ def _model_compare_report(args: Any, baseline_root: Path, results: dict[str, Any "## Verdict Rule", "", "只有 `market_plus_ofi` 在 validation_locked 和 latest_stress 上同时好过 `market_only`,才进入正式特征链路。", + "Entry 的收益回归诊断使用 `actual_plan_net_edge_bps`,也就是真实按价格计划出场后的净收益。", "", ] )