Add Direction label and PM probe diagnostics
This commit is contained in:
@@ -36,8 +36,8 @@ def diagnose_training_run(args: Any) -> None:
|
||||
|
||||
|
||||
def _label_summary(root) -> dict[str, Any]:
|
||||
direction = read_parquet(root / "label" / "direction_labels.parquet")
|
||||
entry = read_parquet(root / "label" / "entry_labels.parquet")
|
||||
direction = read_parquet(root / "dataset" / "direction_train.parquet")
|
||||
entry = read_parquet(root / "dataset" / "entry_train.parquet")
|
||||
summary: dict[str, Any] = {}
|
||||
for split_id in DIAGNOSTIC_SPLITS:
|
||||
direction_split = direction[direction["split_id"].eq(split_id)].copy()
|
||||
@@ -45,28 +45,57 @@ def _label_summary(root) -> dict[str, Any]:
|
||||
item: dict[str, Any] = {"direction": {}, "entry": {}}
|
||||
if not direction_split.empty:
|
||||
item["direction"] = {
|
||||
"source": "dataset/direction_train.parquet",
|
||||
"rows": len(direction_split),
|
||||
"label_ratio": direction_split["direction_label"].value_counts(normalize=True).round(6).to_dict(),
|
||||
"label_ratio": _direction_target_ratio(direction_split),
|
||||
"future_return_bps_quantile": _quantiles(direction_split["future_return_bps"], (0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99)),
|
||||
}
|
||||
if not entry_split.empty:
|
||||
if "actual_plan_net_edge_bps" not in entry_split.columns:
|
||||
raise ValueError("entry_labels is missing actual_plan_net_edge_bps for diagnostics")
|
||||
grouped = entry_split.groupby("side", observed=False)
|
||||
required = {
|
||||
"long_entry_target",
|
||||
"short_entry_target",
|
||||
"long_actual_plan_net_edge_bps",
|
||||
"short_actual_plan_net_edge_bps",
|
||||
}
|
||||
missing = sorted(required - set(entry_split.columns))
|
||||
if missing:
|
||||
raise ValueError(f"entry_train is missing columns required by diagnostics: {missing}")
|
||||
item["entry"] = {
|
||||
"source": "dataset/entry_train.parquet",
|
||||
"rows": len(entry_split),
|
||||
"target_rate_by_side": grouped["entry_target"].mean().round(6).to_dict(),
|
||||
"target_rate_by_side": {
|
||||
"LONG": float(entry_split["long_entry_target"].astype(float).mean()),
|
||||
"SHORT": float(entry_split["short_entry_target"].astype(float).mean()),
|
||||
},
|
||||
"edge_column": "actual_plan_net_edge_bps",
|
||||
"edge_mean_by_side": grouped["actual_plan_net_edge_bps"].mean().round(6).to_dict(),
|
||||
"edge_mean_by_side": {
|
||||
"LONG": float(entry_split["long_actual_plan_net_edge_bps"].astype(float).mean()),
|
||||
"SHORT": float(entry_split["short_actual_plan_net_edge_bps"].astype(float).mean()),
|
||||
},
|
||||
"edge_quantile_by_side": {
|
||||
str(side): _quantiles(group["actual_plan_net_edge_bps"], (0.05, 0.5, 0.95))
|
||||
for side, group in grouped
|
||||
"LONG": _quantiles(entry_split["long_actual_plan_net_edge_bps"], (0.05, 0.5, 0.95)),
|
||||
"SHORT": _quantiles(entry_split["short_actual_plan_net_edge_bps"], (0.05, 0.5, 0.95)),
|
||||
},
|
||||
}
|
||||
summary[split_id] = item
|
||||
return summary
|
||||
|
||||
|
||||
def _direction_target_ratio(frame: pd.DataFrame) -> dict[str, float]:
|
||||
required = {"long_target", "short_target", "neutral_target"}
|
||||
missing = sorted(required - set(frame.columns))
|
||||
if missing:
|
||||
raise ValueError(f"direction_train is missing target columns required by diagnostics: {missing}")
|
||||
rows = len(frame)
|
||||
if rows == 0:
|
||||
return {"LONG": 0.0, "SHORT": 0.0, "NEUTRAL": 0.0}
|
||||
return {
|
||||
"LONG": float(frame["long_target"].astype(float).mean()),
|
||||
"SHORT": float(frame["short_target"].astype(float).mean()),
|
||||
"NEUTRAL": float(frame["neutral_target"].astype(float).mean()),
|
||||
}
|
||||
|
||||
|
||||
def _pm_summary(root) -> dict[str, Any]:
|
||||
summary: dict[str, Any] = {}
|
||||
config_path = root / "pm-search" / "position_manager_config.json"
|
||||
@@ -259,7 +288,7 @@ def _diagnostic_conclusion(pm_summary: dict[str, Any]) -> dict[str, Any]:
|
||||
if validation.get("avg_weighted_edge_bps", 0.0) <= 0 and stress.get("avg_weighted_edge_bps", 0.0) <= 0:
|
||||
return {
|
||||
"status": "PRICE_PLAN_OR_ENTRY_NOT_TRADABLE",
|
||||
"plain_reason": "按固定止盈止损真实收益算,验证集和压力集选出来的交易平均都不赚钱。",
|
||||
"plain_reason": "按当前价格计划真实收益算,验证集和压力集选出来的交易平均都不赚钱。",
|
||||
"next_action": "优先重新搜索价格计划,再重建 Entry 标签和模型;不要只放松 PM 阈值。",
|
||||
}
|
||||
return {
|
||||
@@ -296,10 +325,12 @@ def _markdown_report(payload: dict[str, Any]) -> str:
|
||||
lines.append("")
|
||||
if direction:
|
||||
lines.append(f"- Direction 行数: {direction['rows']}")
|
||||
lines.append(f"- Direction 来源: `{direction['source']}`")
|
||||
lines.append(f"- Direction 标签比例: `{direction['label_ratio']}`")
|
||||
lines.append(f"- 45 分钟未来收益分位: `{direction['future_return_bps_quantile']}`")
|
||||
if entry:
|
||||
lines.append(f"- Entry 行数: {entry['rows']}")
|
||||
lines.append(f"- Entry 来源: `{entry['source']}`")
|
||||
lines.append(f"- Entry 命中率: `{entry['target_rate_by_side']}`")
|
||||
lines.append(f"- Entry 平均净收益: `{entry['edge_mean_by_side']}`")
|
||||
lines.append("")
|
||||
@@ -312,6 +343,7 @@ def _markdown_report(payload: dict[str, Any]) -> str:
|
||||
lines.append(f"- 当前阈值: `{item['active_thresholds']}`")
|
||||
lines.append(f"- 当前阈值选中交易: `{item['selected_trade_metrics']}`")
|
||||
lines.append(f"- 网格里有交易的候选数: {item['grid_search_any_trade']['candidates_with_trade']} / {item['grid_search_any_trade']['candidate_count']}")
|
||||
lines.extend(_score_distribution_markdown(item["score_distribution"]))
|
||||
lines.append("")
|
||||
for side in ("long", "short"):
|
||||
lines.append(f"#### {side.upper()}")
|
||||
@@ -334,6 +366,33 @@ def _markdown_report(payload: dict[str, Any]) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _score_distribution_markdown(distribution: dict[str, dict[str, float]]) -> list[str]:
|
||||
watched_columns = [
|
||||
"long_prob",
|
||||
"short_prob",
|
||||
"long_entry_prob",
|
||||
"short_entry_prob",
|
||||
"market_risk_prob",
|
||||
"pred_long_expected_net_edge_bps",
|
||||
"pred_short_expected_net_edge_bps",
|
||||
]
|
||||
lines = ["", "#### 分数分布", "", "| 字段 | 最小 | 5% | 中位数 | 95% | 最大 |", "| --- | ---: | ---: | ---: | ---: | ---: |"]
|
||||
for column in watched_columns:
|
||||
quantiles = distribution.get(column)
|
||||
if not quantiles:
|
||||
continue
|
||||
lines.append(
|
||||
"| "
|
||||
+ column
|
||||
+ f" | {quantiles.get('0.0', 0.0):.4f}"
|
||||
+ f" | {quantiles.get('0.05', 0.0):.4f}"
|
||||
+ f" | {quantiles.get('0.5', 0.0):.4f}"
|
||||
+ f" | {quantiles.get('0.95', 0.0):.4f}"
|
||||
+ f" | {quantiles.get('1.0', 0.0):.4f} |"
|
||||
)
|
||||
return lines
|
||||
|
||||
|
||||
def _jsonable(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _jsonable(item) for key, item in value.items()}
|
||||
|
||||
Reference in New Issue
Block a user