Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text
|
||||
from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS
|
||||
|
||||
|
||||
def _feature_base(root, args) -> pd.DataFrame:
|
||||
path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
||||
frame = read_parquet(path)
|
||||
require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame")
|
||||
trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy()
|
||||
trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy()
|
||||
if trainable.empty:
|
||||
raise ValueError("no trainable feature rows; check feature quality report")
|
||||
return trainable
|
||||
|
||||
|
||||
def build_train_datasets(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
feature = _feature_base(root, args)
|
||||
dataset_dir = root / "dataset"
|
||||
manifests = {}
|
||||
|
||||
direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet")
|
||||
direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner")
|
||||
manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds)
|
||||
|
||||
entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet")
|
||||
entry_pivot = _entry_pivot(entry)
|
||||
entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner")
|
||||
manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds)
|
||||
|
||||
continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet")
|
||||
continue_ds = feature.merge(
|
||||
continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]],
|
||||
on="sample_id",
|
||||
how="inner",
|
||||
)
|
||||
manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds)
|
||||
|
||||
exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet")
|
||||
exit_cols = [
|
||||
"sample_id",
|
||||
"long_exit_target",
|
||||
"short_exit_target",
|
||||
"long_adverse_move_bps",
|
||||
"short_adverse_move_bps",
|
||||
"adverse_move_prob_label",
|
||||
"reversal_prob_label",
|
||||
"stop_hit_prob_label",
|
||||
"stagnation_prob_label",
|
||||
]
|
||||
exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner")
|
||||
manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds)
|
||||
|
||||
risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet")
|
||||
risk_cols = [
|
||||
"sample_id",
|
||||
"market_risk_target",
|
||||
"market_path_risk_bps",
|
||||
"long_position_path_risk_bps",
|
||||
"short_position_path_risk_bps",
|
||||
"long_position_risk_target",
|
||||
"short_position_risk_target",
|
||||
"market_drawdown_prob_label",
|
||||
"volatility_expansion_prob_label",
|
||||
"spike_prob_label",
|
||||
"liquidity_deterioration_prob_label",
|
||||
"position_drawdown_prob_label",
|
||||
]
|
||||
risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner")
|
||||
manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds)
|
||||
|
||||
write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests})
|
||||
_write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests)
|
||||
logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests))
|
||||
|
||||
|
||||
def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame:
|
||||
require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels")
|
||||
long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
|
||||
columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"}
|
||||
)
|
||||
short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename(
|
||||
columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"}
|
||||
)
|
||||
return long.merge(short, on="sample_id", how="inner")
|
||||
|
||||
|
||||
def _write_dataset(path, frame: pd.DataFrame) -> dict:
|
||||
data_hash = write_parquet(path, frame)
|
||||
return manifest(
|
||||
path,
|
||||
{
|
||||
"row_count": len(frame),
|
||||
"feature_count": len(FEATURE_ORDER),
|
||||
"data_hash_sha256": data_hash,
|
||||
"split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _write_dataset_report(path, manifests: dict) -> None:
|
||||
lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"]
|
||||
for name, item in manifests.items():
|
||||
lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |")
|
||||
write_text(path, "\n".join(lines) + "\n")
|
||||
Reference in New Issue
Block a user