Align Continue labels with price plan outcomes
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.io_utils import write_json
|
||||
from trader_training.labels import build_continue_exit_risk_labels
|
||||
|
||||
|
||||
class ContinueLabelsTest(unittest.TestCase):
|
||||
def test_continue_label_uses_first_price_plan_barrier_not_later_mae(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
run_root = data_root / "trader-v4" / "runs" / "unit-continue"
|
||||
feature_path = run_root / "feature" / "feature_frame.parquet"
|
||||
replay_path = run_root / "replay" / "replay_1m.parquet"
|
||||
plan_path = run_root / "label" / "price_plan_context.json"
|
||||
config_path = data_root / "label_config.json"
|
||||
feature_path.parent.mkdir(parents=True)
|
||||
replay_path.parent.mkdir(parents=True)
|
||||
plan_path.parent.mkdir(parents=True)
|
||||
|
||||
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
|
||||
pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s0"],
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"event_time": [times[0]],
|
||||
"open_time_ms": [0],
|
||||
"split_id": ["fit_inner"],
|
||||
"walk_forward_fold": [0],
|
||||
"data_quality_flag": ["OK"],
|
||||
"spread_bps": [1.0],
|
||||
"spread_rank_24h_pct": [0.1],
|
||||
"realized_vol_15m_bps": [2.0],
|
||||
}
|
||||
).to_parquet(feature_path, index=False)
|
||||
pd.DataFrame(
|
||||
{
|
||||
"event_time": times,
|
||||
"open_time_ms": [0, 60_000, 120_000, 180_000, 240_000],
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"open": [100.0, 100.0, 100.0, 100.0, 100.0],
|
||||
"high": [100.0, 100.30, 100.0, 100.0, 100.0],
|
||||
"low": [100.0, 100.0, 99.80, 100.0, 100.0],
|
||||
"close": [100.0, 100.10, 100.0, 100.0, 100.0],
|
||||
"spread_bps": [1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
).to_parquet(replay_path, index=False)
|
||||
write_json(
|
||||
config_path,
|
||||
{
|
||||
"continue": {"horizon_minutes": 3, "min_expected_continue_edge_bps": 5.0},
|
||||
"entry": {"target_bps": 20.0, "stop_bps": 8.0, "max_hold_minutes": 3},
|
||||
},
|
||||
)
|
||||
write_json(
|
||||
plan_path,
|
||||
{
|
||||
"pricePlanId": "unit-plan",
|
||||
"pricePlanConfigHash": "unit-hash",
|
||||
"targetDistanceBps": 20.0,
|
||||
"stopDistanceBps": 8.0,
|
||||
"maxHoldMinutes": 3,
|
||||
"costBps": 6.5,
|
||||
"entryLabelMethod": "PRICE_PLAN_OUTCOME_V1",
|
||||
},
|
||||
)
|
||||
|
||||
build_continue_exit_risk_labels(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-continue",
|
||||
feature_path=feature_path,
|
||||
replay_path=replay_path,
|
||||
label_config_path=config_path,
|
||||
cost_config_path=None,
|
||||
price_plan_context_path=plan_path,
|
||||
)
|
||||
)
|
||||
|
||||
labels = pd.read_parquet(run_root / "label" / "continue_labels.parquet")
|
||||
row = labels.iloc[0]
|
||||
self.assertEqual(1, int(row["long_continue_target"]))
|
||||
self.assertAlmostEqual(13.5, float(row["long_expected_continue_edge_bps"]), places=6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -33,7 +33,9 @@ class StateContinueExperimentTest(unittest.TestCase):
|
||||
"high_since_entry": 100.2,
|
||||
"low_since_entry": 99.95,
|
||||
"future_return_bps": 12.0,
|
||||
"mae_bps": 3.0,
|
||||
"gross_edge_bps": 12.0,
|
||||
"mae_bps": 20.0,
|
||||
"stop_hit": 0,
|
||||
"entry_predicted_edge_bps": 8.5,
|
||||
"entry_direction_prob": 0.64,
|
||||
"add_count": 0.0,
|
||||
|
||||
@@ -25,9 +25,15 @@ from trader_training.schemas import LABEL_VERSION
|
||||
DEFAULT_LABEL_CONFIG = {
|
||||
"direction": {"horizon_minutes": 45, "long_threshold_bps": 5.0, "short_threshold_bps": -5.0},
|
||||
"entry": {"max_hold_minutes": 45, "target_bps": 12.0, "stop_bps": 8.0, "min_expected_net_edge_bps": 3.0},
|
||||
"continue": {"horizon_minutes": 30, "min_expected_continue_edge_bps": 2.0},
|
||||
"exit": {"horizon_minutes": 30, "adverse_move_bps": 8.0, "stagnation_abs_return_bps": 2.0},
|
||||
"risk": {"horizon_minutes": 30, "market_drawdown_bps": 12.0, "vol_expansion_ratio": 1.6, "spike_bps": 20.0},
|
||||
"continue": {"horizon_minutes": 45, "min_expected_continue_edge_bps": 5.0},
|
||||
"exit": {"horizon_minutes": 45, "adverse_move_bps": 20.0, "stagnation_abs_return_bps": 5.0},
|
||||
"risk": {
|
||||
"horizon_minutes": 45,
|
||||
"market_drawdown_bps": 60.0,
|
||||
"position_path_risk_bps": 35.0,
|
||||
"vol_expansion_ratio": 1.8,
|
||||
"spike_bps": 80.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +43,7 @@ DEFAULT_COST_CONFIG = {
|
||||
"funding_cost_bps": 0.5,
|
||||
}
|
||||
|
||||
ENTRY_LABEL_METHOD = "MAX_FUTURE_EDGE_V1"
|
||||
ENTRY_LABEL_METHOD = "PRICE_PLAN_OUTCOME_V1"
|
||||
|
||||
|
||||
def _load_config(path, default):
|
||||
@@ -53,6 +59,13 @@ def _load_config(path, default):
|
||||
return merged
|
||||
|
||||
|
||||
def _config_number(config: dict[str, Any], keys: tuple[str, ...], default: float) -> float:
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return float(config[key])
|
||||
return default
|
||||
|
||||
|
||||
def _base_frames(args: Any) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
root = run_root(args)
|
||||
feature_path = args.feature_path or root / "feature" / "feature_frame.parquet"
|
||||
@@ -144,16 +157,16 @@ def _path_stats_for_group(group: pd.DataFrame, side: str, horizon: int, target_b
|
||||
target_window = future_high >= target_price[:, None]
|
||||
stop_window = future_low <= stop_price[:, None]
|
||||
future_return_bps = (exit_price / entry - 1.0) * 10000.0
|
||||
mfe_bps = (high_max / entry - 1.0) * 10000.0
|
||||
mae_bps = (entry / low_min - 1.0) * 10000.0
|
||||
mfe_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
|
||||
mae_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
|
||||
else:
|
||||
target_price = entry * (1.0 - target_bps / 10000.0)
|
||||
stop_price = entry * (1.0 + stop_bps / 10000.0)
|
||||
target_window = future_low <= target_price[:, None]
|
||||
stop_window = future_high >= stop_price[:, None]
|
||||
future_return_bps = (entry / exit_price - 1.0) * 10000.0
|
||||
mfe_bps = (entry / low_min - 1.0) * 10000.0
|
||||
mae_bps = (high_max / entry - 1.0) * 10000.0
|
||||
mfe_bps = np.maximum((entry / low_min - 1.0) * 10000.0, 0.0)
|
||||
mae_bps = np.maximum((high_max / entry - 1.0) * 10000.0, 0.0)
|
||||
|
||||
target_any, first_target_idx = _first_hit_index(target_window)
|
||||
stop_any, first_stop_idx = _first_hit_index(stop_window)
|
||||
@@ -310,8 +323,8 @@ def build_entry_labels(args: Any) -> None:
|
||||
merged = features[feature_columns].merge(stats, on=["symbol", "open_time_ms"], how="inner")
|
||||
merged["max_achievable_gross_edge_bps"] = merged["mfe_bps"]
|
||||
merged["max_achievable_net_edge_bps"] = merged["max_achievable_gross_edge_bps"] - cost_bps
|
||||
merged["expected_net_edge_bps"] = merged["max_achievable_net_edge_bps"]
|
||||
merged["entry_target"] = (merged["max_achievable_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
|
||||
merged["expected_net_edge_bps"] = merged["gross_edge_bps"] - cost_bps
|
||||
merged["entry_target"] = (merged["expected_net_edge_bps"] >= float(entry_conf["min_expected_net_edge_bps"])).astype("int8")
|
||||
merged["price_plan_id"] = plan["pricePlanId"]
|
||||
merged["price_plan_hash"] = plan["pricePlanConfigHash"]
|
||||
merged["cost_bps"] = cost_bps
|
||||
@@ -403,9 +416,22 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
min_continue = float(labels["continue"]["min_expected_continue_edge_bps"])
|
||||
adverse_threshold = float(labels["exit"]["adverse_move_bps"])
|
||||
current_vol = merged["realized_vol_15m_bps"].astype(float).fillna(0.0).clip(lower=1.0)
|
||||
risk_config = labels["risk"]
|
||||
market_risk_threshold = _config_number(
|
||||
risk_config,
|
||||
("market_path_risk_threshold_bps", "market_drawdown_bps"),
|
||||
60.0,
|
||||
)
|
||||
position_risk_threshold = _config_number(
|
||||
risk_config,
|
||||
("position_path_risk_threshold_bps", "position_path_risk_bps"),
|
||||
35.0,
|
||||
)
|
||||
spike_threshold = _config_number(risk_config, ("spike_1m_threshold_bps", "spike_bps"), 80.0)
|
||||
vol_expansion_ratio = _config_number(risk_config, ("vol_expansion_ratio",), 1.8)
|
||||
|
||||
long_edge = merged["long_future_return_bps"] - cost_bps
|
||||
short_edge = merged["short_future_return_bps"] - cost_bps
|
||||
long_edge = merged["long_gross_edge_bps"] - cost_bps
|
||||
short_edge = merged["short_gross_edge_bps"] - cost_bps
|
||||
path_risk = np.maximum(merged["long_mae_bps"], merged["short_mae_bps"])
|
||||
max_path_move = np.maximum.reduce([merged["long_mfe_bps"], merged["short_mfe_bps"], path_risk])
|
||||
if "ret_15m_bps" in merged.columns:
|
||||
@@ -413,7 +439,9 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
else:
|
||||
reversal = pd.Series(0, index=merged.index, dtype="int8")
|
||||
future_vol = merged["long_future_realized_vol_bps"].fillna(0.0)
|
||||
volatility_expansion = future_vol >= current_vol * float(labels["risk"]["vol_expansion_ratio"])
|
||||
volatility_expansion = future_vol >= current_vol * vol_expansion_ratio
|
||||
spike = max_path_move >= spike_threshold
|
||||
market_risk = (path_risk >= market_risk_threshold) | spike | volatility_expansion
|
||||
liquidity_deterioration = merged["spread_rank_24h_pct"].astype(float).fillna(0.0) >= 0.90
|
||||
|
||||
rows_continue = pd.DataFrame(
|
||||
@@ -421,8 +449,8 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
"sample_id": merged["sample_id"],
|
||||
"symbol": merged["symbol"],
|
||||
"event_time": merged["event_time"],
|
||||
"long_continue_target": ((long_edge >= min_continue) & (merged["long_mae_bps"] < stop_bps)).astype("int8"),
|
||||
"short_continue_target": ((short_edge >= min_continue) & (merged["short_mae_bps"] < stop_bps)).astype("int8"),
|
||||
"long_continue_target": ((long_edge >= min_continue) & (merged["long_stop_hit"] == 0)).astype("int8"),
|
||||
"short_continue_target": ((short_edge >= min_continue) & (merged["short_stop_hit"] == 0)).astype("int8"),
|
||||
"long_expected_continue_edge_bps": long_edge,
|
||||
"short_expected_continue_edge_bps": short_edge,
|
||||
"split_id": merged["split_id"],
|
||||
@@ -453,17 +481,17 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
"sample_id": merged["sample_id"],
|
||||
"symbol": merged["symbol"],
|
||||
"event_time": merged["event_time"],
|
||||
"market_risk_target": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
|
||||
"market_risk_target": market_risk.astype("int8"),
|
||||
"market_path_risk_bps": path_risk,
|
||||
"long_position_path_risk_bps": merged["long_mae_bps"],
|
||||
"short_position_path_risk_bps": merged["short_mae_bps"],
|
||||
"long_position_risk_target": (merged["long_mae_bps"] >= stop_bps).astype("int8"),
|
||||
"short_position_risk_target": (merged["short_mae_bps"] >= stop_bps).astype("int8"),
|
||||
"market_drawdown_prob_label": (path_risk >= float(labels["risk"]["market_drawdown_bps"])).astype("int8"),
|
||||
"long_position_risk_target": ((merged["long_mae_bps"] >= position_risk_threshold) | (merged["long_stop_hit"] == 1)).astype("int8"),
|
||||
"short_position_risk_target": ((merged["short_mae_bps"] >= position_risk_threshold) | (merged["short_stop_hit"] == 1)).astype("int8"),
|
||||
"market_drawdown_prob_label": (path_risk >= market_risk_threshold).astype("int8"),
|
||||
"volatility_expansion_prob_label": volatility_expansion.astype("int8"),
|
||||
"spike_prob_label": (max_path_move >= float(labels["risk"]["spike_bps"])).astype("int8"),
|
||||
"spike_prob_label": spike.astype("int8"),
|
||||
"liquidity_deterioration_prob_label": liquidity_deterioration.astype("int8"),
|
||||
"position_drawdown_prob_label": (path_risk >= stop_bps).astype("int8"),
|
||||
"position_drawdown_prob_label": (path_risk >= position_risk_threshold).astype("int8"),
|
||||
"split_id": merged["split_id"],
|
||||
"walk_forward_fold": merged["walk_forward_fold"],
|
||||
"label_version": LABEL_VERSION,
|
||||
@@ -475,6 +503,21 @@ def build_continue_exit_risk_labels(args: Any) -> None:
|
||||
("risk", pd.DataFrame(rows_risk), "market_risk_target"),
|
||||
]
|
||||
report_parts = ["# Continue Exit Risk Label Report", ""]
|
||||
report_parts.extend(
|
||||
[
|
||||
"## Risk Thresholds",
|
||||
"",
|
||||
str(
|
||||
{
|
||||
"market_risk_threshold_bps": market_risk_threshold,
|
||||
"position_risk_threshold_bps": position_risk_threshold,
|
||||
"spike_threshold_bps": spike_threshold,
|
||||
"vol_expansion_ratio": vol_expansion_ratio,
|
||||
}
|
||||
),
|
||||
"",
|
||||
]
|
||||
)
|
||||
for name, frame, target in outputs:
|
||||
path = root / "label" / f"{name}_labels.parquet"
|
||||
data_hash = write_parquet(path, frame)
|
||||
|
||||
@@ -380,8 +380,9 @@ def _state_rows_for_age(frame: pd.DataFrame, stop_bps: float, target_bps: float,
|
||||
target_price = np.where(long_mask, entry_price * (1.0 + target_bps / 10000.0), entry_price * (1.0 - target_bps / 10000.0))
|
||||
distance_to_stop = np.where(long_mask, (current_price / stop_price - 1.0) * 10000.0, (stop_price / current_price - 1.0) * 10000.0)
|
||||
distance_to_target = np.where(long_mask, (target_price / current_price - 1.0) * 10000.0, (current_price / target_price - 1.0) * 10000.0)
|
||||
expected_edge = frame["future_return_bps"].astype(float) - cost_bps
|
||||
continue_target = ((expected_edge >= min_continue_edge_bps) & (frame["mae_bps"].astype(float) < stop_bps)).astype("int8")
|
||||
# Continue must score the first price-plan outcome from the current state, not the raw horizon close.
|
||||
expected_edge = frame["gross_edge_bps"].astype(float) - cost_bps
|
||||
continue_target = ((expected_edge >= min_continue_edge_bps) & (frame["stop_hit"].astype(int) == 0)).astype("int8")
|
||||
|
||||
out = frame[
|
||||
[
|
||||
@@ -601,6 +602,8 @@ def _source_manifest(
|
||||
"uses_same_round_model_prediction_as_feature": False,
|
||||
"entry_predicted_edge_bps": "baseline frozen ENTRY ONNX output selected by side",
|
||||
"entry_direction_prob": "baseline frozen DIRECTION ONNX output selected by side",
|
||||
"expected_continue_edge_bps": "price-plan gross edge minus cost; target/stop/timeout outcome is respected",
|
||||
"continue_target": "expected_continue_edge_bps >= threshold and stop is not the first path barrier",
|
||||
"path_features": "position path shape and side-adjusted market pressure at current state time",
|
||||
"add_count": "synthetic first-position diagnostic, fixed to 0",
|
||||
"minutes_since_last_add": "synthetic first-position diagnostic, fixed to 9999",
|
||||
|
||||
Reference in New Issue
Block a user