Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,417 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import (
|
||||
manifest,
|
||||
read_json,
|
||||
read_parquet,
|
||||
require_columns,
|
||||
run_root,
|
||||
sha256_json,
|
||||
to_utc_series,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_text,
|
||||
)
|
||||
from trader_training.schemas import LABEL_VERSION
|
||||
|
||||
|
||||
DEFAULT_LABEL_CONFIG = {
|
||||
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
|
||||
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
|
||||
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
|
||||
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
|
||||
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_COST_CONFIG = {
|
||||
"fee_bps": 4.0,
|
||||
"slippage_bps": 2.0,
|
||||
"funding_cost_bps": 0.5,
|
||||
}
|
||||
|
||||
|
||||
def _load_config(path, default):
|
||||
if path is None:
|
||||
return default
|
||||
value = read_json(path)
|
||||
merged = default.copy()
|
||||
for key, item in value.items():
|
||||
if isinstance(item, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = {**merged[key], **item}
|
||||
else:
|
||||
merged[key] = item
|
||||
return merged
|
||||
|
||||
|
||||
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
root = run_root(args)
|
||||
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
||||
replay_path = args.replay_path or root / "replay" / "replay_1m.parquet"
|
||||
features = read_parquet(feature_path)
|
||||
replay = read_parquet(replay_path)
|
||||
require_columns(features, ("sample_id", "symbol", "event_time", "open_time_ms", "split_id", "walk_forward_fold", "data_quality_flag"), "feature_frame")
|
||||
require_columns(replay, ("symbol", "event_time", "open_time_ms", "open", "high", "low", "close", "spread_bps"), "replay_1m")
|
||||
features = features.copy()
|
||||
replay = replay.copy()
|
||||
features["event_time"] = to_utc_series(features["event_time"])
|
||||
replay["event_time"] = to_utc_series(replay["event_time"])
|
||||
replay = replay.sort_values(["symbol", "event_time"]).reset_index(drop=True)
|
||||
return features, replay
|
||||
|
||||
|
||||
def _future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
|
||||
start = index + 1
|
||||
end = min(len(group), index + horizon + 1)
|
||||
return group.iloc[start:end]
|
||||
|
||||
|
||||
def _contiguous_future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
|
||||
path = _future_path(group, index, horizon)
|
||||
if len(path) < horizon:
|
||||
return pd.DataFrame()
|
||||
current_ms = int(group.iloc[index]["open_time_ms"])
|
||||
expected = current_ms + np.arange(1, horizon + 1, dtype=np.int64) * 60_000
|
||||
actual = path["open_time_ms"].astype("int64").to_numpy()
|
||||
if len(actual) != len(expected) or not np.array_equal(actual, expected):
|
||||
return pd.DataFrame()
|
||||
return path
|
||||
|
||||
|
||||
def _side_return_bps(side: str, entry_price: float, exit_price: float) -> float:
|
||||
if side == "LONG":
|
||||
return (exit_price / entry_price - 1.0) * 10000.0
|
||||
return (entry_price / exit_price - 1.0) * 10000.0
|
||||
|
||||
|
||||
def _path_stats(group: pd.DataFrame, index: int, side: str, horizon: int, target_bps: float, stop_bps: float) -> dict[str, Any]:
|
||||
current = group.iloc[index]
|
||||
entry = float(current["close"])
|
||||
path = _contiguous_future_path(group, index, horizon)
|
||||
if path.empty:
|
||||
return {"valid": False}
|
||||
target_price = entry * (1.0 + target_bps / 10000.0) if side == "LONG" else entry * (1.0 - target_bps / 10000.0)
|
||||
stop_price = entry * (1.0 - stop_bps / 10000.0) if side == "LONG" else entry * (1.0 + stop_bps / 10000.0)
|
||||
target_hit = False
|
||||
stop_hit = False
|
||||
ambiguous = False
|
||||
time_to_target_ms = -1
|
||||
time_to_stop_ms = -1
|
||||
for _, row in path.iterrows():
|
||||
high = float(row["high"])
|
||||
low = float(row["low"])
|
||||
if side == "LONG":
|
||||
target_now = high >= target_price
|
||||
stop_now = low <= stop_price
|
||||
else:
|
||||
target_now = low <= target_price
|
||||
stop_now = high >= stop_price
|
||||
if target_now and stop_now:
|
||||
ambiguous = True
|
||||
stop_hit = True
|
||||
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
|
||||
break
|
||||
if target_now:
|
||||
target_hit = True
|
||||
time_to_target_ms = int(row["open_time_ms"] - current["open_time_ms"])
|
||||
break
|
||||
if stop_now:
|
||||
stop_hit = True
|
||||
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
|
||||
break
|
||||
exit_price = float(path.iloc[-1]["close"])
|
||||
final_return_bps = _side_return_bps(side, entry, exit_price)
|
||||
if side == "LONG":
|
||||
mfe = (path["high"].max() / entry - 1.0) * 10000.0
|
||||
mae = (entry / path["low"].min() - 1.0) * 10000.0
|
||||
else:
|
||||
mfe = (entry / path["low"].min() - 1.0) * 10000.0
|
||||
mae = (path["high"].max() / entry - 1.0) * 10000.0
|
||||
if target_hit:
|
||||
gross = target_bps
|
||||
elif stop_hit:
|
||||
gross = -stop_bps
|
||||
else:
|
||||
gross = final_return_bps
|
||||
return {
|
||||
"valid": True,
|
||||
"target_hit": int(target_hit),
|
||||
"stop_hit": int(stop_hit),
|
||||
"timeout_hit": int(not target_hit and not stop_hit),
|
||||
"ambiguous_hit": int(ambiguous),
|
||||
"time_to_target_ms": time_to_target_ms,
|
||||
"time_to_stop_ms": time_to_stop_ms,
|
||||
"gross_edge_bps": float(gross),
|
||||
"future_return_bps": float(final_return_bps),
|
||||
"mfe_bps": float(mfe),
|
||||
"mae_bps": float(mae),
|
||||
"future_spread_p80": float(path["spread_bps"].quantile(0.8)),
|
||||
"future_realized_vol_bps": float(np.log(path["close"].astype(float) / path["close"].astype(float).shift(1)).std() * 10000.0),
|
||||
}
|
||||
|
||||
|
||||
def write_price_plan_context(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
entry = labels["entry"]
|
||||
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
|
||||
context = {
|
||||
"pricePlanId": args.price_plan_id,
|
||||
"pricePlanConfigHash": sha256_json({"entry": entry, "cost": cost}),
|
||||
"stopDistanceBps": float(entry["stop_bps"]),
|
||||
"targetDistanceBps": float(entry["target_bps"]),
|
||||
"maxHoldMinutes": int(entry["max_hold_minutes"]),
|
||||
"costBps": cost_bps,
|
||||
}
|
||||
path = root / "label" / "price_plan_context.json"
|
||||
write_json(path, context)
|
||||
frame = pd.DataFrame([{
|
||||
"price_plan_id": context["pricePlanId"],
|
||||
"price_plan_hash": context["pricePlanConfigHash"],
|
||||
"target_bps": context["targetDistanceBps"],
|
||||
"stop_bps": context["stopDistanceBps"],
|
||||
"max_hold_minutes": context["maxHoldMinutes"],
|
||||
"cost_bps": context["costBps"],
|
||||
}])
|
||||
write_parquet(root / "label" / "price_plan_context.parquet", frame)
|
||||
logging.info("trader.training.price_plan_written runId=%s path=%s", args.run_id, path)
|
||||
|
||||
|
||||
def build_direction_labels(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
config = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)["direction"]
|
||||
features, replay = _base_frames(args)
|
||||
horizon = int(config["horizon_minutes"])
|
||||
replay = replay[["symbol", "event_time", "open_time_ms", "close"]].copy()
|
||||
future = replay[["symbol", "open_time_ms", "close"]].copy()
|
||||
future["open_time_ms"] = future["open_time_ms"].astype("int64") - horizon * 60_000
|
||||
future = future.rename(columns={"close": "future_close"})
|
||||
merged = features.merge(replay[["symbol", "open_time_ms", "close"]], on=["symbol", "open_time_ms"], how="left")
|
||||
merged = merged.merge(future, on=["symbol", "open_time_ms"], how="left")
|
||||
merged["future_return_bps"] = (merged["future_close"] / merged["close"] - 1.0) * 10000.0
|
||||
merged["direction_label"] = np.select(
|
||||
[merged["future_return_bps"] >= float(config["long_threshold_bps"]), merged["future_return_bps"] <= float(config["short_threshold_bps"])],
|
||||
["LONG", "SHORT"],
|
||||
default="NEUTRAL",
|
||||
)
|
||||
out = pd.DataFrame(
|
||||
{
|
||||
"sample_id": merged["sample_id"],
|
||||
"symbol": merged["symbol"],
|
||||
"event_time": merged["event_time"],
|
||||
"horizon_minutes": horizon,
|
||||
"future_return_bps": merged["future_return_bps"],
|
||||
"direction_label": merged["direction_label"],
|
||||
"long_target": merged["direction_label"].eq("LONG").astype("int8"),
|
||||
"short_target": merged["direction_label"].eq("SHORT").astype("int8"),
|
||||
"neutral_target": merged["direction_label"].eq("NEUTRAL").astype("int8"),
|
||||
"split_id": merged["split_id"],
|
||||
"walk_forward_fold": merged["walk_forward_fold"],
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
).dropna(subset=["future_return_bps"])
|
||||
path = root / "label" / "direction_labels.parquet"
|
||||
data_hash = write_parquet(path, out)
|
||||
_write_label_manifest(root / "label" / "direction_labels.manifest.json", path, out, data_hash)
|
||||
_write_distribution_report(root / "label" / "direction_label_report.md", out, "direction_label")
|
||||
logging.info("trader.training.direction_labels_written runId=%s rowCount=%s", args.run_id, len(out))
|
||||
|
||||
|
||||
def build_entry_labels(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
plan_path = args.price_plan_context_path or root / "label" / "price_plan_context.json"
|
||||
plan = read_json(plan_path)
|
||||
features, replay = _base_frames(args)
|
||||
entry_conf = labels["entry"]
|
||||
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
|
||||
rows: list[dict[str, Any]] = []
|
||||
groups, index_by_key = _group_replay_with_index(replay)
|
||||
for feature in features.itertuples(index=False):
|
||||
key = (feature.symbol, int(feature.open_time_ms))
|
||||
index = index_by_key.get(key)
|
||||
if index is None:
|
||||
continue
|
||||
group = groups[feature.symbol]
|
||||
for side in ("LONG", "SHORT"):
|
||||
stats = _path_stats(group, index, side, int(entry_conf["max_hold_minutes"]), float(entry_conf["target_bps"]), float(entry_conf["stop_bps"]))
|
||||
if not stats["valid"]:
|
||||
continue
|
||||
expected = stats["gross_edge_bps"] - cost_bps
|
||||
rows.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"side": side,
|
||||
"price_plan_id": plan["pricePlanId"],
|
||||
"price_plan_hash": plan["pricePlanConfigHash"],
|
||||
"target_hit": stats["target_hit"],
|
||||
"stop_hit": stats["stop_hit"],
|
||||
"timeout_hit": stats["timeout_hit"],
|
||||
"ambiguous_hit": stats["ambiguous_hit"],
|
||||
"time_to_target_ms": stats["time_to_target_ms"],
|
||||
"time_to_stop_ms": stats["time_to_stop_ms"],
|
||||
"gross_edge_bps": stats["gross_edge_bps"],
|
||||
"cost_bps": cost_bps,
|
||||
"expected_net_edge_bps": expected,
|
||||
"entry_target": int(stats["target_hit"] == 1 and expected >= float(entry_conf["min_expected_net_edge_bps"])),
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
out = pd.DataFrame(rows)
|
||||
path = root / "label" / "entry_labels.parquet"
|
||||
data_hash = write_parquet(path, out)
|
||||
_write_label_manifest(root / "label" / "entry_labels.manifest.json", path, out, data_hash)
|
||||
_write_distribution_report(root / "label" / "entry_label_report.md", out, "entry_target")
|
||||
logging.info("trader.training.entry_labels_written runId=%s rowCount=%s", args.run_id, len(out))
|
||||
|
||||
|
||||
def build_position_state_samples(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
entry_path = args.entry_label_path or root / "label" / "entry_labels.parquet"
|
||||
entry = read_parquet(entry_path)
|
||||
if entry.empty:
|
||||
raise ValueError("entry labels are required before building position samples")
|
||||
samples = entry[entry["entry_target"] == 1].copy()
|
||||
samples["position_age_minutes"] = 0
|
||||
samples["unrealized_pnl_bps"] = 0.0
|
||||
samples["mfe_bps"] = samples["gross_edge_bps"].clip(lower=0)
|
||||
samples["mae_bps"] = (-samples["gross_edge_bps"]).clip(lower=0)
|
||||
path = root / "label" / "position_state_samples.parquet"
|
||||
data_hash = write_parquet(path, samples)
|
||||
write_json(root / "label" / "position_state_samples.manifest.json", manifest(path, {"row_count": len(samples), "data_hash_sha256": data_hash}))
|
||||
logging.info("trader.training.position_samples_written runId=%s rowCount=%s", args.run_id, len(samples))
|
||||
|
||||
|
||||
def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
labels = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
cost = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
plan = read_json(args.price_plan_context_path or root / "label" / "price_plan_context.json")
|
||||
features, replay = _base_frames(args)
|
||||
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
|
||||
horizon = int(labels["continue"]["horizon_minutes"])
|
||||
target_bps = float(plan["targetDistanceBps"])
|
||||
stop_bps = float(plan["stopDistanceBps"])
|
||||
rows_continue: list[dict[str, Any]] = []
|
||||
rows_exit: list[dict[str, Any]] = []
|
||||
rows_risk: list[dict[str, Any]] = []
|
||||
groups, index_by_key = _group_replay_with_index(replay)
|
||||
for feature in features.itertuples(index=False):
|
||||
key = (feature.symbol, int(feature.open_time_ms))
|
||||
index = index_by_key.get(key)
|
||||
if index is None:
|
||||
continue
|
||||
group = groups[feature.symbol]
|
||||
long_stats = _path_stats(group, index, "LONG", horizon, target_bps, stop_bps)
|
||||
short_stats = _path_stats(group, index, "SHORT", horizon, target_bps, stop_bps)
|
||||
if not long_stats["valid"] or not short_stats["valid"]:
|
||||
continue
|
||||
long_edge = long_stats["future_return_bps"] - cost_bps
|
||||
short_edge = short_stats["future_return_bps"] - cost_bps
|
||||
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
|
||||
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
|
||||
rows_continue.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"long_continue_target": int(long_edge >= min_continue and long_stats["mae_bps"] < stop_bps),
|
||||
"short_continue_target": int(short_edge >= min_continue and short_stats["mae_bps"] < stop_bps),
|
||||
"long_expected_continue_edge_bps": long_edge,
|
||||
"short_expected_continue_edge_bps": short_edge,
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
stagnation = int(abs(long_stats["future_return_bps"]) <= float(labels["exit"]["stagnation_abs_return_bps"]))
|
||||
rows_exit.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"long_exit_target": int(long_stats["stop_hit"] == 1 or long_stats["mae_bps"] >= adverse_threshold),
|
||||
"short_exit_target": int(short_stats["stop_hit"] == 1 or short_stats["mae_bps"] >= adverse_threshold),
|
||||
"long_adverse_move_bps": long_stats["mae_bps"],
|
||||
"short_adverse_move_bps": short_stats["mae_bps"],
|
||||
"adverse_move_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= adverse_threshold),
|
||||
"reversal_prob_label": int(np.sign(long_stats["future_return_bps"]) != np.sign(feature.ret_15m_bps) if hasattr(feature, "ret_15m_bps") else 0),
|
||||
"stop_hit_prob_label": int(long_stats["stop_hit"] == 1 or short_stats["stop_hit"] == 1),
|
||||
"stagnation_prob_label": stagnation,
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
path_risk = max(long_stats["mae_bps"], short_stats["mae_bps"])
|
||||
vol_ratio = 0.0 if long_stats["future_realized_vol_bps"] != long_stats["future_realized_vol_bps"] else long_stats["future_realized_vol_bps"]
|
||||
rows_risk.append(
|
||||
{
|
||||
"sample_id": feature.sample_id,
|
||||
"symbol": feature.symbol,
|
||||
"event_time": feature.event_time,
|
||||
"market_risk_target": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
|
||||
"market_path_risk_bps": path_risk,
|
||||
"long_position_path_risk_bps": long_stats["mae_bps"],
|
||||
"short_position_path_risk_bps": short_stats["mae_bps"],
|
||||
"long_position_risk_target": int(long_stats["mae_bps"] >= stop_bps),
|
||||
"short_position_risk_target": int(short_stats["mae_bps"] >= stop_bps),
|
||||
"market_drawdown_prob_label": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
|
||||
"volatility_expansion_prob_label": int(vol_ratio >= float(labels["risk"]["spike_bps"])),
|
||||
"spike_prob_label": int(max(long_stats["mfe_bps"], short_stats["mfe_bps"], path_risk) >= float(labels["risk"]["spike_bps"])),
|
||||
"liquidity_deterioration_prob_label": int(long_stats["future_spread_p80"] >= float(replay["spread_bps"].quantile(0.9))),
|
||||
"position_drawdown_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= stop_bps),
|
||||
"split_id": feature.split_id,
|
||||
"walk_forward_fold": feature.walk_forward_fold,
|
||||
"label_version": LABEL_VERSION,
|
||||
}
|
||||
)
|
||||
outputs = [
|
||||
("continue", pd.DataFrame(rows_continue), "long_continue_target"),
|
||||
("exit", pd.DataFrame(rows_exit), "long_exit_target"),
|
||||
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
|
||||
]
|
||||
report_parts = ["# Continue Exit Risk Label Report", ""]
|
||||
for name, frame, target in outputs:
|
||||
path = root / "label" / f"{name}_labels.parquet"
|
||||
data_hash = write_parquet(path, frame)
|
||||
_write_label_manifest(root / "label" / f"{name}_labels.manifest.json", path, frame, data_hash)
|
||||
report_parts.append(f"## {name}")
|
||||
report_parts.append("")
|
||||
report_parts.append(str(frame[target].value_counts(dropna=False).to_dict() if not frame.empty else {}))
|
||||
report_parts.append("")
|
||||
logging.info("trader.training.%s_labels_written runId=%s rowCount=%s", name, args.run_id, len(frame))
|
||||
write_text(root / "label" / "continue_exit_risk_label_report.md", "\n".join(report_parts) + "\n")
|
||||
|
||||
|
||||
def _write_label_manifest(path, parquet_path, frame: pd.DataFrame, data_hash: str) -> None:
|
||||
write_json(path, manifest(parquet_path, {"row_count": len(frame), "label_version": LABEL_VERSION, "data_hash_sha256": data_hash}))
|
||||
|
||||
|
||||
def _write_distribution_report(path, frame: pd.DataFrame, column: str) -> None:
|
||||
counts = frame[column].value_counts(dropna=False).to_dict() if not frame.empty else {}
|
||||
lines = ["# Label Report", "", f"- row_count: {len(frame)}", f"- target_column: {column}", f"- distribution: {counts}", ""]
|
||||
write_text(path, "\n".join(lines))
|
||||
|
||||
|
||||
def _group_replay_with_index(replay: pd.DataFrame) -> tuple[dict[str, pd.DataFrame], dict[tuple[str, int], int]]:
|
||||
groups: dict[str, pd.DataFrame] = {}
|
||||
index_by_key: dict[tuple[str, int], int] = {}
|
||||
for symbol, group in replay.groupby("symbol", sort=False):
|
||||
grouped = group.sort_values("event_time").reset_index(drop=True)
|
||||
groups[symbol] = grouped
|
||||
for idx, row in grouped.iterrows():
|
||||
index_by_key[(symbol, int(row["open_time_ms"]))] = idx
|
||||
return groups, index_by_key
|
||||
Reference in New Issue
Block a user