Files
quant-trader-service/training/trader_training/replay.py
T
Codex 9acb3460a1 Improve Trader V4 training pipeline
Align entry labels with max future edge, tune direction labeling, and harden regression evaluation.

Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
2026-06-27 19:57:29 +08:00

551 lines
24 KiB
Python

from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from trader_training.io_utils import (
DEFAULT_RAW_ROOT,
ensure_dir,
manifest,
open_time_ms,
partition_files,
read_json,
read_partitioned_table,
read_parquet,
require_columns,
run_root,
to_utc_series,
utc_now_text,
write_json,
write_parquet,
write_text,
)
from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, SPLIT_VERSION, TRAINING_SPLITS, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
def audit_source_data(data_root: Path, symbol: str, start_date: str | None, end_date: str | None, min_ready_days: int = 250) -> dict[str, Any]:
raw_root = data_root / "crypto-lake" / "raw"
required_tables = ("candles", "trades", "level_1", "funding", "open_interest")
optional_tables = ("liquidations",)
rows: list[dict[str, Any]] = []
table_dates: dict[str, set[str]] = {}
for table in required_tables + optional_tables:
files = partition_files(raw_root, table, symbol, start_date, end_date)
dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files})
table_dates[table] = set(dates)
rows.append(
{
"table": table,
"required": table in required_tables,
"file_count": len(files),
"first_date": dates[0] if dates else None,
"last_date": dates[-1] if dates else None,
"status": "OK" if files or table in optional_tables else "MISSING",
}
)
all_dates = _audit_date_range(table_dates, required_tables, start_date, end_date)
replay_ready_days = []
excluded_days = []
for day in all_dates:
missing_required = [table for table in required_tables if day not in table_dates[table]]
missing_optional = [table for table in optional_tables if day not in table_dates[table]]
if missing_required:
excluded_days.append({"date": day, "reason": "MISSING_REQUIRED_TABLE", "missing_required_tables": missing_required, "missing_optional_tables": missing_optional})
else:
replay_ready_days.append(day)
result = {
"symbol": symbol,
"start_date": start_date,
"end_date": end_date,
"raw_root": str(raw_root),
"tables": rows,
"replay_ready_day_count": len(replay_ready_days),
"excluded_day_count": len(excluded_days),
"replay_ready_days": replay_ready_days,
"excluded_days": excluded_days,
"created_at": utc_now_text(),
"ready": all(row["status"] == "OK" for row in rows if row["required"]) and len(replay_ready_days) >= min_ready_days,
}
return result
def write_audit_outputs(args: Any) -> None:
root = run_root(args)
result = audit_source_data(args.data_root, args.symbol, args.start_date, args.end_date, int(args.min_ready_days))
path = root / "raw-manifest" / "source_data_audit.json"
write_json(path, result)
write_json(root / "raw-manifest" / "source_data_manifest.json", result)
write_json(root / "raw-manifest" / "excluded_days.json", result["excluded_days"])
write_text(root / "raw-manifest" / "replay_ready_days.txt", "\n".join(result["replay_ready_days"]) + ("\n" if result["replay_ready_days"] else ""))
report_lines = [
"# Trader Source Data Audit",
"",
f"- symbol: {result['symbol']}",
f"- raw_root: {result['raw_root']}",
f"- ready: {result['ready']}",
f"- replay_ready_day_count: {result['replay_ready_day_count']}",
f"- excluded_day_count: {result['excluded_day_count']}",
"",
"| table | required | file_count | first_date | last_date | status |",
"| --- | --- | ---: | --- | --- | --- |",
]
for row in result["tables"]:
report_lines.append(
f"| {row['table']} | {row['required']} | {row['file_count']} | {row['first_date']} | {row['last_date']} | {row['status']} |"
)
write_text(root / "raw-manifest" / "source_data_audit.md", "\n".join(report_lines) + "\n")
logging.info(
"trader.training.audit_written runId=%s ready=%s readyDays=%s excludedDays=%s path=%s",
args.run_id,
result["ready"],
result["replay_ready_day_count"],
result["excluded_day_count"],
path,
)
if not result["ready"]:
raise SystemExit("required raw tables are missing; see source_data_audit.md")
def _audit_date_range(table_dates: dict[str, set[str]], required_tables: tuple[str, ...], start_date: str | None, end_date: str | None) -> list[str]:
if start_date and end_date:
start = pd.Timestamp(start_date)
end = pd.Timestamp(end_date)
else:
dates = sorted(set().union(*(table_dates[table] for table in required_tables)))
if not dates:
return []
start = pd.Timestamp(start_date or dates[0])
end = pd.Timestamp(end_date or dates[-1])
return [day.strftime("%Y-%m-%d") for day in pd.date_range(start, end, freq="D")]
def _minute_frame(frame: pd.DataFrame, time_column: str = "origin_time") -> pd.DataFrame:
frame = frame.copy()
frame["event_time"] = to_utc_series(frame[time_column]).dt.floor("min")
frame["open_time_ms"] = open_time_ms(frame["event_time"])
return frame
def _read_candles(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
candles = read_partitioned_table(
raw_root,
"candles",
symbol,
start_date,
end_date,
columns=("origin_time", "start", "open", "high", "low", "close", "volume", "symbol"),
)
if candles.empty:
raise ValueError("candles raw data is required to build replay_1m")
time_col = "start" if "start" in candles.columns else "origin_time"
candles = _minute_frame(candles, time_col)
keep = ["symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "volume"]
candles = candles[keep].sort_values(["symbol", "event_time"]).drop_duplicates(["symbol", "event_time"], keep="last")
for column in ("open", "high", "low", "close", "volume"):
candles[column] = pd.to_numeric(candles[column], errors="coerce")
return candles
def _read_trades(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
trades = read_partitioned_table(
raw_root,
"trades",
symbol,
start_date,
end_date,
columns=("origin_time", "side", "quantity", "symbol"),
)
if trades.empty:
raise ValueError("trades raw data is required for taker imbalance")
trades = _minute_frame(trades)
trades["quantity"] = pd.to_numeric(trades["quantity"], errors="coerce").fillna(0.0)
side = trades["side"].astype(str).str.upper()
trades["taker_buy_volume"] = np.where(side.eq("BUY"), trades["quantity"], 0.0)
trades["taker_sell_volume"] = np.where(side.eq("SELL"), trades["quantity"], 0.0)
return (
trades.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[["taker_buy_volume", "taker_sell_volume"]]
.sum()
)
def _read_level1(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
level1 = read_partitioned_table(
raw_root,
"level_1",
symbol,
start_date,
end_date,
columns=("origin_time", "bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size", "symbol"),
)
if level1.empty:
raise ValueError("level_1 raw data is required for spread and OFI")
level1 = _minute_frame(level1)
for column in ("bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"):
level1[column] = pd.to_numeric(level1[column], errors="coerce")
level1 = level1.dropna(subset=["bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"])
level1 = level1.sort_values(["symbol", "event_time", "origin_time"])
group = level1.groupby("symbol", sort=False, observed=True)
prev_bid_price = group["bid_0_price"].shift(1)
prev_bid_size = group["bid_0_size"].shift(1)
prev_ask_price = group["ask_0_price"].shift(1)
prev_ask_size = group["ask_0_size"].shift(1)
bid_ofi = np.select(
[level1["bid_0_price"] > prev_bid_price, level1["bid_0_price"].eq(prev_bid_price)],
[level1["bid_0_size"], level1["bid_0_size"] - prev_bid_size],
default=-prev_bid_size,
)
ask_ofi = np.select(
[level1["ask_0_price"] < prev_ask_price, level1["ask_0_price"].eq(prev_ask_price)],
[level1["ask_0_size"], prev_ask_size - level1["ask_0_size"]],
default=-prev_ask_size,
)
level1["ofi_raw"] = np.nan_to_num(bid_ofi + ask_ofi, nan=0.0)
level1["depth"] = (level1["bid_0_size"] + level1["ask_0_size"]).clip(lower=1e-12)
level1["mid"] = (level1["bid_0_price"] + level1["ask_0_price"]) / 2.0
level1["spread_bps"] = (level1["ask_0_price"] - level1["bid_0_price"]) / level1["mid"] * 10000.0
agg = level1.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True).agg(
best_bid_price=("bid_0_price", "last"),
best_ask_price=("ask_0_price", "last"),
spread_bps=("spread_bps", "last"),
ofi_sum=("ofi_raw", "sum"),
depth_mean=("depth", "mean"),
)
agg["level1_ofi_1m"] = agg["ofi_sum"] / agg["depth_mean"].clip(lower=1e-12)
return agg.drop(columns=["ofi_sum", "depth_mean"])
def _read_liquidations(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
files = partition_files(raw_root, "liquidations", symbol, start_date, end_date)
if not files:
return pd.DataFrame(columns=["symbol", "event_time", "open_time_ms", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"])
liquidations = read_partitioned_table(
raw_root,
"liquidations",
symbol,
start_date,
end_date,
columns=("origin_time", "side", "quantity", "price", "symbol"),
)
liquidations = _minute_frame(liquidations)
liquidations["quantity"] = pd.to_numeric(liquidations["quantity"], errors="coerce").fillna(0.0)
liquidations["price"] = pd.to_numeric(liquidations["price"], errors="coerce").fillna(0.0)
liquidations["notional"] = liquidations["quantity"] * liquidations["price"]
side = liquidations["side"].astype(str).str.upper()
liquidations["liquidation_buy_notional_1m"] = np.where(side.eq("BUY"), liquidations["notional"], 0.0)
liquidations["liquidation_sell_notional_1m"] = np.where(side.eq("SELL"), liquidations["notional"], 0.0)
agg = liquidations.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[
["liquidation_buy_notional_1m", "liquidation_sell_notional_1m"]
].sum()
agg["liquidation_available"] = 1.0
return agg
def _asof_column(
replay: pd.DataFrame,
raw_root: Path,
table: str,
symbol: str,
start_date: str | None,
end_date: str | None,
columns: tuple[str, ...],
) -> pd.DataFrame:
frame = read_partitioned_table(raw_root, table, symbol, start_date, end_date, columns=("origin_time", "symbol", *columns))
if frame.empty:
raise ValueError(f"{table} raw data is required")
frame = _minute_frame(frame)
for column in columns:
if column.endswith("time"):
continue
frame[column] = pd.to_numeric(frame[column], errors="coerce")
frame = frame.sort_values(["symbol", "event_time"])
left = replay[["symbol", "event_time"]].sort_values(["symbol", "event_time"])
merged = pd.merge_asof(
left,
frame[["symbol", "event_time", *columns]].sort_values(["symbol", "event_time"]),
by="symbol",
on="event_time",
direction="backward",
tolerance=pd.Timedelta(hours=12),
)
return merged
REPLAY_REQUIRED_COLUMNS = [
"open",
"high",
"low",
"close",
"volume",
"best_bid_price",
"best_ask_price",
"spread_bps",
"level1_ofi_1m",
"funding_bps",
"mark_price",
"index_price",
"open_interest",
]
REPLAY_OUTPUT_COLUMNS = [
"symbol",
"timeframe",
"event_time",
"open_time_ms",
"open",
"high",
"low",
"close",
"volume",
"taker_buy_volume",
"taker_sell_volume",
"funding_bps",
"mark_price",
"index_price",
"next_funding_time",
"open_interest",
"best_bid_price",
"best_ask_price",
"spread_bps",
"level1_ofi_1m",
"liquidation_buy_notional_1m",
"liquidation_sell_notional_1m",
"liquidation_available",
"source_coverage",
]
def _replay_date_texts(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> list[str]:
if start_date and end_date:
return [day.strftime("%Y-%m-%d") for day in pd.date_range(pd.Timestamp(start_date), pd.Timestamp(end_date), freq="D")]
files = partition_files(raw_root, "candles", symbol, start_date, end_date)
dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files})
return [date for date in dates if date]
def _previous_date_text(day: str) -> str:
return (pd.Timestamp(day) - pd.Timedelta(days=1)).strftime("%Y-%m-%d")
def _build_replay_day(raw_root: Path, symbol: str, day: str) -> pd.DataFrame:
replay = _read_candles(raw_root, symbol, day, day)
replay = replay[replay["event_time"].dt.strftime("%Y-%m-%d").eq(day)].copy()
trades = _read_trades(raw_root, symbol, day, day)
level1 = _read_level1(raw_root, symbol, day, day)
liquidations = _read_liquidations(raw_root, symbol, day, day)
replay = replay.merge(trades, on=["symbol", "event_time", "open_time_ms"], how="left")
replay = replay.merge(level1, on=["symbol", "event_time", "open_time_ms"], how="left")
replay = replay.merge(liquidations, on=["symbol", "event_time", "open_time_ms"], how="left")
replay[["taker_buy_volume", "taker_sell_volume"]] = replay[["taker_buy_volume", "taker_sell_volume"]].fillna(0.0)
for column in ("liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"):
replay[column] = pd.to_numeric(replay[column], errors="coerce").fillna(0.0)
# Funding and open interest are as-of values. Include the previous UTC day so
# the first minutes of a day can use the last known value without reading the
# whole training window into memory.
lookback_start = _previous_date_text(day)
funding = _asof_column(replay, raw_root, "funding", symbol, lookback_start, day, ("rate", "mark_price", "index_price", "next_funding_time"))
funding = funding.rename(columns={"rate": "funding_rate"})
funding["funding_bps"] = pd.to_numeric(funding["funding_rate"], errors="coerce") * 10000.0
replay = replay.merge(funding.drop(columns=["funding_rate"]), on=["symbol", "event_time"], how="left")
replay["next_funding_time"] = to_utc_series(replay["next_funding_time"])
oi = _asof_column(replay, raw_root, "open_interest", symbol, lookback_start, day, ("open_interest",))
replay = replay.merge(oi, on=["symbol", "event_time"], how="left")
replay["timeframe"] = "1m"
replay["source_coverage"] = "crypto_lake_raw"
replay["event_date"] = replay["event_time"].dt.strftime("%Y-%m-%d")
return replay
def build_replay_1m(args: Any) -> None:
root = run_root(args)
raw_root = args.raw_root or DEFAULT_RAW_ROOT
logging.info("trader.training.replay_started runId=%s symbol=%s rawRoot=%s", args.run_id, args.symbol, raw_root)
dates = _replay_date_texts(raw_root, args.symbol, args.start_date, args.end_date)
if not dates:
raise ValueError("no candle dates are available for replay_1m")
ready_days: list[str] = []
excluded_days: list[dict[str, Any]] = []
ready_frames: list[pd.DataFrame] = []
row_before_filter = 0
for index, day in enumerate(dates, start=1):
logging.info("trader.training.replay_day_started runId=%s day=%s index=%s total=%s", args.run_id, day, index, len(dates))
try:
day_replay = _build_replay_day(raw_root, args.symbol, day)
except Exception as exc:
excluded_days.append(
{
"date": day,
"row_count": 0,
"missing_required_rows": 0,
"reason": "DAY_BUILD_FAILED",
"error": str(exc),
}
)
logging.warning("trader.training.replay_day_failed runId=%s day=%s error=%s", args.run_id, day, exc)
continue
row_count = len(day_replay)
row_before_filter += row_count
missing_required_rows = int(day_replay[REPLAY_REQUIRED_COLUMNS].isna().any(axis=1).sum())
ready = row_count >= int(args.min_minutes_per_day) and missing_required_rows == 0
if ready:
ready_days.append(day)
ready_frames.append(day_replay[REPLAY_OUTPUT_COLUMNS].copy())
else:
excluded_days.append(
{
"date": day,
"row_count": int(row_count),
"missing_required_rows": missing_required_rows,
"reason": "MISSING_REQUIRED_MARKET_FIELDS" if missing_required_rows else "INCOMPLETE_MINUTE_COUNT",
}
)
logging.info(
"trader.training.replay_day_finished runId=%s day=%s ready=%s rows=%s missingRequiredRows=%s",
args.run_id,
day,
ready,
row_count,
missing_required_rows,
)
if len(ready_days) < int(args.min_replay_ready_days):
write_json(root / "replay" / "excluded_days.json", excluded_days)
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + ("\n" if ready_days else ""))
raise ValueError(f"replay_1m has only {len(ready_days)} replay-ready days, required {args.min_replay_ready_days}")
replay = pd.concat(ready_frames, ignore_index=True)
logging.info(
"trader.training.replay_ready_days_selected runId=%s readyDays=%s excludedDays=%s rowBefore=%s rowAfter=%s",
args.run_id,
len(ready_days),
len(excluded_days),
row_before_filter,
len(replay),
)
replay = replay[REPLAY_OUTPUT_COLUMNS].sort_values(["symbol", "event_time"]).reset_index(drop=True)
path = root / "replay" / "replay_1m.parquet"
data_hash = write_parquet(path, replay)
write_json(
root / "replay" / "replay_1m.manifest.json",
manifest(
path,
{
"row_count": len(replay),
"hash_sha256": data_hash,
"replay_ready_day_count": len(ready_days),
"excluded_day_count": len(excluded_days),
"min_minutes_per_day": int(args.min_minutes_per_day),
},
),
)
write_json(root / "replay" / "excluded_days.json", excluded_days)
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + "\n")
logging.info("trader.training.replay_written runId=%s rowCount=%s readyDays=%s path=%s", args.run_id, len(replay), len(ready_days), path)
def build_splits(args: Any) -> None:
root = run_root(args)
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
replay = read_parquet(replay_path)
require_columns(replay, ("event_time", "symbol"), "replay_1m")
replay["event_time"] = to_utc_series(replay["event_time"])
replay = replay.sort_values(["event_time", "symbol"]).reset_index(drop=True)
if len(replay) < 10:
raise ValueError("not enough replay rows to build time splits")
gap = int(args.gap_minutes)
intervals = _fixed_split_intervals(args, gap)
replay_start = replay["event_time"].min()
replay_end = replay["event_time"].max()
intervals = [
(split_id, max(start, replay_start), min(end, replay_end))
for split_id, start, end in intervals
if max(start, replay_start) <= min(end, replay_end)
]
if {item[0] for item in intervals} != set(TRAINING_SPLITS):
raise ValueError(f"fixed split dates do not fit replay coverage: replay_start={replay_start} replay_end={replay_end}")
split_manifest = {
"split_version": SPLIT_VERSION,
"created_at": utc_now_text(),
"source_replay_path": str(replay_path),
"gap_minutes": gap,
# Sealed splits are withheld from broad parameter search. They only answer
# whether a finished candidate survives final validation and recent stress.
"sealed_splits": [VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT],
"latest_stress_policy": "FINAL_GATE_ONLY",
"requested_splits": {
FIT_SPLIT: [args.fit_inner_start, args.fit_inner_end],
TUNE_SPLIT: [args.tune_inner_start, args.tune_inner_end],
VALIDATION_LOCKED_SPLIT: [args.validation_locked_start, args.validation_locked_end],
LATEST_STRESS_SPLIT: [args.latest_stress_start, args.latest_stress_end],
},
"splits": [
{"split_id": split_id, "start": start.isoformat().replace("+00:00", "Z"), "end": end.isoformat().replace("+00:00", "Z")}
for split_id, start, end in intervals
if start <= end
],
}
fold_count = max(1, int(args.fold_count))
fit_interval = next(item for item in intervals if item[0] == FIT_SPLIT)
tune_interval = next(item for item in intervals if item[0] == TUNE_SPLIT)
train_times = pd.Series(pd.date_range(fit_interval[1], fit_interval[2], periods=fold_count + 1))
folds = []
for idx in range(fold_count):
folds.append(
{
"walk_forward_fold": f"fold_{idx + 1:02d}",
"train_start": fit_interval[1].isoformat().replace("+00:00", "Z"),
"train_end": train_times.iloc[idx + 1].isoformat().replace("+00:00", "Z"),
"validation_start": tune_interval[1].isoformat().replace("+00:00", "Z"),
"validation_end": tune_interval[2].isoformat().replace("+00:00", "Z"),
}
)
ensure_dir(root / "split")
write_json(root / "split" / "split_manifest.json", split_manifest)
write_json(root / "split" / "walk_forward_folds.json", {"split_version": SPLIT_VERSION, "folds": folds})
_write_purge_embargo_report(root / "split" / "purge_embargo_report.md", intervals, gap)
logging.info("trader.training.splits_written runId=%s splitCount=%s foldCount=%s", args.run_id, len(split_manifest["splits"]), len(folds))
def assign_split(event_times: pd.Series, split_manifest_path: Path) -> pd.Series:
manifest_data = read_json(split_manifest_path)
result = pd.Series("NO_SPLIT", index=event_times.index, dtype="object")
values = to_utc_series(event_times)
for item in manifest_data["splits"]:
start = pd.Timestamp(item["start"])
end = pd.Timestamp(item["end"])
mask = values.between(start, end, inclusive="both")
result.loc[mask] = item["split_id"]
return result
def _fixed_split_intervals(args: Any, gap_minutes: int) -> list[tuple[str, pd.Timestamp, pd.Timestamp]]:
gap = pd.Timedelta(minutes=gap_minutes)
return [
(FIT_SPLIT, _start_of_day(args.fit_inner_start), _end_of_day(args.fit_inner_end) - gap),
(TUNE_SPLIT, _start_of_day(args.tune_inner_start) + gap, _end_of_day(args.tune_inner_end) - gap),
(VALIDATION_LOCKED_SPLIT, _start_of_day(args.validation_locked_start) + gap, _end_of_day(args.validation_locked_end) - gap),
(LATEST_STRESS_SPLIT, _start_of_day(args.latest_stress_start) + gap, _end_of_day(args.latest_stress_end)),
]
def _start_of_day(value: str) -> pd.Timestamp:
return pd.Timestamp(value, tz="UTC")
def _end_of_day(value: str) -> pd.Timestamp:
return pd.Timestamp(value, tz="UTC") + pd.Timedelta(days=1) - pd.Timedelta(minutes=1)
def _write_purge_embargo_report(path: Path, intervals: list[tuple[str, pd.Timestamp, pd.Timestamp]], gap_minutes: int) -> None:
lines = ["# Purge Embargo Report", "", f"- gap_minutes: {gap_minutes}", "", "| split_id | start | end |", "| --- | --- | --- |"]
for split_id, start, end in intervals:
lines.append(f"| {split_id} | {start.isoformat()} | {end.isoformat()} |")
write_text(path, "\n".join(lines) + "\n")