Improve Trader V4 training pipeline

Align entry labels with max future edge, tune direction labeling, and harden regression evaluation.

Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
This commit is contained in:
Codex
2026-06-27 19:57:29 +08:00
parent e58e4a5572
commit 9acb3460a1
27 changed files with 2059 additions and 341 deletions
+285 -205
View File
@@ -5,6 +5,7 @@ from typing import Any
import numpy as np
import pandas as pd
from numpy.lib.stride_tricks import sliding_window_view
from trader_training.io_utils import (
manifest,
@@ -36,6 +37,8 @@ DEFAULT_COST_CONFIG = {
"funding_cost_bps": 0.5,
}
ENTRY_LABEL_METHOD = "MAX_FUTURE_EDGE_V1"
def _load_config(path, default):
if path is None:
@@ -66,94 +69,144 @@ def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
return features, replay
def _future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
start = index + 1
end = min(len(group), index + horizon + 1)
return group.iloc[start:end]
PATH_STAT_COLUMNS = [
"symbol",
"open_time_ms",
"side",
"target_hit",
"stop_hit",
"timeout_hit",
"ambiguous_hit",
"time_to_target_ms",
"time_to_stop_ms",
"gross_edge_bps",
"future_return_bps",
"mfe_bps",
"mae_bps",
"future_spread_p80",
"future_realized_vol_bps",
]
def _contiguous_future_path(group: pd.DataFrame, index: int, horizon: int) -> pd.DataFrame:
path = _future_path(group, index, horizon)
if len(path) < horizon:
return pd.DataFrame()
current_ms = int(group.iloc[index]["open_time_ms"])
expected = current_ms + np.arange(1, horizon + 1, dtype=np.int64) * 60_000
actual = path["open_time_ms"].astype("int64").to_numpy()
if len(actual) != len(expected) or not np.array_equal(actual, expected):
return pd.DataFrame()
return path
def _empty_path_stats_frame() -> pd.DataFrame:
return pd.DataFrame(columns=PATH_STAT_COLUMNS)
def _side_return_bps(side: str, entry_price: float, exit_price: float) -> float:
if side == "LONG":
return (exit_price / entry_price - 1.0) * 10000.0
return (entry_price / exit_price - 1.0) * 10000.0
def _first_hit_index(hit_window: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
hit_any = hit_window.any(axis=1)
first_idx = np.argmax(hit_window, axis=1)
first_idx = np.where(hit_any, first_idx, hit_window.shape[1] + 1)
return hit_any, first_idx
def _path_stats(group: pd.DataFrame, index: int, side: str, horizon: int, target_bps: float, stop_bps: float) -> dict[str, Any]:
current = group.iloc[index]
entry = float(current["close"])
path = _contiguous_future_path(group, index, horizon)
if path.empty:
return {"valid": False}
target_price = entry * (1.0 + target_bps / 10000.0) if side == "LONG" else entry * (1.0 - target_bps / 10000.0)
stop_price = entry * (1.0 - stop_bps / 10000.0) if side == "LONG" else entry * (1.0 + stop_bps / 10000.0)
target_hit = False
stop_hit = False
ambiguous = False
time_to_target_ms = -1
time_to_stop_ms = -1
for _, row in path.iterrows():
high = float(row["high"])
low = float(row["low"])
if side == "LONG":
target_now = high >= target_price
stop_now = low <= stop_price
else:
target_now = low <= target_price
stop_now = high >= stop_price
if target_now and stop_now:
ambiguous = True
stop_hit = True
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
break
if target_now:
target_hit = True
time_to_target_ms = int(row["open_time_ms"] - current["open_time_ms"])
break
if stop_now:
stop_hit = True
time_to_stop_ms = int(row["open_time_ms"] - current["open_time_ms"])
break
exit_price = float(path.iloc[-1]["close"])
final_return_bps = _side_return_bps(side, entry, exit_price)
if side == "LONG":
mfe = (path["high"].max() / entry - 1.0) * 10000.0
mae = (entry / path["low"].min() - 1.0) * 10000.0
def _path_stats_for_group(group: pd.DataFrame, side: str, horizon: int, target_bps: float, stop_bps: float) -> pd.DataFrame:
if len(group) <= horizon:
return _empty_path_stats_frame()
grouped = group.sort_values("event_time").reset_index(drop=True)
open_ms = grouped["open_time_ms"].astype("int64").to_numpy()
close = grouped["close"].astype("float64").to_numpy()
high = grouped["high"].astype("float64").to_numpy()
low = grouped["low"].astype("float64").to_numpy()
spread = grouped["spread_bps"].astype("float64").to_numpy()
entry = close[:-horizon]
exit_price = close[horizon:]
current_open_ms = open_ms[:-horizon]
bad_gap = (np.diff(open_ms) != 60_000).astype("int64")
gap_cumsum = np.concatenate(([0], np.cumsum(bad_gap)))
contiguous = (gap_cumsum[horizon:] - gap_cumsum[:-horizon]) == 0
finite = np.isfinite(entry) & np.isfinite(exit_price)
valid = contiguous & finite
future_high = sliding_window_view(high[1:], horizon)
future_low = sliding_window_view(low[1:], horizon)
future_spread = sliding_window_view(spread[1:], horizon)
with np.errstate(all="ignore"):
high_max = np.nanmax(future_high, axis=1)
low_min = np.nanmin(future_low, axis=1)
spread_p80 = np.nanquantile(future_spread, 0.8, axis=1)
if horizon > 1:
log_close = np.log(np.clip(close, 1e-12, None))
log_return = np.diff(log_close)
future_log_return = sliding_window_view(log_return[1:], horizon - 1)
with np.errstate(all="ignore"):
realized_vol_bps = np.nanstd(future_log_return, axis=1, ddof=1) * 10000.0
else:
mfe = (entry / path["low"].min() - 1.0) * 10000.0
mae = (path["high"].max() / entry - 1.0) * 10000.0
if target_hit:
gross = target_bps
elif stop_hit:
gross = -stop_bps
realized_vol_bps = np.full(len(entry), np.nan)
if side == "LONG":
target_price = entry * (1.0 + target_bps / 10000.0)
stop_price = entry * (1.0 - stop_bps / 10000.0)
target_window = future_high >= target_price[:, None]
stop_window = future_low <= stop_price[:, None]
future_return_bps = (exit_price / entry - 1.0) * 10000.0
mfe_bps = (high_max / entry - 1.0) * 10000.0
mae_bps = (entry / low_min - 1.0) * 10000.0
else:
gross = final_return_bps
return {
"valid": True,
"target_hit": int(target_hit),
"stop_hit": int(stop_hit),
"timeout_hit": int(not target_hit and not stop_hit),
"ambiguous_hit": int(ambiguous),
"time_to_target_ms": time_to_target_ms,
"time_to_stop_ms": time_to_stop_ms,
"gross_edge_bps": float(gross),
"future_return_bps": float(final_return_bps),
"mfe_bps": float(mfe),
"mae_bps": float(mae),
"future_spread_p80": float(path["spread_bps"].quantile(0.8)),
"future_realized_vol_bps": float(np.log(path["close"].astype(float) / path["close"].astype(float).shift(1)).std() * 10000.0),
}
target_price = entry * (1.0 - target_bps / 10000.0)
stop_price = entry * (1.0 + stop_bps / 10000.0)
target_window = future_low <= target_price[:, None]
stop_window = future_high >= stop_price[:, None]
future_return_bps = (entry / exit_price - 1.0) * 10000.0
mfe_bps = (entry / low_min - 1.0) * 10000.0
mae_bps = (high_max / entry - 1.0) * 10000.0
target_any, first_target_idx = _first_hit_index(target_window)
stop_any, first_stop_idx = _first_hit_index(stop_window)
ambiguous_hit = target_any & stop_any & (first_target_idx == first_stop_idx)
target_hit = target_any & (first_target_idx < first_stop_idx)
stop_hit = stop_any & (first_stop_idx <= first_target_idx)
timeout_hit = ~(target_hit | stop_hit)
gross_edge_bps = np.where(target_hit, target_bps, np.where(stop_hit, -stop_bps, future_return_bps))
out = pd.DataFrame(
{
"symbol": grouped["symbol"].iloc[0],
"open_time_ms": current_open_ms,
"side": side,
"target_hit": target_hit.astype("int8"),
"stop_hit": stop_hit.astype("int8"),
"timeout_hit": timeout_hit.astype("int8"),
"ambiguous_hit": ambiguous_hit.astype("int8"),
"time_to_target_ms": np.where(target_hit, (first_target_idx + 1) * 60_000, -1).astype("int64"),
"time_to_stop_ms": np.where(stop_hit, (first_stop_idx + 1) * 60_000, -1).astype("int64"),
"gross_edge_bps": gross_edge_bps.astype("float64"),
"future_return_bps": future_return_bps.astype("float64"),
"mfe_bps": mfe_bps.astype("float64"),
"mae_bps": mae_bps.astype("float64"),
"future_spread_p80": spread_p80.astype("float64"),
"future_realized_vol_bps": realized_vol_bps.astype("float64"),
}
)
return out.loc[valid, PATH_STAT_COLUMNS].reset_index(drop=True)
def _build_path_stats(replay: pd.DataFrame, horizon: int, target_bps: float, stop_bps: float) -> pd.DataFrame:
frames: list[pd.DataFrame] = []
for symbol, group in replay.groupby("symbol", sort=False, observed=False):
logging.info(
"trader.training.path_stats_group_start symbol=%s horizonMinutes=%s rowCount=%s",
symbol,
horizon,
len(group),
)
for side in ("LONG", "SHORT"):
stats = _path_stats_for_group(group, side, horizon, target_bps, stop_bps)
frames.append(stats)
logging.info(
"trader.training.path_stats_side_done symbol=%s side=%s horizonMinutes=%s rowCount=%s",
symbol,
side,
horizon,
len(stats),
)
out = pd.concat(frames, ignore_index=True) if frames else _empty_path_stats_frame()
logging.info("trader.training.path_stats_built horizonMinutes=%s rowCount=%s", horizon, len(out))
return out
def write_price_plan_context(args: Any) -> None:
@@ -164,11 +217,13 @@ def write_price_plan_context(args: Any) -> None:
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
context = {
"pricePlanId": args.price_plan_id,
"pricePlanConfigHash": sha256_json({"entry": entry, "cost": cost}),
"pricePlanConfigHash": sha256_json({"entry": entry, "cost": cost, "entry_label_method": ENTRY_LABEL_METHOD}),
"stopDistanceBps": float(entry["stop_bps"]),
"targetDistanceBps": float(entry["target_bps"]),
"maxHoldMinutes": int(entry["max_hold_minutes"]),
"minExpectedNetEdgeBps": float(entry["min_expected_net_edge_bps"]),
"costBps": cost_bps,
"entryLabelMethod": ENTRY_LABEL_METHOD,
}
path = root / "label" / "price_plan_context.json"
write_json(path, context)
@@ -178,7 +233,9 @@ def write_price_plan_context(args: Any) -> None:
"target_bps": context["targetDistanceBps"],
"stop_bps": context["stopDistanceBps"],
"max_hold_minutes": context["maxHoldMinutes"],
"min_expected_net_edge_bps": context["minExpectedNetEdgeBps"],
"cost_bps": context["costBps"],
"entry_label_method": context["entryLabelMethod"],
}])
write_parquet(root / "label" / "price_plan_context.parquet", frame)
logging.info("trader.training.price_plan_written runId=%s path=%s", args.run_id, path)
@@ -233,43 +290,62 @@ def build_entry_labels(args: Any) -> None:
features, replay = _base_frames(args)
entry_conf = labels["entry"]
cost_bps = float(cost["fee_bps"]) + float(cost["slippage_bps"]) + float(cost["funding_cost_bps"])
rows: list[dict[str, Any]] = []
groups, index_by_key = _group_replay_with_index(replay)
for feature in features.itertuples(index=False):
key = (feature.symbol, int(feature.open_time_ms))
index = index_by_key.get(key)
if index is None:
continue
group = groups[feature.symbol]
for side in ("LONG", "SHORT"):
stats = _path_stats(group, index, side, int(entry_conf["max_hold_minutes"]), float(entry_conf["target_bps"]), float(entry_conf["stop_bps"]))
if not stats["valid"]:
continue
expected = stats["gross_edge_bps"] - cost_bps
rows.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"side": side,
"price_plan_id": plan["pricePlanId"],
"price_plan_hash": plan["pricePlanConfigHash"],
"target_hit": stats["target_hit"],
"stop_hit": stats["stop_hit"],
"timeout_hit": stats["timeout_hit"],
"ambiguous_hit": stats["ambiguous_hit"],
"time_to_target_ms": stats["time_to_target_ms"],
"time_to_stop_ms": stats["time_to_stop_ms"],
"gross_edge_bps": stats["gross_edge_bps"],
"cost_bps": cost_bps,
"expected_net_edge_bps": expected,
"entry_target": int(stats["target_hit"] == 1 and expected >= float(entry_conf["min_expected_net_edge_bps"])),
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
out = pd.DataFrame(rows)
stats = _build_path_stats(
replay,
int(entry_conf["max_hold_minutes"]),
float(entry_conf["target_bps"]),
float(entry_conf["stop_bps"]),
)
feature_columns = [
"sample_id",
"symbol",
"event_time",
"open_time_ms",
"split_id",
"walk_forward_fold",
"spread_bps",
"spread_rank_24h_pct",
"realized_vol_15m_bps",
]
merged = features[feature_columns].merge(stats, on=["symbol", "open_time_ms"], how="inner")
merged["max_achievable_gross_edge_bps"] = merged["mfe_bps"]
merged["max_achievable_net_edge_bps"] = merged["max_achievable_gross_edge_bps"] - cost_bps
merged["expected_net_edge_bps"] = merged["max_achievable_net_edge_bps"]
merged["entry_target"] = (merged["max_achievable_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
merged["price_plan_id"] = plan["pricePlanId"]
merged["price_plan_hash"] = plan["pricePlanConfigHash"]
merged["cost_bps"] = cost_bps
merged["label_method"] = ENTRY_LABEL_METHOD
merged["label_version"] = LABEL_VERSION
out = merged[
[
"sample_id",
"symbol",
"event_time",
"side",
"price_plan_id",
"price_plan_hash",
"target_hit",
"stop_hit",
"timeout_hit",
"ambiguous_hit",
"time_to_target_ms",
"time_to_stop_ms",
"gross_edge_bps",
"future_return_bps",
"mfe_bps",
"mae_bps",
"max_achievable_gross_edge_bps",
"max_achievable_net_edge_bps",
"cost_bps",
"expected_net_edge_bps",
"entry_target",
"label_method",
"split_id",
"walk_forward_fold",
"label_version",
]
].copy()
path = root / "label" / "entry_labels.parquet"
data_hash = write_parquet(path, out)
_write_label_manifest(root / "label" / "entry_labels.manifest.json", path, out, data_hash)
@@ -286,8 +362,8 @@ def build_position_state_samples(args: Any) -> None:
samples = entry[entry["entry_target"] == 1].copy()
samples["position_age_minutes"] = 0
samples["unrealized_pnl_bps"] = 0.0
samples["mfe_bps"] = samples["gross_edge_bps"].clip(lower=0)
samples["mae_bps"] = (-samples["gross_edge_bps"]).clip(lower=0)
samples["mfe_bps"] = pd.to_numeric(samples["mfe_bps"], errors="coerce").fillna(0.0).clip(lower=0)
samples["mae_bps"] = pd.to_numeric(samples["mae_bps"], errors="coerce").fillna(0.0).clip(lower=0)
path = root / "label" / "position_state_samples.parquet"
data_hash = write_parquet(path, samples)
write_json(root / "label" / "position_state_samples.manifest.json", manifest(path, {"row_count": len(samples), "data_hash_sha256": data_hash}))
@@ -304,80 +380,95 @@ def build_continue_exit_risk_labels(args: Any) -> None:
horizon = int(labels["continue"]["horizon_minutes"])
target_bps = float(plan["targetDistanceBps"])
stop_bps = float(plan["stopDistanceBps"])
rows_continue: list[dict[str, Any]] = []
rows_exit: list[dict[str, Any]] = []
rows_risk: list[dict[str, Any]] = []
groups, index_by_key = _group_replay_with_index(replay)
for feature in features.itertuples(index=False):
key = (feature.symbol, int(feature.open_time_ms))
index = index_by_key.get(key)
if index is None:
continue
group = groups[feature.symbol]
long_stats = _path_stats(group, index, "LONG", horizon, target_bps, stop_bps)
short_stats = _path_stats(group, index, "SHORT", horizon, target_bps, stop_bps)
if not long_stats["valid"] or not short_stats["valid"]:
continue
long_edge = long_stats["future_return_bps"] - cost_bps
short_edge = short_stats["future_return_bps"] - cost_bps
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
rows_continue.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"long_continue_target": int(long_edge >= min_continue and long_stats["mae_bps"] < stop_bps),
"short_continue_target": int(short_edge >= min_continue and short_stats["mae_bps"] < stop_bps),
"long_expected_continue_edge_bps": long_edge,
"short_expected_continue_edge_bps": short_edge,
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
stagnation = int(abs(long_stats["future_return_bps"]) <= float(labels["exit"]["stagnation_abs_return_bps"]))
rows_exit.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"long_exit_target": int(long_stats["stop_hit"] == 1 or long_stats["mae_bps"] >= adverse_threshold),
"short_exit_target": int(short_stats["stop_hit"] == 1 or short_stats["mae_bps"] >= adverse_threshold),
"long_adverse_move_bps": long_stats["mae_bps"],
"short_adverse_move_bps": short_stats["mae_bps"],
"adverse_move_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= adverse_threshold),
"reversal_prob_label": int(np.sign(long_stats["future_return_bps"]) != np.sign(feature.ret_15m_bps) if hasattr(feature, "ret_15m_bps") else 0),
"stop_hit_prob_label": int(long_stats["stop_hit"] == 1 or short_stats["stop_hit"] == 1),
"stagnation_prob_label": stagnation,
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
path_risk = max(long_stats["mae_bps"], short_stats["mae_bps"])
vol_ratio = 0.0 if long_stats["future_realized_vol_bps"] != long_stats["future_realized_vol_bps"] else long_stats["future_realized_vol_bps"]
rows_risk.append(
{
"sample_id": feature.sample_id,
"symbol": feature.symbol,
"event_time": feature.event_time,
"market_risk_target": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
"market_path_risk_bps": path_risk,
"long_position_path_risk_bps": long_stats["mae_bps"],
"short_position_path_risk_bps": short_stats["mae_bps"],
"long_position_risk_target": int(long_stats["mae_bps"] >= stop_bps),
"short_position_risk_target": int(short_stats["mae_bps"] >= stop_bps),
"market_drawdown_prob_label": int(path_risk >= float(labels["risk"]["market_drawdown_bps"])),
"volatility_expansion_prob_label": int(vol_ratio >= float(labels["risk"]["spike_bps"])),
"spike_prob_label": int(max(long_stats["mfe_bps"], short_stats["mfe_bps"], path_risk) >= float(labels["risk"]["spike_bps"])),
"liquidity_deterioration_prob_label": int(long_stats["future_spread_p80"] >= float(replay["spread_bps"].quantile(0.9))),
"position_drawdown_prob_label": int(max(long_stats["mae_bps"], short_stats["mae_bps"]) >= stop_bps),
"split_id": feature.split_id,
"walk_forward_fold": feature.walk_forward_fold,
"label_version": LABEL_VERSION,
}
)
stats = _build_path_stats(replay, horizon, target_bps, stop_bps)
long_stats = stats[stats["side"] == "LONG"].drop(columns=["side"]).add_prefix("long_")
short_stats = stats[stats["side"] == "SHORT"].drop(columns=["side"]).add_prefix("short_")
long_stats = long_stats.rename(columns={"long_symbol": "symbol", "long_open_time_ms": "open_time_ms"})
short_stats = short_stats.rename(columns={"short_symbol": "symbol", "short_open_time_ms": "open_time_ms"})
feature_columns = [
"sample_id",
"symbol",
"event_time",
"open_time_ms",
"split_id",
"walk_forward_fold",
"spread_bps",
"spread_rank_24h_pct",
"realized_vol_15m_bps",
]
if "ret_15m_bps" in features.columns:
feature_columns.append("ret_15m_bps")
merged = features[feature_columns].merge(long_stats, on=["symbol", "open_time_ms"], how="inner")
merged = merged.merge(short_stats, on=["symbol", "open_time_ms"], how="inner")
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
current_vol = merged["realized_vol_15m_bps"].astype(float).fillna(0.0).clip(lower=1.0)
long_edge = merged["long_future_return_bps"] - cost_bps
short_edge = merged["short_future_return_bps"] - cost_bps
path_risk = np.maximum(merged["long_mae_bps"], merged["short_mae_bps"])
max_path_move = np.maximum.reduce([merged["long_mfe_bps"], merged["short_mfe_bps"], path_risk])
if "ret_15m_bps" in merged.columns:
reversal = (np.sign(merged["long_future_return_bps"]) != np.sign(merged["ret_15m_bps"])).astype("int8")
else:
reversal = pd.Series(0, index=merged.index, dtype="int8")
future_vol = merged["long_future_realized_vol_bps"].fillna(0.0)
volatility_expansion = future_vol >= current_vol * float(labels["risk"]["vol_expansion_ratio"])
liquidity_deterioration = merged["spread_rank_24h_pct"].astype(float).fillna(0.0) >= 0.90
rows_continue = pd.DataFrame(
{
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"long_continue_target": ((long_edge >= min_continue) & (merged["long_mae_bps"] < stop_bps)).astype("int8"),
"short_continue_target": ((short_edge >= min_continue) & (merged["short_mae_bps"] < stop_bps)).astype("int8"),
"long_expected_continue_edge_bps": long_edge,
"short_expected_continue_edge_bps": short_edge,
"split_id": merged["split_id"],
"walk_forward_fold": merged["walk_forward_fold"],
"label_version": LABEL_VERSION,
}
)
rows_exit = pd.DataFrame(
{
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"long_exit_target": ((merged["long_stop_hit"] == 1) | (merged["long_mae_bps"] >= adverse_threshold)).astype("int8"),
"short_exit_target": ((merged["short_stop_hit"] == 1) | (merged["short_mae_bps"] >= adverse_threshold)).astype("int8"),
"long_adverse_move_bps": merged["long_mae_bps"],
"short_adverse_move_bps": merged["short_mae_bps"],
"adverse_move_prob_label": (path_risk >= adverse_threshold).astype("int8"),
"reversal_prob_label": reversal,
"stop_hit_prob_label": ((merged["long_stop_hit"] == 1) | (merged["short_stop_hit"] == 1)).astype("int8"),
"stagnation_prob_label": (merged["long_future_return_bps"].abs() <= float(labels["exit"]["stagnation_abs_return_bps"])).astype("int8"),
"split_id": merged["split_id"],
"walk_forward_fold": merged["walk_forward_fold"],
"label_version": LABEL_VERSION,
}
)
rows_risk = pd.DataFrame(
{
"sample_id": merged["sample_id"],
"symbol": merged["symbol"],
"event_time": merged["event_time"],
"market_risk_target": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
"market_path_risk_bps": path_risk,
"long_position_path_risk_bps": merged["long_mae_bps"],
"short_position_path_risk_bps": merged["short_mae_bps"],
"long_position_risk_target": (merged["long_mae_bps"] >= stop_bps).astype("int8"),
"short_position_risk_target": (merged["short_mae_bps"] >= stop_bps).astype("int8"),
"market_drawdown_prob_label": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
"volatility_expansion_prob_label": volatility_expansion.astype("int8"),
"spike_prob_label": (max_path_move >= float(labels["risk"]["spike_bps"])).astype("int8"),
"liquidity_deterioration_prob_label": liquidity_deterioration.astype("int8"),
"position_drawdown_prob_label": (path_risk >= stop_bps).astype("int8"),
"split_id": merged["split_id"],
"walk_forward_fold": merged["walk_forward_fold"],
"label_version": LABEL_VERSION,
}
)
outputs = [
("continue", pd.DataFrame(rows_continue), "long_continue_target"),
("exit", pd.DataFrame(rows_exit), "long_exit_target"),
@@ -404,14 +495,3 @@ def _write_distribution_report(path, frame: pd.DataFrame, column: str) -> None:
counts = frame[column].value_counts(dropna=False).to_dict() if not frame.empty else {}
lines = ["# Label Report", "", f"- row_count: {len(frame)}", f"- target_column: {column}", f"- distribution: {counts}", ""]
write_text(path, "\n".join(lines))
def _group_replay_with_index(replay: pd.DataFrame) -> tuple[dict[str, pd.DataFrame], dict[tuple[str, int], int]]:
groups: dict[str, pd.DataFrame] = {}
index_by_key: dict[tuple[str, int], int] = {}
for symbol, group in replay.groupby("symbol", sort=False):
grouped = group.sort_values("event_time").reset_index(drop=True)
groups[symbol] = grouped
for idx, row in grouped.iterrows():
index_by_key[(symbol, int(row["open_time_ms"]))] = idx
return groups, index_by_key