from __future__ import annotations import json import logging from pathlib import Path from typing import Any import numpy as np import pandas as pd from trader_training.io_utils import ( DEFAULT_RAW_ROOT, ensure_dir, manifest, open_time_ms, partition_files, read_json, read_partitioned_table, read_parquet, require_columns, run_root, to_utc_series, utc_now_text, write_json, write_parquet, write_text, ) from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, SPLIT_VERSION, TRAINING_SPLITS, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT def audit_source_data(data_root: Path, symbol: str, start_date: str | None, end_date: str | None, min_ready_days: int = 250) -> dict[str, Any]: raw_root = data_root / "crypto-lake" / "raw" required_tables = ("candles", "trades", "level_1", "funding", "open_interest") optional_tables = ("liquidations",) rows: list[dict[str, Any]] = [] table_dates: dict[str, set[str]] = {} for table in required_tables + optional_tables: files = partition_files(raw_root, table, symbol, start_date, end_date) dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files}) table_dates[table] = set(dates) rows.append( { "table": table, "required": table in required_tables, "file_count": len(files), "first_date": dates[0] if dates else None, "last_date": dates[-1] if dates else None, "status": "OK" if files or table in optional_tables else "MISSING", } ) all_dates = _audit_date_range(table_dates, required_tables, start_date, end_date) replay_ready_days = [] excluded_days = [] for day in all_dates: missing_required = [table for table in required_tables if day not in table_dates[table]] missing_optional = [table for table in optional_tables if day not in table_dates[table]] if missing_required: excluded_days.append({"date": day, "reason": "MISSING_REQUIRED_TABLE", "missing_required_tables": missing_required, "missing_optional_tables": missing_optional}) else: replay_ready_days.append(day) result = { "symbol": symbol, "start_date": start_date, "end_date": end_date, "raw_root": str(raw_root), "tables": rows, "replay_ready_day_count": len(replay_ready_days), "excluded_day_count": len(excluded_days), "replay_ready_days": replay_ready_days, "excluded_days": excluded_days, "created_at": utc_now_text(), "ready": all(row["status"] == "OK" for row in rows if row["required"]) and len(replay_ready_days) >= min_ready_days, } return result def write_audit_outputs(args: Any) -> None: root = run_root(args) result = audit_source_data(args.data_root, args.symbol, args.start_date, args.end_date, int(args.min_ready_days)) path = root / "raw-manifest" / "source_data_audit.json" write_json(path, result) write_json(root / "raw-manifest" / "source_data_manifest.json", result) write_json(root / "raw-manifest" / "excluded_days.json", result["excluded_days"]) write_text(root / "raw-manifest" / "replay_ready_days.txt", "\n".join(result["replay_ready_days"]) + ("\n" if result["replay_ready_days"] else "")) report_lines = [ "# Trader Source Data Audit", "", f"- symbol: {result['symbol']}", f"- raw_root: {result['raw_root']}", f"- ready: {result['ready']}", f"- replay_ready_day_count: {result['replay_ready_day_count']}", f"- excluded_day_count: {result['excluded_day_count']}", "", "| table | required | file_count | first_date | last_date | status |", "| --- | --- | ---: | --- | --- | --- |", ] for row in result["tables"]: report_lines.append( f"| {row['table']} | {row['required']} | {row['file_count']} | {row['first_date']} | {row['last_date']} | {row['status']} |" ) write_text(root / "raw-manifest" / "source_data_audit.md", "\n".join(report_lines) + "\n") logging.info( "trader.training.audit_written runId=%s ready=%s readyDays=%s excludedDays=%s path=%s", args.run_id, result["ready"], result["replay_ready_day_count"], result["excluded_day_count"], path, ) if not result["ready"]: raise SystemExit("required raw tables are missing; see source_data_audit.md") def _audit_date_range(table_dates: dict[str, set[str]], required_tables: tuple[str, ...], start_date: str | None, end_date: str | None) -> list[str]: if start_date and end_date: start = pd.Timestamp(start_date) end = pd.Timestamp(end_date) else: dates = sorted(set().union(*(table_dates[table] for table in required_tables))) if not dates: return [] start = pd.Timestamp(start_date or dates[0]) end = pd.Timestamp(end_date or dates[-1]) return [day.strftime("%Y-%m-%d") for day in pd.date_range(start, end, freq="D")] def _minute_frame(frame: pd.DataFrame, time_column: str = "origin_time") -> pd.DataFrame: frame = frame.copy() frame["event_time"] = to_utc_series(frame[time_column]).dt.floor("min") frame["open_time_ms"] = open_time_ms(frame["event_time"]) return frame def _read_candles(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: candles = read_partitioned_table( raw_root, "candles", symbol, start_date, end_date, columns=("origin_time", "start", "open", "high", "low", "close", "volume", "symbol"), ) if candles.empty: raise ValueError("candles raw data is required to build replay_1m") time_col = "start" if "start" in candles.columns else "origin_time" candles = _minute_frame(candles, time_col) keep = ["symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "volume"] candles = candles[keep].sort_values(["symbol", "event_time"]).drop_duplicates(["symbol", "event_time"], keep="last") for column in ("open", "high", "low", "close", "volume"): candles[column] = pd.to_numeric(candles[column], errors="coerce") return candles def _read_trades(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: trades = read_partitioned_table( raw_root, "trades", symbol, start_date, end_date, columns=("origin_time", "side", "quantity", "symbol"), ) if trades.empty: raise ValueError("trades raw data is required for taker imbalance") trades = _minute_frame(trades) trades["quantity"] = pd.to_numeric(trades["quantity"], errors="coerce").fillna(0.0) side = trades["side"].astype(str).str.upper() trades["taker_buy_volume"] = np.where(side.eq("BUY"), trades["quantity"], 0.0) trades["taker_sell_volume"] = np.where(side.eq("SELL"), trades["quantity"], 0.0) return ( trades.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[["taker_buy_volume", "taker_sell_volume"]] .sum() ) def _read_level1(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: level1 = read_partitioned_table( raw_root, "level_1", symbol, start_date, end_date, columns=("origin_time", "bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size", "symbol"), ) if level1.empty: raise ValueError("level_1 raw data is required for spread and OFI") level1 = _minute_frame(level1) for column in ("bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"): level1[column] = pd.to_numeric(level1[column], errors="coerce") level1 = level1.dropna(subset=["bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"]) level1 = level1.sort_values(["symbol", "event_time", "origin_time"]) group = level1.groupby("symbol", sort=False, observed=True) prev_bid_price = group["bid_0_price"].shift(1) prev_bid_size = group["bid_0_size"].shift(1) prev_ask_price = group["ask_0_price"].shift(1) prev_ask_size = group["ask_0_size"].shift(1) bid_ofi = np.select( [level1["bid_0_price"] > prev_bid_price, level1["bid_0_price"].eq(prev_bid_price)], [level1["bid_0_size"], level1["bid_0_size"] - prev_bid_size], default=-prev_bid_size, ) ask_ofi = np.select( [level1["ask_0_price"] < prev_ask_price, level1["ask_0_price"].eq(prev_ask_price)], [level1["ask_0_size"], prev_ask_size - level1["ask_0_size"]], default=-prev_ask_size, ) level1["ofi_raw"] = np.nan_to_num(bid_ofi + ask_ofi, nan=0.0) level1["depth"] = (level1["bid_0_size"] + level1["ask_0_size"]).clip(lower=1e-12) level1["mid"] = (level1["bid_0_price"] + level1["ask_0_price"]) / 2.0 level1["spread_bps"] = (level1["ask_0_price"] - level1["bid_0_price"]) / level1["mid"] * 10000.0 agg = level1.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True).agg( best_bid_price=("bid_0_price", "last"), best_ask_price=("ask_0_price", "last"), spread_bps=("spread_bps", "last"), ofi_sum=("ofi_raw", "sum"), depth_mean=("depth", "mean"), ) agg["level1_ofi_1m"] = agg["ofi_sum"] / agg["depth_mean"].clip(lower=1e-12) return agg.drop(columns=["ofi_sum", "depth_mean"]) def _read_liquidations(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame: files = partition_files(raw_root, "liquidations", symbol, start_date, end_date) if not files: return pd.DataFrame(columns=["symbol", "event_time", "open_time_ms", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"]) liquidations = read_partitioned_table( raw_root, "liquidations", symbol, start_date, end_date, columns=("origin_time", "side", "quantity", "price", "symbol"), ) liquidations = _minute_frame(liquidations) liquidations["quantity"] = pd.to_numeric(liquidations["quantity"], errors="coerce").fillna(0.0) liquidations["price"] = pd.to_numeric(liquidations["price"], errors="coerce").fillna(0.0) liquidations["notional"] = liquidations["quantity"] * liquidations["price"] side = liquidations["side"].astype(str).str.upper() liquidations["liquidation_buy_notional_1m"] = np.where(side.eq("BUY"), liquidations["notional"], 0.0) liquidations["liquidation_sell_notional_1m"] = np.where(side.eq("SELL"), liquidations["notional"], 0.0) agg = liquidations.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[ ["liquidation_buy_notional_1m", "liquidation_sell_notional_1m"] ].sum() agg["liquidation_available"] = 1.0 return agg def _asof_column( replay: pd.DataFrame, raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None, columns: tuple[str, ...], ) -> pd.DataFrame: frame = read_partitioned_table(raw_root, table, symbol, start_date, end_date, columns=("origin_time", "symbol", *columns)) if frame.empty: raise ValueError(f"{table} raw data is required") frame = _minute_frame(frame) for column in columns: if column.endswith("time"): continue frame[column] = pd.to_numeric(frame[column], errors="coerce") frame = frame.sort_values(["symbol", "event_time"]) left = replay[["symbol", "event_time"]].sort_values(["symbol", "event_time"]) merged = pd.merge_asof( left, frame[["symbol", "event_time", *columns]].sort_values(["symbol", "event_time"]), by="symbol", on="event_time", direction="backward", tolerance=pd.Timedelta(hours=12), ) return merged def build_replay_1m(args: Any) -> None: root = run_root(args) raw_root = args.raw_root or DEFAULT_RAW_ROOT logging.info("trader.training.replay_started runId=%s symbol=%s rawRoot=%s", args.run_id, args.symbol, raw_root) replay = _read_candles(raw_root, args.symbol, args.start_date, args.end_date) trades = _read_trades(raw_root, args.symbol, args.start_date, args.end_date) level1 = _read_level1(raw_root, args.symbol, args.start_date, args.end_date) liquidations = _read_liquidations(raw_root, args.symbol, args.start_date, args.end_date) replay = replay.merge(trades, on=["symbol", "event_time", "open_time_ms"], how="left") replay = replay.merge(level1, on=["symbol", "event_time", "open_time_ms"], how="left") replay = replay.merge(liquidations, on=["symbol", "event_time", "open_time_ms"], how="left") replay[["taker_buy_volume", "taker_sell_volume"]] = replay[["taker_buy_volume", "taker_sell_volume"]].fillna(0.0) for column in ("liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"): replay[column] = replay[column].fillna(0.0) funding = _asof_column(replay, raw_root, "funding", args.symbol, args.start_date, args.end_date, ("rate", "mark_price", "index_price", "next_funding_time")) funding = funding.rename(columns={"rate": "funding_rate"}) funding["funding_bps"] = pd.to_numeric(funding["funding_rate"], errors="coerce") * 10000.0 replay = replay.merge(funding.drop(columns=["funding_rate"]), on=["symbol", "event_time"], how="left") replay["next_funding_time"] = to_utc_series(replay["next_funding_time"]) oi = _asof_column(replay, raw_root, "open_interest", args.symbol, args.start_date, args.end_date, ("open_interest",)) replay = replay.merge(oi, on=["symbol", "event_time"], how="left") replay["timeframe"] = "1m" replay["source_coverage"] = "crypto_lake_raw" required = [ "open", "high", "low", "close", "volume", "best_bid_price", "best_ask_price", "spread_bps", "level1_ofi_1m", "funding_bps", "mark_price", "index_price", "open_interest", ] replay["event_date"] = replay["event_time"].dt.strftime("%Y-%m-%d") missing_required = replay[required].isna().any(axis=1) day_quality = ( replay.assign(missing_required=missing_required.astype(int)) .groupby("event_date", as_index=False, observed=True) .agg(row_count=("event_time", "count"), missing_required_rows=("missing_required", "sum")) ) day_quality["ready"] = (day_quality["row_count"] >= int(args.min_minutes_per_day)) & day_quality["missing_required_rows"].eq(0) ready_days = sorted(day_quality.loc[day_quality["ready"], "event_date"].astype(str).tolist()) excluded_days = [ { "date": row.event_date, "row_count": int(row.row_count), "missing_required_rows": int(row.missing_required_rows), "reason": "MISSING_REQUIRED_MARKET_FIELDS" if int(row.missing_required_rows) else "INCOMPLETE_MINUTE_COUNT", } for row in day_quality.loc[~day_quality["ready"]].itertuples(index=False) ] if len(ready_days) < int(args.min_replay_ready_days): write_json(root / "replay" / "excluded_days.json", excluded_days) write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + ("\n" if ready_days else "")) raise ValueError(f"replay_1m has only {len(ready_days)} replay-ready days, required {args.min_replay_ready_days}") before_filter = len(replay) replay = replay[replay["event_date"].isin(ready_days)].copy() logging.info( "trader.training.replay_ready_days_selected runId=%s readyDays=%s excludedDays=%s rowBefore=%s rowAfter=%s", args.run_id, len(ready_days), len(excluded_days), before_filter, len(replay), ) columns = [ "symbol", "timeframe", "event_time", "open_time_ms", "open", "high", "low", "close", "volume", "taker_buy_volume", "taker_sell_volume", "funding_bps", "mark_price", "index_price", "next_funding_time", "open_interest", "best_bid_price", "best_ask_price", "spread_bps", "level1_ofi_1m", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available", "source_coverage", ] replay = replay[columns].sort_values(["symbol", "event_time"]).reset_index(drop=True) path = root / "replay" / "replay_1m.parquet" data_hash = write_parquet(path, replay) write_json( root / "replay" / "replay_1m.manifest.json", manifest( path, { "row_count": len(replay), "hash_sha256": data_hash, "replay_ready_day_count": len(ready_days), "excluded_day_count": len(excluded_days), "min_minutes_per_day": int(args.min_minutes_per_day), }, ), ) write_json(root / "replay" / "excluded_days.json", excluded_days) write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + "\n") logging.info("trader.training.replay_written runId=%s rowCount=%s readyDays=%s path=%s", args.run_id, len(replay), len(ready_days), path) def build_splits(args: Any) -> None: root = run_root(args) replay_path = args.replay_path or root / "replay" / "replay_1m.parquet" replay = read_parquet(replay_path) require_columns(replay, ("event_time", "symbol"), "replay_1m") replay["event_time"] = to_utc_series(replay["event_time"]) replay = replay.sort_values(["event_time", "symbol"]).reset_index(drop=True) if len(replay) < 10: raise ValueError("not enough replay rows to build time splits") gap = int(args.gap_minutes) intervals = _fixed_split_intervals(args, gap) replay_start = replay["event_time"].min() replay_end = replay["event_time"].max() intervals = [ (split_id, max(start, replay_start), min(end, replay_end)) for split_id, start, end in intervals if max(start, replay_start) <= min(end, replay_end) ] if {item[0] for item in intervals} != set(TRAINING_SPLITS): raise ValueError(f"fixed split dates do not fit replay coverage: replay_start={replay_start} replay_end={replay_end}") split_manifest = { "split_version": SPLIT_VERSION, "created_at": utc_now_text(), "source_replay_path": str(replay_path), "gap_minutes": gap, # Sealed splits are withheld from broad parameter search. They only answer # whether a finished candidate survives final validation and recent stress. "sealed_splits": [VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT], "latest_stress_policy": "FINAL_GATE_ONLY", "requested_splits": { FIT_SPLIT: [args.fit_inner_start, args.fit_inner_end], TUNE_SPLIT: [args.tune_inner_start, args.tune_inner_end], VALIDATION_LOCKED_SPLIT: [args.validation_locked_start, args.validation_locked_end], LATEST_STRESS_SPLIT: [args.latest_stress_start, args.latest_stress_end], }, "splits": [ {"split_id": split_id, "start": start.isoformat().replace("+00:00", "Z"), "end": end.isoformat().replace("+00:00", "Z")} for split_id, start, end in intervals if start <= end ], } fold_count = max(1, int(args.fold_count)) fit_interval = next(item for item in intervals if item[0] == FIT_SPLIT) tune_interval = next(item for item in intervals if item[0] == TUNE_SPLIT) train_times = pd.Series(pd.date_range(fit_interval[1], fit_interval[2], periods=fold_count + 1)) folds = [] for idx in range(fold_count): folds.append( { "walk_forward_fold": f"fold_{idx + 1:02d}", "train_start": fit_interval[1].isoformat().replace("+00:00", "Z"), "train_end": train_times.iloc[idx + 1].isoformat().replace("+00:00", "Z"), "validation_start": tune_interval[1].isoformat().replace("+00:00", "Z"), "validation_end": tune_interval[2].isoformat().replace("+00:00", "Z"), } ) ensure_dir(root / "split") write_json(root / "split" / "split_manifest.json", split_manifest) write_json(root / "split" / "walk_forward_folds.json", {"split_version": SPLIT_VERSION, "folds": folds}) _write_purge_embargo_report(root / "split" / "purge_embargo_report.md", intervals, gap) logging.info("trader.training.splits_written runId=%s splitCount=%s foldCount=%s", args.run_id, len(split_manifest["splits"]), len(folds)) def assign_split(event_times: pd.Series, split_manifest_path: Path) -> pd.Series: manifest_data = read_json(split_manifest_path) result = pd.Series("NO_SPLIT", index=event_times.index, dtype="object") values = to_utc_series(event_times) for item in manifest_data["splits"]: start = pd.Timestamp(item["start"]) end = pd.Timestamp(item["end"]) mask = values.between(start, end, inclusive="both") result.loc[mask] = item["split_id"] return result def _fixed_split_intervals(args: Any, gap_minutes: int) -> list[tuple[str, pd.Timestamp, pd.Timestamp]]: gap = pd.Timedelta(minutes=gap_minutes) return [ (FIT_SPLIT, _start_of_day(args.fit_inner_start), _end_of_day(args.fit_inner_end) - gap), (TUNE_SPLIT, _start_of_day(args.tune_inner_start) + gap, _end_of_day(args.tune_inner_end) - gap), (VALIDATION_LOCKED_SPLIT, _start_of_day(args.validation_locked_start) + gap, _end_of_day(args.validation_locked_end) - gap), (LATEST_STRESS_SPLIT, _start_of_day(args.latest_stress_start) + gap, _end_of_day(args.latest_stress_end)), ] def _start_of_day(value: str) -> pd.Timestamp: return pd.Timestamp(value, tz="UTC") def _end_of_day(value: str) -> pd.Timestamp: return pd.Timestamp(value, tz="UTC") + pd.Timedelta(days=1) - pd.Timedelta(minutes=1) def _write_purge_embargo_report(path: Path, intervals: list[tuple[str, pd.Timestamp, pd.Timestamp]], gap_minutes: int) -> None: lines = ["# Purge Embargo Report", "", f"- gap_minutes: {gap_minutes}", "", "| split_id | start | end |", "| --- | --- | --- |"] for split_id, start, end in intervals: lines.append(f"| {split_id} | {start.isoformat()} | {end.isoformat()} |") write_text(path, "\n".join(lines) + "\n")