Improve nonlinear PM diagnostics
This commit is contained in:
@@ -23,7 +23,7 @@ from trader_training.entry_mae_label_diagnostic import diagnose_entry_mae_labels
|
||||
from trader_training.good_trade_structure import _side_frame, _top_fraction_metrics
|
||||
from trader_training.io_utils import read_json, write_json
|
||||
from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels
|
||||
from trader_training.nonlinear_pm_probe import _expanded_threshold_candidates
|
||||
from trader_training.nonlinear_pm_probe import _entry_side_fit_frame, _exit_metrics, _expanded_threshold_candidates
|
||||
from trader_training.ofi_feature_experiment import _load_entry_dataset, l1_snapshot_diff_ofi_quote
|
||||
from trader_training.promote import promote_artifact_bundle
|
||||
from trader_training.replay import build_splits
|
||||
@@ -58,6 +58,57 @@ class TrainingContractTest(unittest.TestCase):
|
||||
},
|
||||
candidates,
|
||||
)
|
||||
self.assertIn(
|
||||
{
|
||||
"long_open_prob": 1.01,
|
||||
"short_open_prob": 0.2,
|
||||
"min_entry_prob": 0.05,
|
||||
"max_market_risk_prob": 0.45,
|
||||
"min_expected_edge_bps": -5.0,
|
||||
"min_direction_margin": 0.0,
|
||||
},
|
||||
candidates,
|
||||
)
|
||||
|
||||
def test_nonlinear_entry_tree_probe_can_use_side_opportunity_rows(self) -> None:
|
||||
direction = pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s1", "s2", "s3", "s4"],
|
||||
"long_target": [1, 0, 0, 0],
|
||||
"short_target": [0, 1, 0, 0],
|
||||
}
|
||||
)
|
||||
entry = pd.DataFrame(
|
||||
{
|
||||
"sample_id": ["s1", "s2", "s3", "s4"],
|
||||
"split_id": ["fit_inner", "fit_inner", "fit_inner", "fit_inner"],
|
||||
"long_max_achievable_net_edge_bps": [45.0, 10.0, 65.0, 39.0],
|
||||
"short_max_achievable_net_edge_bps": [8.0, 41.0, 15.0, 70.0],
|
||||
}
|
||||
)
|
||||
|
||||
long_frame = _entry_side_fit_frame(direction, entry, "LONG", "side_opportunity", 40.0)
|
||||
short_frame = _entry_side_fit_frame(direction, entry, "SHORT", "side_opportunity", 40.0)
|
||||
|
||||
self.assertEqual(["s1", "s3"], long_frame["sample_id"].tolist())
|
||||
self.assertEqual(["s2", "s4"], short_frame["sample_id"].tolist())
|
||||
|
||||
def test_nonlinear_pm_probe_exit_metrics_describe_trade_outcomes(self) -> None:
|
||||
trades = pd.DataFrame(
|
||||
{
|
||||
"target_hit": [1, 0, 0],
|
||||
"stop_hit": [0, 1, 0],
|
||||
"time_to_exit_ms": [300_000, 600_000, 2_700_000],
|
||||
}
|
||||
)
|
||||
|
||||
metrics = _exit_metrics(trades)
|
||||
|
||||
self.assertAlmostEqual(1 / 3, metrics["target_hit_rate"])
|
||||
self.assertAlmostEqual(1 / 3, metrics["stop_hit_rate"])
|
||||
self.assertAlmostEqual(1 / 3, metrics["timeout_exit_rate"])
|
||||
self.assertAlmostEqual(20.0, metrics["avg_time_to_exit_min"])
|
||||
self.assertAlmostEqual(10.0, metrics["p50_time_to_exit_min"])
|
||||
|
||||
def test_entry_feature_screen_prefers_actual_plan_edge(self) -> None:
|
||||
dataset = pd.DataFrame(
|
||||
|
||||
Reference in New Issue
Block a user