Files
quant-trader-service/training/trader_training/direction_opportunity_dataset.py
T

141 lines
6.2 KiB
Python
Raw Normal View History

from __future__ import annotations
import logging
from typing import Any
import numpy as np
import pandas as pd
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
from trader_training.schemas import FEATURE_ORDER
def build_direction_opportunity_dataset(args: Any) -> None:
root = run_root(args)
direction_path = args.direction_dataset_path or root / "dataset" / "direction_train.parquet"
entry_path = args.entry_dataset_path or root / "dataset" / "entry_train.parquet"
output_path = args.output_path or root / "dataset" / "direction_train.parquet"
opportunity_bps = float(args.opportunity_bps)
min_advantage_bps = float(args.min_advantage_bps)
long_edge_column = str(args.long_edge_column)
short_edge_column = str(args.short_edge_column)
label_method = str(args.label_method)
direction = read_parquet(direction_path)
entry = read_parquet(entry_path)
require_columns(direction, ("sample_id", "split_id", *FEATURE_ORDER), "direction_train")
require_columns(entry, ("sample_id", long_edge_column, short_edge_column), "entry_train")
opportunity = entry[["sample_id", long_edge_column, short_edge_column]].copy()
merged = direction.drop(columns=["long_target", "short_target", "neutral_target"], errors="ignore").merge(opportunity, on="sample_id", how="inner", validate="one_to_one")
if len(merged) != len(direction):
raise ValueError(f"direction opportunity dataset lost rows: before={len(direction)} after={len(merged)}")
labels = _opportunity_labels(
pd.to_numeric(merged[long_edge_column], errors="coerce").to_numpy(dtype="float64"),
pd.to_numeric(merged[short_edge_column], errors="coerce").to_numpy(dtype="float64"),
opportunity_bps,
min_advantage_bps,
)
merged["long_target"] = labels["long_target"]
merged["short_target"] = labels["short_target"]
merged["neutral_target"] = labels["neutral_target"]
# 保留 future_return_bps 作为排查字段;训练目标以三列 target 为准。
ordered = [column for column in direction.columns if column in merged.columns and column not in {"long_target", "short_target", "neutral_target"}]
ordered.extend(["long_target", "short_target", "neutral_target"])
for column in (long_edge_column, short_edge_column):
if column not in ordered:
ordered.append(column)
out = merged[ordered].copy()
data_hash = write_parquet(output_path, out)
result = {
"dataset": manifest(
output_path,
{
"row_count": len(out),
"feature_count": len(FEATURE_ORDER),
"data_hash_sha256": data_hash,
"split_counts": out["split_id"].value_counts().to_dict(),
},
),
"label_method": label_method,
"long_edge_column": long_edge_column,
"short_edge_column": short_edge_column,
"opportunity_bps": opportunity_bps,
"min_advantage_bps": min_advantage_bps,
"target_counts": {
"long": int(out["long_target"].sum()),
"short": int(out["short_target"].sum()),
"neutral": int(out["neutral_target"].sum()),
},
"target_rates_by_split": _target_rates_by_split(out),
}
write_json(root / "dataset" / "direction_opportunity_dataset_result.json", result)
write_text(root / "dataset" / "direction_opportunity_dataset_report.md", _markdown_report(result))
logging.info(
"trader.training.direction_opportunity_dataset_written runId=%s opportunityBps=%.6f minAdvantageBps=%.6f rowCount=%s outputPath=%s",
args.run_id,
opportunity_bps,
min_advantage_bps,
len(out),
output_path,
)
def _opportunity_labels(long_edge: np.ndarray, short_edge: np.ndarray, opportunity_bps: float, min_advantage_bps: float) -> dict[str, np.ndarray]:
long_clean = np.nan_to_num(long_edge, nan=-np.inf)
short_clean = np.nan_to_num(short_edge, nan=-np.inf)
long_ok = long_clean >= opportunity_bps
short_ok = short_clean >= opportunity_bps
long_wins = long_ok & ((long_clean - short_clean) >= min_advantage_bps)
short_wins = short_ok & ((short_clean - long_clean) >= min_advantage_bps)
neutral = ~(long_wins | short_wins)
return {
"long_target": long_wins.astype("int8"),
"short_target": short_wins.astype("int8"),
"neutral_target": neutral.astype("int8"),
}
def _target_rates_by_split(frame: pd.DataFrame) -> dict[str, dict[str, float]]:
result: dict[str, dict[str, float]] = {}
for split_id, part in frame.groupby("split_id", observed=False):
rows = len(part)
result[str(split_id)] = {
"rows": float(rows),
"long_rate": float(part["long_target"].mean()) if rows else 0.0,
"short_rate": float(part["short_target"].mean()) if rows else 0.0,
"neutral_rate": float(part["neutral_target"].mean()) if rows else 0.0,
}
return result
def _markdown_report(result: dict[str, Any]) -> str:
lines = [
"# Direction 机会标签数据集报告",
"",
"这份数据集把 Direction 目标从“未来收盘收益方向”改为“未来路径里哪边有可交易空间”。",
"",
f"- label_method: `{result['label_method']}`",
f"- long_edge_column: `{result['long_edge_column']}`",
f"- short_edge_column: `{result['short_edge_column']}`",
f"- opportunity_bps: `{result['opportunity_bps']}`",
f"- min_advantage_bps: `{result['min_advantage_bps']}`",
f"- row_count: `{result['dataset']['row_count']}`",
"",
"## 标签数量",
"",
f"- long: `{result['target_counts']['long']}`",
f"- short: `{result['target_counts']['short']}`",
f"- neutral: `{result['target_counts']['neutral']}`",
"",
"## 分段比例",
"",
"| split | rows | long | short | neutral |",
"| --- | ---: | ---: | ---: | ---: |",
]
for split_id, item in result["target_rates_by_split"].items():
lines.append(f"| {split_id} | {int(item['rows'])} | {item['long_rate']:.4f} | {item['short_rate']:.4f} | {item['neutral_rate']:.4f} |")
lines.append("")
return "\n".join(lines)