Use actual plan edge in OFI diagnostics
This commit is contained in:
@@ -337,6 +337,23 @@ def _load_direction_dataset(baseline_root: Path, feature: pd.DataFrame) -> pd.Da
|
||||
|
||||
|
||||
def _load_entry_dataset(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFrame:
|
||||
dataset_path = baseline_root / "dataset" / "entry_train.parquet"
|
||||
if dataset_path.is_file():
|
||||
labels = read_parquet(dataset_path)
|
||||
required = {
|
||||
"sample_id",
|
||||
"long_entry_target",
|
||||
"short_entry_target",
|
||||
"long_actual_plan_net_edge_bps",
|
||||
"short_actual_plan_net_edge_bps",
|
||||
}
|
||||
missing = sorted(required.difference(labels.columns))
|
||||
if missing:
|
||||
raise ValueError(f"entry_train dataset missing columns: {missing}")
|
||||
dataset = feature.merge(labels[list(required)], on="sample_id", how="inner")
|
||||
logging.info("trader.training.ofi_entry_dataset_loaded source=entry_train rowCount=%s", len(dataset))
|
||||
return dataset
|
||||
|
||||
labels = read_parquet(baseline_root / "label" / "entry_labels.parquet")
|
||||
required = {"sample_id", "side", "entry_target", "expected_net_edge_bps"}
|
||||
missing = sorted(required.difference(labels.columns))
|
||||
@@ -350,7 +367,7 @@ def _load_entry_dataset(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFr
|
||||
)
|
||||
pivot = long.merge(short, on="sample_id", how="inner")
|
||||
dataset = feature.merge(pivot, on="sample_id", how="inner")
|
||||
logging.info("trader.training.ofi_entry_dataset_loaded rowCount=%s", len(dataset))
|
||||
logging.info("trader.training.ofi_entry_dataset_loaded source=entry_labels_legacy rowCount=%s", len(dataset))
|
||||
return dataset
|
||||
|
||||
|
||||
@@ -407,8 +424,8 @@ def _train_entry(frame: pd.DataFrame, feature_columns: list[str]) -> tuple[dict[
|
||||
specs = [
|
||||
("long_entry_prob", "binary", "long_entry_target"),
|
||||
("short_entry_prob", "binary", "short_entry_target"),
|
||||
("long_expected_net_edge_bps", "regression", "long_expected_net_edge_bps"),
|
||||
("short_expected_net_edge_bps", "regression", "short_expected_net_edge_bps"),
|
||||
("long_actual_plan_net_edge_bps", "regression", "long_actual_plan_net_edge_bps"),
|
||||
("short_actual_plan_net_edge_bps", "regression", "short_actual_plan_net_edge_bps"),
|
||||
]
|
||||
results: dict[str, Any] = {"feature_count": len(feature_columns), "feature_hash": sha256_json(feature_columns)}
|
||||
split_predictions: dict[str, pd.DataFrame] = {
|
||||
@@ -785,8 +802,8 @@ def _model_compare_report(args: Any, baseline_root: Path, results: dict[str, Any
|
||||
f"| Direction | neutral_auc | {baseline_direction.get('neutral_auc')} |",
|
||||
f"| Entry | long_auc | {baseline_entry['long_entry_prob'].get('auc')} |",
|
||||
f"| Entry | short_auc | {baseline_entry['short_entry_prob'].get('auc')} |",
|
||||
f"| Entry | long_edge_mae_ratio | {baseline_entry['long_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |",
|
||||
f"| Entry | short_edge_mae_ratio | {baseline_entry['short_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |",
|
||||
f"| Entry | long_exported_edge_mae_ratio | {baseline_entry['long_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |",
|
||||
f"| Entry | short_exported_edge_mae_ratio | {baseline_entry['short_expected_net_edge_bps'].get('mae_vs_constant_ratio')} |",
|
||||
"",
|
||||
"## Diagnostic Direction Result",
|
||||
"",
|
||||
@@ -817,7 +834,7 @@ def _model_compare_report(args: Any, baseline_root: Path, results: dict[str, Any
|
||||
lines.append(
|
||||
f"| {head} | {feature_set_name} | {split_id} | {metric.get('auc')} | {metric.get('brier_vs_constant_ratio')} | {metric.get('top10_hit_rate')} |"
|
||||
)
|
||||
for head in ("long_expected_net_edge_bps", "short_expected_net_edge_bps"):
|
||||
for head in ("long_actual_plan_net_edge_bps", "short_actual_plan_net_edge_bps"):
|
||||
for split_id in EVAL_SPLITS:
|
||||
metric = entry.get(head, {}).get(split_id, {})
|
||||
lines.append(f"| {head} | {feature_set_name} | {split_id} | {metric.get('mae_vs_constant_ratio')} | | |")
|
||||
@@ -827,6 +844,7 @@ def _model_compare_report(args: Any, baseline_root: Path, results: dict[str, Any
|
||||
"## Verdict Rule",
|
||||
"",
|
||||
"只有 `market_plus_ofi` 在 validation_locked 和 latest_stress 上同时好过 `market_only`,才进入正式特征链路。",
|
||||
"Entry 的收益回归诊断使用 `actual_plan_net_edge_bps`,也就是真实按价格计划出场后的净收益。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user