Files

113 lines
4.9 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import logging
from typing import Any
import pandas as pd
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS
def _feature_base(root, args) -> pd.DataFrame:
path = args.feature_path or root / "feature" / "feature_frame.parquet"
frame = read_parquet(path)
require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame")
trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy()
trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy()
if trainable.empty:
raise ValueError("no trainable feature rows; check feature quality report")
return trainable
def build_train_datasets(args: Any) -> None:
root = run_root(args)
feature = _feature_base(root, args)
dataset_dir = root / "dataset"
manifests = {}
direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet")
direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner")
manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds)
entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet")
entry_pivot = _entry_pivot(entry)
entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner")
manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds)
continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet")
continue_ds = feature.merge(
continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]],
on="sample_id",
how="inner",
)
manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds)
exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet")
exit_cols = [
"sample_id",
"long_exit_target",
"short_exit_target",
"long_adverse_move_bps",
"short_adverse_move_bps",
"adverse_move_prob_label",
"reversal_prob_label",
"stop_hit_prob_label",
"stagnation_prob_label",
]
exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner")
manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds)
risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet")
risk_cols = [
"sample_id",
"market_risk_target",
"market_path_risk_bps",
"long_position_path_risk_bps",
"short_position_path_risk_bps",
"long_position_risk_target",
"short_position_risk_target",
"market_drawdown_prob_label",
"volatility_expansion_prob_label",
"spike_prob_label",
"liquidity_deterioration_prob_label",
"position_drawdown_prob_label",
]
risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner")
manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds)
write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests})
_write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests)
logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests))
def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame:
require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels")
long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"}
)
short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"}
)
return long.merge(short, on="sample_id", how="inner")
def _write_dataset(path, frame: pd.DataFrame) -> dict:
data_hash = write_parquet(path, frame)
return manifest(
path,
{
"row_count": len(frame),
"feature_count": len(FEATURE_ORDER),
"data_hash_sha256": data_hash,
"split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {},
},
)
def _write_dataset_report(path, manifests: dict) -> None:
lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"]
for name, item in manifests.items():
lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |")
write_text(path, "\n".join(lines) + "\n")