113 lines
4.9 KiB
Python
113 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
import pandas as pd
|
|
|
|
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
|
|
from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS
|
|
|
|
|
|
def _feature_base(root, args) -> pd.DataFrame:
|
|
path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
|
frame = read_parquet(path)
|
|
require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame")
|
|
trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy()
|
|
trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy()
|
|
if trainable.empty:
|
|
raise ValueError("no trainable feature rows; check feature quality report")
|
|
return trainable
|
|
|
|
|
|
def build_train_datasets(args: Any) -> None:
|
|
root = run_root(args)
|
|
feature = _feature_base(root, args)
|
|
dataset_dir = root / "dataset"
|
|
manifests = {}
|
|
|
|
direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet")
|
|
direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner")
|
|
manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds)
|
|
|
|
entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet")
|
|
entry_pivot = _entry_pivot(entry)
|
|
entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner")
|
|
manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds)
|
|
|
|
continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet")
|
|
continue_ds = feature.merge(
|
|
continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]],
|
|
on="sample_id",
|
|
how="inner",
|
|
)
|
|
manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds)
|
|
|
|
exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet")
|
|
exit_cols = [
|
|
"sample_id",
|
|
"long_exit_target",
|
|
"short_exit_target",
|
|
"long_adverse_move_bps",
|
|
"short_adverse_move_bps",
|
|
"adverse_move_prob_label",
|
|
"reversal_prob_label",
|
|
"stop_hit_prob_label",
|
|
"stagnation_prob_label",
|
|
]
|
|
exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner")
|
|
manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds)
|
|
|
|
risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet")
|
|
risk_cols = [
|
|
"sample_id",
|
|
"market_risk_target",
|
|
"market_path_risk_bps",
|
|
"long_position_path_risk_bps",
|
|
"short_position_path_risk_bps",
|
|
"long_position_risk_target",
|
|
"short_position_risk_target",
|
|
"market_drawdown_prob_label",
|
|
"volatility_expansion_prob_label",
|
|
"spike_prob_label",
|
|
"liquidity_deterioration_prob_label",
|
|
"position_drawdown_prob_label",
|
|
]
|
|
risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner")
|
|
manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds)
|
|
|
|
write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests})
|
|
_write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests)
|
|
logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests))
|
|
|
|
|
|
def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame:
|
|
require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels")
|
|
long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
|
|
columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"}
|
|
)
|
|
short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
|
|
columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"}
|
|
)
|
|
return long.merge(short, on="sample_id", how="inner")
|
|
|
|
|
|
def _write_dataset(path, frame: pd.DataFrame) -> dict:
|
|
data_hash = write_parquet(path, frame)
|
|
return manifest(
|
|
path,
|
|
{
|
|
"row_count": len(frame),
|
|
"feature_count": len(FEATURE_ORDER),
|
|
"data_hash_sha256": data_hash,
|
|
"split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {},
|
|
},
|
|
)
|
|
|
|
|
|
def _write_dataset_report(path, manifests: dict) -> None:
|
|
lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"]
|
|
for name, item in manifests.items():
|
|
lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |")
|
|
write_text(path, "\n".join(lines) + "\n")
|