from __future__ import annotations import logging from typing import Any import pandas as pd from trader_training.io_utils import manifest, read_parquet, require_columns, run_root, write_json, write_parquet, write_text from trader_training.schemas import FEATURE_ORDER, TRAINING_SPLITS def _feature_base(root, args) -> pd.DataFrame: path = args.feature_path or root / "feature" / "feature_frame.parquet" frame = read_parquet(path) require_columns(frame, ("sample_id", "split_id", "data_quality_flag", *FEATURE_ORDER), "feature_frame") trainable = frame[frame["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])].copy() trainable = trainable[trainable["split_id"].isin(TRAINING_SPLITS)].copy() if trainable.empty: raise ValueError("no trainable feature rows; check feature quality report") return trainable def build_train_datasets(args: Any) -> None: root = run_root(args) feature = _feature_base(root, args) dataset_dir = root / "dataset" manifests = {} direction = read_parquet(args.direction_label_path or root / "label" / "direction_labels.parquet") direction_ds = feature.merge(direction[["sample_id", "long_target", "short_target", "neutral_target", "future_return_bps"]], on="sample_id", how="inner") manifests["direction"] = _write_dataset(dataset_dir / "direction_train.parquet", direction_ds) entry = read_parquet(args.entry_label_path or root / "label" / "entry_labels.parquet") entry_pivot = _entry_pivot(entry) entry_ds = feature.merge(entry_pivot, on="sample_id", how="inner") manifests["entry"] = _write_dataset(dataset_dir / "entry_train.parquet", entry_ds) continuation = read_parquet(args.continue_label_path or root / "label" / "continue_labels.parquet") continue_ds = feature.merge( continuation[["sample_id", "long_continue_target", "short_continue_target", "long_expected_continue_edge_bps", "short_expected_continue_edge_bps"]], on="sample_id", how="inner", ) manifests["continue"] = _write_dataset(dataset_dir / "continue_train.parquet", continue_ds) exit_labels = read_parquet(args.exit_label_path or root / "label" / "exit_labels.parquet") exit_cols = [ "sample_id", "long_exit_target", "short_exit_target", "long_adverse_move_bps", "short_adverse_move_bps", "adverse_move_prob_label", "reversal_prob_label", "stop_hit_prob_label", "stagnation_prob_label", ] exit_ds = feature.merge(exit_labels[exit_cols], on="sample_id", how="inner") manifests["exit"] = _write_dataset(dataset_dir / "exit_train.parquet", exit_ds) risk = read_parquet(args.risk_label_path or root / "label" / "risk_labels.parquet") risk_cols = [ "sample_id", "market_risk_target", "market_path_risk_bps", "long_position_path_risk_bps", "short_position_path_risk_bps", "long_position_risk_target", "short_position_risk_target", "market_drawdown_prob_label", "volatility_expansion_prob_label", "spike_prob_label", "liquidity_deterioration_prob_label", "position_drawdown_prob_label", ] risk_ds = feature.merge(risk[risk_cols], on="sample_id", how="inner") manifests["risk"] = _write_dataset(dataset_dir / "risk_train.parquet", risk_ds) write_json(dataset_dir / "dataset_manifest.json", {"datasets": manifests}) _write_dataset_report(dataset_dir / "dataset_quality_report.md", manifests) logging.info("trader.training.datasets_written runId=%s datasets=%s", args.run_id, sorted(manifests)) def _entry_pivot(entry: pd.DataFrame) -> pd.DataFrame: require_columns(entry, ("sample_id", "side", "entry_target", "expected_net_edge_bps"), "entry_labels") long = entry[entry["side"] == "LONG"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename( columns={"entry_target": "long_entry_target", "expected_net_edge_bps": "long_expected_net_edge_bps"} ) short = entry[entry["side"] == "SHORT"][["sample_id", "entry_target", "expected_net_edge_bps"]].rename( columns={"entry_target": "short_entry_target", "expected_net_edge_bps": "short_expected_net_edge_bps"} ) return long.merge(short, on="sample_id", how="inner") def _write_dataset(path, frame: pd.DataFrame) -> dict: data_hash = write_parquet(path, frame) return manifest( path, { "row_count": len(frame), "feature_count": len(FEATURE_ORDER), "data_hash_sha256": data_hash, "split_counts": frame["split_id"].value_counts().to_dict() if "split_id" in frame.columns else {}, }, ) def _write_dataset_report(path, manifests: dict) -> None: lines = ["# Trader Dataset Quality Report", "", "| dataset | rows | hash |", "| --- | ---: | --- |"] for name, item in manifests.items(): lines.append(f"| {name} | {item['row_count']} | {item['data_hash_sha256']} |") write_text(path, "\n".join(lines) + "\n")