141 lines
6.2 KiB
Python
141 lines
6.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
|
||
|
|
from trader_training.schemas import FEATURE_ORDER
|
||
|
|
|
||
|
|
|
||
|
|
def build_direction_opportunity_dataset(args: Any) -> None:
|
||
|
|
root = run_root(args)
|
||
|
|
direction_path = args.direction_dataset_path or root / "dataset" / "direction_train.parquet"
|
||
|
|
entry_path = args.entry_dataset_path or root / "dataset" / "entry_train.parquet"
|
||
|
|
output_path = args.output_path or root / "dataset" / "direction_train.parquet"
|
||
|
|
opportunity_bps = float(args.opportunity_bps)
|
||
|
|
min_advantage_bps = float(args.min_advantage_bps)
|
||
|
|
long_edge_column = str(args.long_edge_column)
|
||
|
|
short_edge_column = str(args.short_edge_column)
|
||
|
|
label_method = str(args.label_method)
|
||
|
|
|
||
|
|
direction = read_parquet(direction_path)
|
||
|
|
entry = read_parquet(entry_path)
|
||
|
|
require_columns(direction, ("sample_id", "split_id", *FEATURE_ORDER), "direction_train")
|
||
|
|
require_columns(entry, ("sample_id", long_edge_column, short_edge_column), "entry_train")
|
||
|
|
|
||
|
|
opportunity = entry[["sample_id", long_edge_column, short_edge_column]].copy()
|
||
|
|
merged = direction.drop(columns=["long_target", "short_target", "neutral_target"], errors="ignore").merge(opportunity, on="sample_id", how="inner", validate="one_to_one")
|
||
|
|
if len(merged) != len(direction):
|
||
|
|
raise ValueError(f"direction opportunity dataset lost rows: before={len(direction)} after={len(merged)}")
|
||
|
|
|
||
|
|
labels = _opportunity_labels(
|
||
|
|
pd.to_numeric(merged[long_edge_column], errors="coerce").to_numpy(dtype="float64"),
|
||
|
|
pd.to_numeric(merged[short_edge_column], errors="coerce").to_numpy(dtype="float64"),
|
||
|
|
opportunity_bps,
|
||
|
|
min_advantage_bps,
|
||
|
|
)
|
||
|
|
merged["long_target"] = labels["long_target"]
|
||
|
|
merged["short_target"] = labels["short_target"]
|
||
|
|
merged["neutral_target"] = labels["neutral_target"]
|
||
|
|
# 保留 future_return_bps 作为排查字段;训练目标以三列 target 为准。
|
||
|
|
ordered = [column for column in direction.columns if column in merged.columns and column not in {"long_target", "short_target", "neutral_target"}]
|
||
|
|
ordered.extend(["long_target", "short_target", "neutral_target"])
|
||
|
|
for column in (long_edge_column, short_edge_column):
|
||
|
|
if column not in ordered:
|
||
|
|
ordered.append(column)
|
||
|
|
out = merged[ordered].copy()
|
||
|
|
data_hash = write_parquet(output_path, out)
|
||
|
|
result = {
|
||
|
|
"dataset": manifest(
|
||
|
|
output_path,
|
||
|
|
{
|
||
|
|
"row_count": len(out),
|
||
|
|
"feature_count": len(FEATURE_ORDER),
|
||
|
|
"data_hash_sha256": data_hash,
|
||
|
|
"split_counts": out["split_id"].value_counts().to_dict(),
|
||
|
|
},
|
||
|
|
),
|
||
|
|
"label_method": label_method,
|
||
|
|
"long_edge_column": long_edge_column,
|
||
|
|
"short_edge_column": short_edge_column,
|
||
|
|
"opportunity_bps": opportunity_bps,
|
||
|
|
"min_advantage_bps": min_advantage_bps,
|
||
|
|
"target_counts": {
|
||
|
|
"long": int(out["long_target"].sum()),
|
||
|
|
"short": int(out["short_target"].sum()),
|
||
|
|
"neutral": int(out["neutral_target"].sum()),
|
||
|
|
},
|
||
|
|
"target_rates_by_split": _target_rates_by_split(out),
|
||
|
|
}
|
||
|
|
write_json(root / "dataset" / "direction_opportunity_dataset_result.json", result)
|
||
|
|
write_text(root / "dataset" / "direction_opportunity_dataset_report.md", _markdown_report(result))
|
||
|
|
logging.info(
|
||
|
|
"trader.training.direction_opportunity_dataset_written runId=%s opportunityBps=%.6f minAdvantageBps=%.6f rowCount=%s outputPath=%s",
|
||
|
|
args.run_id,
|
||
|
|
opportunity_bps,
|
||
|
|
min_advantage_bps,
|
||
|
|
len(out),
|
||
|
|
output_path,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _opportunity_labels(long_edge: np.ndarray, short_edge: np.ndarray, opportunity_bps: float, min_advantage_bps: float) -> dict[str, np.ndarray]:
|
||
|
|
long_clean = np.nan_to_num(long_edge, nan=-np.inf)
|
||
|
|
short_clean = np.nan_to_num(short_edge, nan=-np.inf)
|
||
|
|
long_ok = long_clean >= opportunity_bps
|
||
|
|
short_ok = short_clean >= opportunity_bps
|
||
|
|
long_wins = long_ok & ((long_clean - short_clean) >= min_advantage_bps)
|
||
|
|
short_wins = short_ok & ((short_clean - long_clean) >= min_advantage_bps)
|
||
|
|
neutral = ~(long_wins | short_wins)
|
||
|
|
return {
|
||
|
|
"long_target": long_wins.astype("int8"),
|
||
|
|
"short_target": short_wins.astype("int8"),
|
||
|
|
"neutral_target": neutral.astype("int8"),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _target_rates_by_split(frame: pd.DataFrame) -> dict[str, dict[str, float]]:
|
||
|
|
result: dict[str, dict[str, float]] = {}
|
||
|
|
for split_id, part in frame.groupby("split_id", observed=False):
|
||
|
|
rows = len(part)
|
||
|
|
result[str(split_id)] = {
|
||
|
|
"rows": float(rows),
|
||
|
|
"long_rate": float(part["long_target"].mean()) if rows else 0.0,
|
||
|
|
"short_rate": float(part["short_target"].mean()) if rows else 0.0,
|
||
|
|
"neutral_rate": float(part["neutral_target"].mean()) if rows else 0.0,
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def _markdown_report(result: dict[str, Any]) -> str:
|
||
|
|
lines = [
|
||
|
|
"# Direction 机会标签数据集报告",
|
||
|
|
"",
|
||
|
|
"这份数据集把 Direction 目标从“未来收盘收益方向”改为“未来路径里哪边有可交易空间”。",
|
||
|
|
"",
|
||
|
|
f"- label_method: `{result['label_method']}`",
|
||
|
|
f"- long_edge_column: `{result['long_edge_column']}`",
|
||
|
|
f"- short_edge_column: `{result['short_edge_column']}`",
|
||
|
|
f"- opportunity_bps: `{result['opportunity_bps']}`",
|
||
|
|
f"- min_advantage_bps: `{result['min_advantage_bps']}`",
|
||
|
|
f"- row_count: `{result['dataset']['row_count']}`",
|
||
|
|
"",
|
||
|
|
"## 标签数量",
|
||
|
|
"",
|
||
|
|
f"- long: `{result['target_counts']['long']}`",
|
||
|
|
f"- short: `{result['target_counts']['short']}`",
|
||
|
|
f"- neutral: `{result['target_counts']['neutral']}`",
|
||
|
|
"",
|
||
|
|
"## 分段比例",
|
||
|
|
"",
|
||
|
|
"| split | rows | long | short | neutral |",
|
||
|
|
"| --- | ---: | ---: | ---: | ---: |",
|
||
|
|
]
|
||
|
|
for split_id, item in result["target_rates_by_split"].items():
|
||
|
|
lines.append(f"| {split_id} | {int(item['rows'])} | {item['long_rate']:.4f} | {item['short_rate']:.4f} | {item['neutral_rate']:.4f} |")
|
||
|
|
lines.append("")
|
||
|
|
return "\n".join(lines)
|