Improve Trader entry quality training diagnostics
This commit is contained in:
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
TRAINING_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.labels import DEFAULT_LABEL_CONFIG, _path_stats_for_group
|
||||
from trader_training.pm import _probability_implied_edge, _simulate_open_trades, _threshold_candidates, default_pm_config
|
||||
|
||||
|
||||
class RiskPmFixTest(unittest.TestCase):
|
||||
def test_path_stats_never_writes_negative_adverse_or_favorable_move(self) -> None:
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"event_time": pd.date_range("2026-01-01", periods=4, freq="min", tz="UTC"),
|
||||
"open_time_ms": np.arange(4, dtype=np.int64) * 60_000,
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"close": [100.0, 101.0, 102.0, 103.0],
|
||||
"high": [100.0, 101.0, 102.0, 103.0],
|
||||
"low": [100.0, 101.0, 102.0, 103.0],
|
||||
"spread_bps": [1.0, 1.0, 1.0, 1.0],
|
||||
}
|
||||
)
|
||||
|
||||
long_stats = _path_stats_for_group(frame, "LONG", horizon=2, target_bps=500.0, stop_bps=500.0)
|
||||
short_stats = _path_stats_for_group(frame, "SHORT", horizon=2, target_bps=500.0, stop_bps=500.0)
|
||||
|
||||
self.assertGreaterEqual(float(long_stats["mae_bps"].min()), 0.0)
|
||||
self.assertGreaterEqual(float(long_stats["mfe_bps"].min()), 0.0)
|
||||
self.assertGreaterEqual(float(short_stats["mae_bps"].min()), 0.0)
|
||||
self.assertGreaterEqual(float(short_stats["mfe_bps"].min()), 0.0)
|
||||
|
||||
def test_default_risk_labels_match_design_thresholds(self) -> None:
|
||||
self.assertEqual(45, DEFAULT_LABEL_CONFIG["continue"]["horizon_minutes"])
|
||||
self.assertEqual(60.0, DEFAULT_LABEL_CONFIG["risk"]["market_drawdown_bps"])
|
||||
self.assertEqual(35.0, DEFAULT_LABEL_CONFIG["risk"]["position_path_risk_bps"])
|
||||
self.assertEqual(80.0, DEFAULT_LABEL_CONFIG["risk"]["spike_bps"])
|
||||
self.assertEqual(1.8, DEFAULT_LABEL_CONFIG["risk"]["vol_expansion_ratio"])
|
||||
|
||||
def test_pm_search_covers_low_entry_probability_without_allowing_negative_edge(self) -> None:
|
||||
candidates = _threshold_candidates()
|
||||
|
||||
self.assertTrue(candidates)
|
||||
self.assertLessEqual(max(item["max_market_risk_prob"] for item in candidates), 0.98)
|
||||
self.assertLessEqual(min(item["min_entry_prob"] for item in candidates), 0.03)
|
||||
self.assertGreaterEqual(min(item["min_expected_edge_bps"] for item in candidates), 0.0)
|
||||
|
||||
def test_probability_implied_edge_uses_price_plan_payoff(self) -> None:
|
||||
edge = _probability_implied_edge(
|
||||
pd.Series([0.10, 0.50]),
|
||||
{"targetDistanceBps": 120.0, "stopDistanceBps": 2.0, "costBps": 6.5},
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(3.7, float(edge.iloc[0]), places=6)
|
||||
self.assertAlmostEqual(52.5, float(edge.iloc[1]), places=6)
|
||||
|
||||
def test_pm_backtest_sizing_uses_position_manager_formula_not_fixed_floor(self) -> None:
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s0"],
|
||||
"symbol": ["BTC-USDT-PERP"],
|
||||
"event_time": pd.to_datetime(["2026-01-01T00:00:00Z"]),
|
||||
"split_id": ["tune_inner"],
|
||||
"long_prob": [0.70],
|
||||
"short_prob": [0.10],
|
||||
"neutral_prob": [0.20],
|
||||
"long_entry_prob": [0.80],
|
||||
"short_entry_prob": [0.20],
|
||||
"market_risk_prob": [0.20],
|
||||
"long_position_risk_prob": [0.10],
|
||||
"short_position_risk_prob": [0.10],
|
||||
"pred_long_expected_net_edge_bps": [40.0],
|
||||
"pred_short_expected_net_edge_bps": [1.0],
|
||||
"actual_long_expected_net_edge_bps": [30.0],
|
||||
"actual_short_expected_net_edge_bps": [-10.0],
|
||||
"long_trade_net_edge_bps": [11.0],
|
||||
"short_trade_net_edge_bps": [-14.5],
|
||||
"long_target_hit": [1],
|
||||
"short_target_hit": [0],
|
||||
"long_stop_hit": [0],
|
||||
"short_stop_hit": [1],
|
||||
"long_time_to_target_ms": [300_000],
|
||||
"short_time_to_target_ms": [-1],
|
||||
"long_time_to_stop_ms": [-1],
|
||||
"short_time_to_stop_ms": [180_000],
|
||||
"long_entry_target": [1],
|
||||
"short_entry_target": [0],
|
||||
}
|
||||
)
|
||||
thresholds = {
|
||||
"long_open_prob": 0.55,
|
||||
"short_open_prob": 0.55,
|
||||
"min_entry_prob": 0.55,
|
||||
"max_market_risk_prob": 0.55,
|
||||
"min_expected_edge_bps": 3.0,
|
||||
"min_direction_margin": 0.02,
|
||||
}
|
||||
|
||||
trades = _simulate_open_trades(
|
||||
frame,
|
||||
thresholds,
|
||||
default_pm_config(),
|
||||
{"stopDistanceBps": 8.0, "costBps": 6.5},
|
||||
)
|
||||
|
||||
self.assertEqual(1, len(trades))
|
||||
self.assertAlmostEqual(11.0, float(trades.iloc[0]["actual_edge_bps"]))
|
||||
self.assertAlmostEqual(30.0, float(trades.iloc[0]["label_max_edge_bps"]))
|
||||
self.assertGreater(float(trades.iloc[0]["planned_ratio"]), 0.05)
|
||||
self.assertLessEqual(float(trades.iloc[0]["planned_ratio"]), 0.20)
|
||||
|
||||
def test_pm_backtest_blocks_overlapping_open_trades_until_exit_and_cooldown(self) -> None:
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s0", "s1"],
|
||||
"symbol": ["BTC-USDT-PERP", "BTC-USDT-PERP"],
|
||||
"event_time": pd.to_datetime(["2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z"]),
|
||||
"split_id": ["tune_inner", "tune_inner"],
|
||||
"long_prob": [0.70, 0.72],
|
||||
"short_prob": [0.10, 0.10],
|
||||
"neutral_prob": [0.20, 0.18],
|
||||
"long_entry_prob": [0.80, 0.82],
|
||||
"short_entry_prob": [0.20, 0.20],
|
||||
"market_risk_prob": [0.20, 0.20],
|
||||
"long_position_risk_prob": [0.10, 0.10],
|
||||
"short_position_risk_prob": [0.10, 0.10],
|
||||
"pred_long_expected_net_edge_bps": [40.0, 42.0],
|
||||
"pred_short_expected_net_edge_bps": [1.0, 1.0],
|
||||
"actual_long_expected_net_edge_bps": [30.0, 31.0],
|
||||
"actual_short_expected_net_edge_bps": [-10.0, -10.0],
|
||||
"long_trade_net_edge_bps": [11.0, 12.0],
|
||||
"short_trade_net_edge_bps": [-14.5, -14.5],
|
||||
"long_target_hit": [1, 1],
|
||||
"short_target_hit": [0, 0],
|
||||
"long_stop_hit": [0, 0],
|
||||
"short_stop_hit": [1, 1],
|
||||
"long_time_to_target_ms": [300_000, 300_000],
|
||||
"short_time_to_target_ms": [-1, -1],
|
||||
"long_time_to_stop_ms": [-1, -1],
|
||||
"short_time_to_stop_ms": [180_000, 180_000],
|
||||
"long_entry_target": [1, 1],
|
||||
"short_entry_target": [0, 0],
|
||||
}
|
||||
)
|
||||
thresholds = {
|
||||
"long_open_prob": 0.55,
|
||||
"short_open_prob": 0.55,
|
||||
"min_entry_prob": 0.55,
|
||||
"max_market_risk_prob": 0.55,
|
||||
"min_expected_edge_bps": 3.0,
|
||||
"min_direction_margin": 0.02,
|
||||
}
|
||||
|
||||
trades = _simulate_open_trades(
|
||||
frame,
|
||||
thresholds,
|
||||
default_pm_config(),
|
||||
{"stopDistanceBps": 8.0, "costBps": 6.5, "maxHoldMinutes": 45},
|
||||
)
|
||||
|
||||
self.assertEqual(1, len(trades))
|
||||
self.assertEqual("s0", trades.iloc[0]["sample_id"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -14,8 +14,10 @@ if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.onnx_export import LinearHead, export_heads
|
||||
from trader_training.entry_feature_screen import _screen_edge_column
|
||||
from trader_training.io_utils import read_json, write_json
|
||||
from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels
|
||||
from trader_training.ofi_feature_experiment import l1_snapshot_diff_ofi_quote
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
from trader_training.replay import build_splits
|
||||
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
|
||||
@@ -33,6 +35,19 @@ class TrainingContractTest(unittest.TestCase):
|
||||
self.assertEqual(set(fields), set(OUTPUT_MAPPING[model_name]))
|
||||
self.assertEqual([f"prediction[{idx}]" for idx in range(len(fields))], [OUTPUT_MAPPING[model_name][field] for field in fields])
|
||||
|
||||
def test_entry_feature_screen_prefers_actual_plan_edge(self) -> None:
|
||||
dataset = pd.DataFrame(
|
||||
{
|
||||
"long_expected_net_edge_bps": [20.0],
|
||||
"short_expected_net_edge_bps": [15.0],
|
||||
"long_actual_plan_net_edge_bps": [-3.0],
|
||||
"short_actual_plan_net_edge_bps": [4.0],
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual("long_actual_plan_net_edge_bps", _screen_edge_column(dataset, "LONG"))
|
||||
self.assertEqual("short_actual_plan_net_edge_bps", _screen_edge_column(dataset, "SHORT"))
|
||||
|
||||
def test_split_builder_uses_locked_validation_contract(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
@@ -90,7 +105,7 @@ class TrainingContractTest(unittest.TestCase):
|
||||
self.assertEqual(120_000, first["time_to_stop_ms"])
|
||||
self.assertAlmostEqual(-8.0, first["gross_edge_bps"])
|
||||
|
||||
def test_entry_label_uses_max_future_edge_not_fixed_target_hit(self) -> None:
|
||||
def test_entry_label_uses_price_plan_outcome_not_max_future_edge(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
run_root = data_root / "trader-v4" / "runs" / "unit-entry"
|
||||
@@ -167,11 +182,200 @@ class TrainingContractTest(unittest.TestCase):
|
||||
labels = pd.read_parquet(run_root / "label" / "entry_labels.parquet")
|
||||
row = labels[labels["sample_id"].eq("s0") & labels["side"].eq("LONG")].iloc[0]
|
||||
self.assertEqual(0, row["target_hit"])
|
||||
self.assertEqual(1, row["entry_target"])
|
||||
self.assertEqual(0, row["entry_target"])
|
||||
self.assertEqual(ENTRY_LABEL_METHOD, row["label_method"])
|
||||
self.assertAlmostEqual(13.5, row["expected_net_edge_bps"], places=6)
|
||||
self.assertAlmostEqual(-6.5, row["expected_net_edge_bps"], places=6)
|
||||
self.assertAlmostEqual(row["gross_edge_bps"] - row["cost_bps"], row["expected_net_edge_bps"], places=6)
|
||||
self.assertAlmostEqual(row["mfe_bps"] - row["cost_bps"], row["max_achievable_net_edge_bps"], places=6)
|
||||
|
||||
def test_entry_opportunity_label_keeps_plan_outcome_for_pm(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
run_root = data_root / "trader-v4" / "runs" / "unit-entry-opportunity"
|
||||
feature_path = run_root / "feature" / "feature_frame.parquet"
|
||||
replay_path = run_root / "replay" / "replay_1m.parquet"
|
||||
plan_path = run_root / "label" / "price_plan_context.json"
|
||||
config_path = data_root / "label_config.json"
|
||||
feature_path.parent.mkdir(parents=True)
|
||||
replay_path.parent.mkdir(parents=True)
|
||||
|
||||
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
|
||||
pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s0"],
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"event_time": [times[0]],
|
||||
"open_time_ms": [0],
|
||||
"split_id": "fit_inner",
|
||||
"walk_forward_fold": 0,
|
||||
"data_quality_flag": "OK",
|
||||
"spread_bps": 1.0,
|
||||
"spread_rank_24h_pct": 0.1,
|
||||
"realized_vol_15m_bps": 2.0,
|
||||
}
|
||||
).to_parquet(feature_path, index=False)
|
||||
pd.DataFrame(
|
||||
{
|
||||
"event_time": times,
|
||||
"open_time_ms": np.arange(5, dtype=np.int64) * 60_000,
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"open": [100.0] * 5,
|
||||
"high": [100.0, 100.05, 100.19, 100.20, 100.0],
|
||||
"low": [100.0, 99.99, 99.98, 99.97, 100.0],
|
||||
"close": [100.0] * 5,
|
||||
"spread_bps": 1.0,
|
||||
}
|
||||
).to_parquet(replay_path, index=False)
|
||||
write_json(
|
||||
config_path,
|
||||
{
|
||||
"entry": {
|
||||
"max_hold_minutes": 3,
|
||||
"target_bps": 50.0,
|
||||
"stop_bps": 50.0,
|
||||
"min_expected_net_edge_bps": 3.0,
|
||||
"target_method": "OPPORTUNITY_MFE_V1",
|
||||
}
|
||||
},
|
||||
)
|
||||
write_json(
|
||||
plan_path,
|
||||
{
|
||||
"pricePlanId": "unit-plan",
|
||||
"pricePlanConfigHash": "unit-hash",
|
||||
"targetDistanceBps": 50.0,
|
||||
"stopDistanceBps": 50.0,
|
||||
"maxHoldMinutes": 3,
|
||||
"costBps": 6.5,
|
||||
"entryLabelMethod": ENTRY_LABEL_METHOD,
|
||||
"entryTargetMethod": "OPPORTUNITY_MFE_V1",
|
||||
},
|
||||
)
|
||||
|
||||
build_entry_labels(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-entry-opportunity",
|
||||
feature_path=feature_path,
|
||||
replay_path=replay_path,
|
||||
label_config_path=config_path,
|
||||
cost_config_path=None,
|
||||
price_plan_context_path=plan_path,
|
||||
)
|
||||
)
|
||||
|
||||
labels = pd.read_parquet(run_root / "label" / "entry_labels.parquet")
|
||||
row = labels[labels["sample_id"].eq("s0") & labels["side"].eq("LONG")].iloc[0]
|
||||
self.assertEqual(0, row["target_hit"])
|
||||
self.assertEqual(1, row["entry_target"])
|
||||
self.assertEqual("OPPORTUNITY_MFE_V1", row["label_method"])
|
||||
self.assertAlmostEqual(row["mfe_bps"] - row["cost_bps"], row["expected_net_edge_bps"], places=6)
|
||||
self.assertAlmostEqual(-6.5, row["gross_edge_bps"] - row["cost_bps"], places=6)
|
||||
|
||||
def test_entry_quality_label_rejects_untradable_opportunity(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
run_root = data_root / "trader-v4" / "runs" / "unit-entry-quality"
|
||||
feature_path = run_root / "feature" / "feature_frame.parquet"
|
||||
replay_path = run_root / "replay" / "replay_1m.parquet"
|
||||
plan_path = run_root / "label" / "price_plan_context.json"
|
||||
config_path = data_root / "label_config.json"
|
||||
feature_path.parent.mkdir(parents=True)
|
||||
replay_path.parent.mkdir(parents=True)
|
||||
|
||||
times = pd.date_range("2026-01-01", periods=5, freq="min", tz="UTC")
|
||||
pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s0"],
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"event_time": [times[0]],
|
||||
"open_time_ms": [0],
|
||||
"split_id": "fit_inner",
|
||||
"walk_forward_fold": 0,
|
||||
"data_quality_flag": "OK",
|
||||
"spread_bps": 1.0,
|
||||
"spread_rank_24h_pct": 0.1,
|
||||
"realized_vol_15m_bps": 2.0,
|
||||
}
|
||||
).to_parquet(feature_path, index=False)
|
||||
pd.DataFrame(
|
||||
{
|
||||
"event_time": times,
|
||||
"open_time_ms": np.arange(5, dtype=np.int64) * 60_000,
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
"open": [100.0] * 5,
|
||||
"high": [100.0, 100.05, 100.19, 100.20, 100.0],
|
||||
"low": [100.0, 99.99, 99.98, 99.97, 100.0],
|
||||
"close": [100.0] * 5,
|
||||
"spread_bps": 1.0,
|
||||
}
|
||||
).to_parquet(replay_path, index=False)
|
||||
write_json(
|
||||
config_path,
|
||||
{
|
||||
"entry": {
|
||||
"max_hold_minutes": 3,
|
||||
"target_bps": 50.0,
|
||||
"stop_bps": 50.0,
|
||||
"min_expected_net_edge_bps": 3.0,
|
||||
"min_plan_net_edge_bps": 0.0,
|
||||
"max_entry_mae_bps": 12.0,
|
||||
"target_method": "OPPORTUNITY_QUALITY_V1",
|
||||
}
|
||||
},
|
||||
)
|
||||
write_json(
|
||||
plan_path,
|
||||
{
|
||||
"pricePlanId": "unit-plan",
|
||||
"pricePlanConfigHash": "unit-hash",
|
||||
"targetDistanceBps": 50.0,
|
||||
"stopDistanceBps": 50.0,
|
||||
"maxHoldMinutes": 3,
|
||||
"costBps": 6.5,
|
||||
"entryLabelMethod": ENTRY_LABEL_METHOD,
|
||||
"entryTargetMethod": "OPPORTUNITY_QUALITY_V1",
|
||||
},
|
||||
)
|
||||
|
||||
build_entry_labels(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-entry-quality",
|
||||
feature_path=feature_path,
|
||||
replay_path=replay_path,
|
||||
label_config_path=config_path,
|
||||
cost_config_path=None,
|
||||
price_plan_context_path=plan_path,
|
||||
)
|
||||
)
|
||||
|
||||
labels = pd.read_parquet(run_root / "label" / "entry_labels.parquet")
|
||||
row = labels[labels["sample_id"].eq("s0") & labels["side"].eq("LONG")].iloc[0]
|
||||
self.assertEqual("OPPORTUNITY_QUALITY_V1", row["label_method"])
|
||||
self.assertGreater(row["expected_net_edge_bps"], 3.0)
|
||||
self.assertLess(row["actual_plan_net_edge_bps"], 0.0)
|
||||
self.assertEqual(0, row["entry_target"])
|
||||
|
||||
def test_l1_snapshot_diff_ofi_uses_quote_notional_and_signed_ask_side(self) -> None:
|
||||
bid_part, ask_part = l1_snapshot_diff_ofi_quote(
|
||||
pd.Series([101.0, 101.0, 100.5]),
|
||||
pd.Series([2.0, 3.0, 4.0]),
|
||||
pd.Series([102.0, 101.5, 102.5]),
|
||||
pd.Series([5.0, 6.0, 7.0]),
|
||||
pd.Series([100.0, 101.0, 101.0]),
|
||||
pd.Series([1.5, 2.0, 3.0]),
|
||||
pd.Series([102.0, 102.0, 101.5]),
|
||||
pd.Series([4.0, 5.0, 6.0]),
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(202.0, bid_part.iloc[0])
|
||||
self.assertAlmostEqual(-102.0, ask_part.iloc[0])
|
||||
self.assertAlmostEqual(101.0, bid_part.iloc[1])
|
||||
self.assertAlmostEqual(-609.0, ask_part.iloc[1])
|
||||
self.assertAlmostEqual(-303.0, bid_part.iloc[2])
|
||||
self.assertAlmostEqual(609.0, ask_part.iloc[2])
|
||||
|
||||
def test_exported_onnx_accepts_java_feature_shape(self) -> None:
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
Reference in New Issue
Block a user