9acb3460a1
Align entry labels with max future edge, tune direction labeling, and harden regression evaluation. Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
551 lines
24 KiB
Python
551 lines
24 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from trader_training.io_utils import (
|
|
DEFAULT_RAW_ROOT,
|
|
ensure_dir,
|
|
manifest,
|
|
open_time_ms,
|
|
partition_files,
|
|
read_json,
|
|
read_partitioned_table,
|
|
read_parquet,
|
|
require_columns,
|
|
run_root,
|
|
to_utc_series,
|
|
utc_now_text,
|
|
write_json,
|
|
write_parquet,
|
|
write_text,
|
|
)
|
|
from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, SPLIT_VERSION, TRAINING_SPLITS, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
|
|
|
|
|
|
def audit_source_data(data_root: Path, symbol: str, start_date: str | None, end_date: str | None, min_ready_days: int = 250) -> dict[str, Any]:
|
|
raw_root = data_root / "crypto-lake" / "raw"
|
|
required_tables = ("candles", "trades", "level_1", "funding", "open_interest")
|
|
optional_tables = ("liquidations",)
|
|
rows: list[dict[str, Any]] = []
|
|
table_dates: dict[str, set[str]] = {}
|
|
for table in required_tables + optional_tables:
|
|
files = partition_files(raw_root, table, symbol, start_date, end_date)
|
|
dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files})
|
|
table_dates[table] = set(dates)
|
|
rows.append(
|
|
{
|
|
"table": table,
|
|
"required": table in required_tables,
|
|
"file_count": len(files),
|
|
"first_date": dates[0] if dates else None,
|
|
"last_date": dates[-1] if dates else None,
|
|
"status": "OK" if files or table in optional_tables else "MISSING",
|
|
}
|
|
)
|
|
all_dates = _audit_date_range(table_dates, required_tables, start_date, end_date)
|
|
replay_ready_days = []
|
|
excluded_days = []
|
|
for day in all_dates:
|
|
missing_required = [table for table in required_tables if day not in table_dates[table]]
|
|
missing_optional = [table for table in optional_tables if day not in table_dates[table]]
|
|
if missing_required:
|
|
excluded_days.append({"date": day, "reason": "MISSING_REQUIRED_TABLE", "missing_required_tables": missing_required, "missing_optional_tables": missing_optional})
|
|
else:
|
|
replay_ready_days.append(day)
|
|
result = {
|
|
"symbol": symbol,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"raw_root": str(raw_root),
|
|
"tables": rows,
|
|
"replay_ready_day_count": len(replay_ready_days),
|
|
"excluded_day_count": len(excluded_days),
|
|
"replay_ready_days": replay_ready_days,
|
|
"excluded_days": excluded_days,
|
|
"created_at": utc_now_text(),
|
|
"ready": all(row["status"] == "OK" for row in rows if row["required"]) and len(replay_ready_days) >= min_ready_days,
|
|
}
|
|
return result
|
|
|
|
|
|
def write_audit_outputs(args: Any) -> None:
|
|
root = run_root(args)
|
|
result = audit_source_data(args.data_root, args.symbol, args.start_date, args.end_date, int(args.min_ready_days))
|
|
path = root / "raw-manifest" / "source_data_audit.json"
|
|
write_json(path, result)
|
|
write_json(root / "raw-manifest" / "source_data_manifest.json", result)
|
|
write_json(root / "raw-manifest" / "excluded_days.json", result["excluded_days"])
|
|
write_text(root / "raw-manifest" / "replay_ready_days.txt", "\n".join(result["replay_ready_days"]) + ("\n" if result["replay_ready_days"] else ""))
|
|
report_lines = [
|
|
"# Trader Source Data Audit",
|
|
"",
|
|
f"- symbol: {result['symbol']}",
|
|
f"- raw_root: {result['raw_root']}",
|
|
f"- ready: {result['ready']}",
|
|
f"- replay_ready_day_count: {result['replay_ready_day_count']}",
|
|
f"- excluded_day_count: {result['excluded_day_count']}",
|
|
"",
|
|
"| table | required | file_count | first_date | last_date | status |",
|
|
"| --- | --- | ---: | --- | --- | --- |",
|
|
]
|
|
for row in result["tables"]:
|
|
report_lines.append(
|
|
f"| {row['table']} | {row['required']} | {row['file_count']} | {row['first_date']} | {row['last_date']} | {row['status']} |"
|
|
)
|
|
write_text(root / "raw-manifest" / "source_data_audit.md", "\n".join(report_lines) + "\n")
|
|
logging.info(
|
|
"trader.training.audit_written runId=%s ready=%s readyDays=%s excludedDays=%s path=%s",
|
|
args.run_id,
|
|
result["ready"],
|
|
result["replay_ready_day_count"],
|
|
result["excluded_day_count"],
|
|
path,
|
|
)
|
|
if not result["ready"]:
|
|
raise SystemExit("required raw tables are missing; see source_data_audit.md")
|
|
|
|
|
|
def _audit_date_range(table_dates: dict[str, set[str]], required_tables: tuple[str, ...], start_date: str | None, end_date: str | None) -> list[str]:
|
|
if start_date and end_date:
|
|
start = pd.Timestamp(start_date)
|
|
end = pd.Timestamp(end_date)
|
|
else:
|
|
dates = sorted(set().union(*(table_dates[table] for table in required_tables)))
|
|
if not dates:
|
|
return []
|
|
start = pd.Timestamp(start_date or dates[0])
|
|
end = pd.Timestamp(end_date or dates[-1])
|
|
return [day.strftime("%Y-%m-%d") for day in pd.date_range(start, end, freq="D")]
|
|
|
|
|
|
def _minute_frame(frame: pd.DataFrame, time_column: str = "origin_time") -> pd.DataFrame:
|
|
frame = frame.copy()
|
|
frame["event_time"] = to_utc_series(frame[time_column]).dt.floor("min")
|
|
frame["open_time_ms"] = open_time_ms(frame["event_time"])
|
|
return frame
|
|
|
|
|
|
def _read_candles(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
|
candles = read_partitioned_table(
|
|
raw_root,
|
|
"candles",
|
|
symbol,
|
|
start_date,
|
|
end_date,
|
|
columns=("origin_time", "start", "open", "high", "low", "close", "volume", "symbol"),
|
|
)
|
|
if candles.empty:
|
|
raise ValueError("candles raw data is required to build replay_1m")
|
|
time_col = "start" if "start" in candles.columns else "origin_time"
|
|
candles = _minute_frame(candles, time_col)
|
|
keep = ["symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "volume"]
|
|
candles = candles[keep].sort_values(["symbol", "event_time"]).drop_duplicates(["symbol", "event_time"], keep="last")
|
|
for column in ("open", "high", "low", "close", "volume"):
|
|
candles[column] = pd.to_numeric(candles[column], errors="coerce")
|
|
return candles
|
|
|
|
|
|
def _read_trades(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
|
trades = read_partitioned_table(
|
|
raw_root,
|
|
"trades",
|
|
symbol,
|
|
start_date,
|
|
end_date,
|
|
columns=("origin_time", "side", "quantity", "symbol"),
|
|
)
|
|
if trades.empty:
|
|
raise ValueError("trades raw data is required for taker imbalance")
|
|
trades = _minute_frame(trades)
|
|
trades["quantity"] = pd.to_numeric(trades["quantity"], errors="coerce").fillna(0.0)
|
|
side = trades["side"].astype(str).str.upper()
|
|
trades["taker_buy_volume"] = np.where(side.eq("BUY"), trades["quantity"], 0.0)
|
|
trades["taker_sell_volume"] = np.where(side.eq("SELL"), trades["quantity"], 0.0)
|
|
return (
|
|
trades.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[["taker_buy_volume", "taker_sell_volume"]]
|
|
.sum()
|
|
)
|
|
|
|
|
|
def _read_level1(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
|
level1 = read_partitioned_table(
|
|
raw_root,
|
|
"level_1",
|
|
symbol,
|
|
start_date,
|
|
end_date,
|
|
columns=("origin_time", "bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size", "symbol"),
|
|
)
|
|
if level1.empty:
|
|
raise ValueError("level_1 raw data is required for spread and OFI")
|
|
level1 = _minute_frame(level1)
|
|
for column in ("bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"):
|
|
level1[column] = pd.to_numeric(level1[column], errors="coerce")
|
|
level1 = level1.dropna(subset=["bid_0_price", "ask_0_price", "bid_0_size", "ask_0_size"])
|
|
level1 = level1.sort_values(["symbol", "event_time", "origin_time"])
|
|
group = level1.groupby("symbol", sort=False, observed=True)
|
|
prev_bid_price = group["bid_0_price"].shift(1)
|
|
prev_bid_size = group["bid_0_size"].shift(1)
|
|
prev_ask_price = group["ask_0_price"].shift(1)
|
|
prev_ask_size = group["ask_0_size"].shift(1)
|
|
bid_ofi = np.select(
|
|
[level1["bid_0_price"] > prev_bid_price, level1["bid_0_price"].eq(prev_bid_price)],
|
|
[level1["bid_0_size"], level1["bid_0_size"] - prev_bid_size],
|
|
default=-prev_bid_size,
|
|
)
|
|
ask_ofi = np.select(
|
|
[level1["ask_0_price"] < prev_ask_price, level1["ask_0_price"].eq(prev_ask_price)],
|
|
[level1["ask_0_size"], prev_ask_size - level1["ask_0_size"]],
|
|
default=-prev_ask_size,
|
|
)
|
|
level1["ofi_raw"] = np.nan_to_num(bid_ofi + ask_ofi, nan=0.0)
|
|
level1["depth"] = (level1["bid_0_size"] + level1["ask_0_size"]).clip(lower=1e-12)
|
|
level1["mid"] = (level1["bid_0_price"] + level1["ask_0_price"]) / 2.0
|
|
level1["spread_bps"] = (level1["ask_0_price"] - level1["bid_0_price"]) / level1["mid"] * 10000.0
|
|
agg = level1.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True).agg(
|
|
best_bid_price=("bid_0_price", "last"),
|
|
best_ask_price=("ask_0_price", "last"),
|
|
spread_bps=("spread_bps", "last"),
|
|
ofi_sum=("ofi_raw", "sum"),
|
|
depth_mean=("depth", "mean"),
|
|
)
|
|
agg["level1_ofi_1m"] = agg["ofi_sum"] / agg["depth_mean"].clip(lower=1e-12)
|
|
return agg.drop(columns=["ofi_sum", "depth_mean"])
|
|
|
|
|
|
def _read_liquidations(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> pd.DataFrame:
|
|
files = partition_files(raw_root, "liquidations", symbol, start_date, end_date)
|
|
if not files:
|
|
return pd.DataFrame(columns=["symbol", "event_time", "open_time_ms", "liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"])
|
|
liquidations = read_partitioned_table(
|
|
raw_root,
|
|
"liquidations",
|
|
symbol,
|
|
start_date,
|
|
end_date,
|
|
columns=("origin_time", "side", "quantity", "price", "symbol"),
|
|
)
|
|
liquidations = _minute_frame(liquidations)
|
|
liquidations["quantity"] = pd.to_numeric(liquidations["quantity"], errors="coerce").fillna(0.0)
|
|
liquidations["price"] = pd.to_numeric(liquidations["price"], errors="coerce").fillna(0.0)
|
|
liquidations["notional"] = liquidations["quantity"] * liquidations["price"]
|
|
side = liquidations["side"].astype(str).str.upper()
|
|
liquidations["liquidation_buy_notional_1m"] = np.where(side.eq("BUY"), liquidations["notional"], 0.0)
|
|
liquidations["liquidation_sell_notional_1m"] = np.where(side.eq("SELL"), liquidations["notional"], 0.0)
|
|
agg = liquidations.groupby(["symbol", "event_time", "open_time_ms"], as_index=False, observed=True)[
|
|
["liquidation_buy_notional_1m", "liquidation_sell_notional_1m"]
|
|
].sum()
|
|
agg["liquidation_available"] = 1.0
|
|
return agg
|
|
|
|
|
|
def _asof_column(
|
|
replay: pd.DataFrame,
|
|
raw_root: Path,
|
|
table: str,
|
|
symbol: str,
|
|
start_date: str | None,
|
|
end_date: str | None,
|
|
columns: tuple[str, ...],
|
|
) -> pd.DataFrame:
|
|
frame = read_partitioned_table(raw_root, table, symbol, start_date, end_date, columns=("origin_time", "symbol", *columns))
|
|
if frame.empty:
|
|
raise ValueError(f"{table} raw data is required")
|
|
frame = _minute_frame(frame)
|
|
for column in columns:
|
|
if column.endswith("time"):
|
|
continue
|
|
frame[column] = pd.to_numeric(frame[column], errors="coerce")
|
|
frame = frame.sort_values(["symbol", "event_time"])
|
|
left = replay[["symbol", "event_time"]].sort_values(["symbol", "event_time"])
|
|
merged = pd.merge_asof(
|
|
left,
|
|
frame[["symbol", "event_time", *columns]].sort_values(["symbol", "event_time"]),
|
|
by="symbol",
|
|
on="event_time",
|
|
direction="backward",
|
|
tolerance=pd.Timedelta(hours=12),
|
|
)
|
|
return merged
|
|
|
|
|
|
REPLAY_REQUIRED_COLUMNS = [
|
|
"open",
|
|
"high",
|
|
"low",
|
|
"close",
|
|
"volume",
|
|
"best_bid_price",
|
|
"best_ask_price",
|
|
"spread_bps",
|
|
"level1_ofi_1m",
|
|
"funding_bps",
|
|
"mark_price",
|
|
"index_price",
|
|
"open_interest",
|
|
]
|
|
|
|
REPLAY_OUTPUT_COLUMNS = [
|
|
"symbol",
|
|
"timeframe",
|
|
"event_time",
|
|
"open_time_ms",
|
|
"open",
|
|
"high",
|
|
"low",
|
|
"close",
|
|
"volume",
|
|
"taker_buy_volume",
|
|
"taker_sell_volume",
|
|
"funding_bps",
|
|
"mark_price",
|
|
"index_price",
|
|
"next_funding_time",
|
|
"open_interest",
|
|
"best_bid_price",
|
|
"best_ask_price",
|
|
"spread_bps",
|
|
"level1_ofi_1m",
|
|
"liquidation_buy_notional_1m",
|
|
"liquidation_sell_notional_1m",
|
|
"liquidation_available",
|
|
"source_coverage",
|
|
]
|
|
|
|
|
|
def _replay_date_texts(raw_root: Path, symbol: str, start_date: str | None, end_date: str | None) -> list[str]:
|
|
if start_date and end_date:
|
|
return [day.strftime("%Y-%m-%d") for day in pd.date_range(pd.Timestamp(start_date), pd.Timestamp(end_date), freq="D")]
|
|
files = partition_files(raw_root, "candles", symbol, start_date, end_date)
|
|
dates = sorted({next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") for file in files})
|
|
return [date for date in dates if date]
|
|
|
|
|
|
def _previous_date_text(day: str) -> str:
|
|
return (pd.Timestamp(day) - pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
|
|
|
|
|
def _build_replay_day(raw_root: Path, symbol: str, day: str) -> pd.DataFrame:
|
|
replay = _read_candles(raw_root, symbol, day, day)
|
|
replay = replay[replay["event_time"].dt.strftime("%Y-%m-%d").eq(day)].copy()
|
|
trades = _read_trades(raw_root, symbol, day, day)
|
|
level1 = _read_level1(raw_root, symbol, day, day)
|
|
liquidations = _read_liquidations(raw_root, symbol, day, day)
|
|
replay = replay.merge(trades, on=["symbol", "event_time", "open_time_ms"], how="left")
|
|
replay = replay.merge(level1, on=["symbol", "event_time", "open_time_ms"], how="left")
|
|
replay = replay.merge(liquidations, on=["symbol", "event_time", "open_time_ms"], how="left")
|
|
replay[["taker_buy_volume", "taker_sell_volume"]] = replay[["taker_buy_volume", "taker_sell_volume"]].fillna(0.0)
|
|
for column in ("liquidation_buy_notional_1m", "liquidation_sell_notional_1m", "liquidation_available"):
|
|
replay[column] = pd.to_numeric(replay[column], errors="coerce").fillna(0.0)
|
|
|
|
# Funding and open interest are as-of values. Include the previous UTC day so
|
|
# the first minutes of a day can use the last known value without reading the
|
|
# whole training window into memory.
|
|
lookback_start = _previous_date_text(day)
|
|
funding = _asof_column(replay, raw_root, "funding", symbol, lookback_start, day, ("rate", "mark_price", "index_price", "next_funding_time"))
|
|
funding = funding.rename(columns={"rate": "funding_rate"})
|
|
funding["funding_bps"] = pd.to_numeric(funding["funding_rate"], errors="coerce") * 10000.0
|
|
replay = replay.merge(funding.drop(columns=["funding_rate"]), on=["symbol", "event_time"], how="left")
|
|
replay["next_funding_time"] = to_utc_series(replay["next_funding_time"])
|
|
|
|
oi = _asof_column(replay, raw_root, "open_interest", symbol, lookback_start, day, ("open_interest",))
|
|
replay = replay.merge(oi, on=["symbol", "event_time"], how="left")
|
|
replay["timeframe"] = "1m"
|
|
replay["source_coverage"] = "crypto_lake_raw"
|
|
replay["event_date"] = replay["event_time"].dt.strftime("%Y-%m-%d")
|
|
return replay
|
|
|
|
|
|
def build_replay_1m(args: Any) -> None:
|
|
root = run_root(args)
|
|
raw_root = args.raw_root or DEFAULT_RAW_ROOT
|
|
logging.info("trader.training.replay_started runId=%s symbol=%s rawRoot=%s", args.run_id, args.symbol, raw_root)
|
|
dates = _replay_date_texts(raw_root, args.symbol, args.start_date, args.end_date)
|
|
if not dates:
|
|
raise ValueError("no candle dates are available for replay_1m")
|
|
|
|
ready_days: list[str] = []
|
|
excluded_days: list[dict[str, Any]] = []
|
|
ready_frames: list[pd.DataFrame] = []
|
|
row_before_filter = 0
|
|
for index, day in enumerate(dates, start=1):
|
|
logging.info("trader.training.replay_day_started runId=%s day=%s index=%s total=%s", args.run_id, day, index, len(dates))
|
|
try:
|
|
day_replay = _build_replay_day(raw_root, args.symbol, day)
|
|
except Exception as exc:
|
|
excluded_days.append(
|
|
{
|
|
"date": day,
|
|
"row_count": 0,
|
|
"missing_required_rows": 0,
|
|
"reason": "DAY_BUILD_FAILED",
|
|
"error": str(exc),
|
|
}
|
|
)
|
|
logging.warning("trader.training.replay_day_failed runId=%s day=%s error=%s", args.run_id, day, exc)
|
|
continue
|
|
|
|
row_count = len(day_replay)
|
|
row_before_filter += row_count
|
|
missing_required_rows = int(day_replay[REPLAY_REQUIRED_COLUMNS].isna().any(axis=1).sum())
|
|
ready = row_count >= int(args.min_minutes_per_day) and missing_required_rows == 0
|
|
if ready:
|
|
ready_days.append(day)
|
|
ready_frames.append(day_replay[REPLAY_OUTPUT_COLUMNS].copy())
|
|
else:
|
|
excluded_days.append(
|
|
{
|
|
"date": day,
|
|
"row_count": int(row_count),
|
|
"missing_required_rows": missing_required_rows,
|
|
"reason": "MISSING_REQUIRED_MARKET_FIELDS" if missing_required_rows else "INCOMPLETE_MINUTE_COUNT",
|
|
}
|
|
)
|
|
logging.info(
|
|
"trader.training.replay_day_finished runId=%s day=%s ready=%s rows=%s missingRequiredRows=%s",
|
|
args.run_id,
|
|
day,
|
|
ready,
|
|
row_count,
|
|
missing_required_rows,
|
|
)
|
|
|
|
if len(ready_days) < int(args.min_replay_ready_days):
|
|
write_json(root / "replay" / "excluded_days.json", excluded_days)
|
|
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + ("\n" if ready_days else ""))
|
|
raise ValueError(f"replay_1m has only {len(ready_days)} replay-ready days, required {args.min_replay_ready_days}")
|
|
replay = pd.concat(ready_frames, ignore_index=True)
|
|
logging.info(
|
|
"trader.training.replay_ready_days_selected runId=%s readyDays=%s excludedDays=%s rowBefore=%s rowAfter=%s",
|
|
args.run_id,
|
|
len(ready_days),
|
|
len(excluded_days),
|
|
row_before_filter,
|
|
len(replay),
|
|
)
|
|
replay = replay[REPLAY_OUTPUT_COLUMNS].sort_values(["symbol", "event_time"]).reset_index(drop=True)
|
|
path = root / "replay" / "replay_1m.parquet"
|
|
data_hash = write_parquet(path, replay)
|
|
write_json(
|
|
root / "replay" / "replay_1m.manifest.json",
|
|
manifest(
|
|
path,
|
|
{
|
|
"row_count": len(replay),
|
|
"hash_sha256": data_hash,
|
|
"replay_ready_day_count": len(ready_days),
|
|
"excluded_day_count": len(excluded_days),
|
|
"min_minutes_per_day": int(args.min_minutes_per_day),
|
|
},
|
|
),
|
|
)
|
|
write_json(root / "replay" / "excluded_days.json", excluded_days)
|
|
write_text(root / "replay" / "replay_ready_days.txt", "\n".join(ready_days) + "\n")
|
|
logging.info("trader.training.replay_written runId=%s rowCount=%s readyDays=%s path=%s", args.run_id, len(replay), len(ready_days), path)
|
|
|
|
|
|
def build_splits(args: Any) -> None:
|
|
root = run_root(args)
|
|
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
|
|
replay = read_parquet(replay_path)
|
|
require_columns(replay, ("event_time", "symbol"), "replay_1m")
|
|
replay["event_time"] = to_utc_series(replay["event_time"])
|
|
replay = replay.sort_values(["event_time", "symbol"]).reset_index(drop=True)
|
|
if len(replay) < 10:
|
|
raise ValueError("not enough replay rows to build time splits")
|
|
gap = int(args.gap_minutes)
|
|
intervals = _fixed_split_intervals(args, gap)
|
|
replay_start = replay["event_time"].min()
|
|
replay_end = replay["event_time"].max()
|
|
intervals = [
|
|
(split_id, max(start, replay_start), min(end, replay_end))
|
|
for split_id, start, end in intervals
|
|
if max(start, replay_start) <= min(end, replay_end)
|
|
]
|
|
if {item[0] for item in intervals} != set(TRAINING_SPLITS):
|
|
raise ValueError(f"fixed split dates do not fit replay coverage: replay_start={replay_start} replay_end={replay_end}")
|
|
split_manifest = {
|
|
"split_version": SPLIT_VERSION,
|
|
"created_at": utc_now_text(),
|
|
"source_replay_path": str(replay_path),
|
|
"gap_minutes": gap,
|
|
# Sealed splits are withheld from broad parameter search. They only answer
|
|
# whether a finished candidate survives final validation and recent stress.
|
|
"sealed_splits": [VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT],
|
|
"latest_stress_policy": "FINAL_GATE_ONLY",
|
|
"requested_splits": {
|
|
FIT_SPLIT: [args.fit_inner_start, args.fit_inner_end],
|
|
TUNE_SPLIT: [args.tune_inner_start, args.tune_inner_end],
|
|
VALIDATION_LOCKED_SPLIT: [args.validation_locked_start, args.validation_locked_end],
|
|
LATEST_STRESS_SPLIT: [args.latest_stress_start, args.latest_stress_end],
|
|
},
|
|
"splits": [
|
|
{"split_id": split_id, "start": start.isoformat().replace("+00:00", "Z"), "end": end.isoformat().replace("+00:00", "Z")}
|
|
for split_id, start, end in intervals
|
|
if start <= end
|
|
],
|
|
}
|
|
fold_count = max(1, int(args.fold_count))
|
|
fit_interval = next(item for item in intervals if item[0] == FIT_SPLIT)
|
|
tune_interval = next(item for item in intervals if item[0] == TUNE_SPLIT)
|
|
train_times = pd.Series(pd.date_range(fit_interval[1], fit_interval[2], periods=fold_count + 1))
|
|
folds = []
|
|
for idx in range(fold_count):
|
|
folds.append(
|
|
{
|
|
"walk_forward_fold": f"fold_{idx + 1:02d}",
|
|
"train_start": fit_interval[1].isoformat().replace("+00:00", "Z"),
|
|
"train_end": train_times.iloc[idx + 1].isoformat().replace("+00:00", "Z"),
|
|
"validation_start": tune_interval[1].isoformat().replace("+00:00", "Z"),
|
|
"validation_end": tune_interval[2].isoformat().replace("+00:00", "Z"),
|
|
}
|
|
)
|
|
ensure_dir(root / "split")
|
|
write_json(root / "split" / "split_manifest.json", split_manifest)
|
|
write_json(root / "split" / "walk_forward_folds.json", {"split_version": SPLIT_VERSION, "folds": folds})
|
|
_write_purge_embargo_report(root / "split" / "purge_embargo_report.md", intervals, gap)
|
|
logging.info("trader.training.splits_written runId=%s splitCount=%s foldCount=%s", args.run_id, len(split_manifest["splits"]), len(folds))
|
|
|
|
|
|
def assign_split(event_times: pd.Series, split_manifest_path: Path) -> pd.Series:
|
|
manifest_data = read_json(split_manifest_path)
|
|
result = pd.Series("NO_SPLIT", index=event_times.index, dtype="object")
|
|
values = to_utc_series(event_times)
|
|
for item in manifest_data["splits"]:
|
|
start = pd.Timestamp(item["start"])
|
|
end = pd.Timestamp(item["end"])
|
|
mask = values.between(start, end, inclusive="both")
|
|
result.loc[mask] = item["split_id"]
|
|
return result
|
|
|
|
|
|
def _fixed_split_intervals(args: Any, gap_minutes: int) -> list[tuple[str, pd.Timestamp, pd.Timestamp]]:
|
|
gap = pd.Timedelta(minutes=gap_minutes)
|
|
return [
|
|
(FIT_SPLIT, _start_of_day(args.fit_inner_start), _end_of_day(args.fit_inner_end) - gap),
|
|
(TUNE_SPLIT, _start_of_day(args.tune_inner_start) + gap, _end_of_day(args.tune_inner_end) - gap),
|
|
(VALIDATION_LOCKED_SPLIT, _start_of_day(args.validation_locked_start) + gap, _end_of_day(args.validation_locked_end) - gap),
|
|
(LATEST_STRESS_SPLIT, _start_of_day(args.latest_stress_start) + gap, _end_of_day(args.latest_stress_end)),
|
|
]
|
|
|
|
|
|
def _start_of_day(value: str) -> pd.Timestamp:
|
|
return pd.Timestamp(value, tz="UTC")
|
|
|
|
|
|
def _end_of_day(value: str) -> pd.Timestamp:
|
|
return pd.Timestamp(value, tz="UTC") + pd.Timedelta(days=1) - pd.Timedelta(minutes=1)
|
|
|
|
|
|
def _write_purge_embargo_report(path: Path, intervals: list[tuple[str, pd.Timestamp, pd.Timestamp]], gap_minutes: int) -> None:
|
|
lines = ["# Purge Embargo Report", "", f"- gap_minutes: {gap_minutes}", "", "| split_id | start | end |", "| --- | --- | --- |"]
|
|
for split_id, start, end in intervals:
|
|
lines.append(f"| {split_id} | {start.isoformat()} | {end.isoformat()} |")
|
|
write_text(path, "\n".join(lines) + "\n")
|