diff --git a/training/tests/test_state_continue_experiment.py b/training/tests/test_state_continue_experiment.py index 4ed433c..a7ff374 100644 --- a/training/tests/test_state_continue_experiment.py +++ b/training/tests/test_state_continue_experiment.py @@ -111,6 +111,26 @@ class StateContinueExperimentTest(unittest.TestCase): self.assertEqual("NOT_READY_FOR_FORMAL_CHAIN", verdict["status"]) self.assertTrue(any("above 0.97" in reason for reason in verdict["reasons"])) + def test_verdict_reports_when_state_features_do_not_beat_market_only(self) -> None: + results = {} + for side in ("long", "short"): + results[f"{side}_market_only"] = { + "validation_locked": {"continue_auc": 0.64, "edge_mae_vs_constant_ratio": 0.965}, + "latest_stress": {"continue_auc": 0.65, "edge_mae_vs_constant_ratio": 0.964}, + "regressor_converged": True, + } + results[f"{side}_market_plus_state"] = { + "validation_locked": {"continue_auc": 0.63, "edge_mae_vs_constant_ratio": 0.975}, + "latest_stress": {"continue_auc": 0.66, "edge_mae_vs_constant_ratio": 0.963}, + "regressor_converged": True, + } + + verdict = _verdict(results) + + self.assertEqual("NOT_READY_FOR_FORMAL_CHAIN", verdict["status"]) + self.assertTrue(any("continue_auc not better than market_only" in reason for reason in verdict["reasons"])) + self.assertTrue(any("edge_mae_vs_constant_ratio not better than market_only" in reason for reason in verdict["reasons"])) + def test_train_side_models_supports_ridge_regressor_diagnostic(self) -> None: rows = [] for split_id in ("fit_inner", "tune_inner", "validation_locked", "latest_stress"): diff --git a/training/trader_training/state_continue_experiment.py b/training/trader_training/state_continue_experiment.py index a17d0d6..e882879 100644 --- a/training/trader_training/state_continue_experiment.py +++ b/training/trader_training/state_continue_experiment.py @@ -578,17 +578,21 @@ def _verdict(results: dict[str, Any]) -> dict[str, Any]: base_auc = base_metric.get("continue_auc") plus_mae = plus_metric.get("edge_mae_vs_constant_ratio") base_mae = base_metric.get("edge_mae_vs_constant_ratio") - if plus_auc is None or plus_auc < 0.60: + auc_ok = plus_auc is not None and plus_auc >= 0.60 + auc_beats_market_only = base_auc is None or (plus_auc is not None and plus_auc > base_auc) + if not auc_ok: reasons.append(f"{side} {split_id} continue_auc below 0.60: {plus_auc}") - elif base_auc is not None and plus_auc <= base_auc: + if not auc_beats_market_only: reasons.append(f"{side} {split_id} continue_auc not better than market_only: {plus_auc} <= {base_auc}") - else: + if auc_ok and auc_beats_market_only: passed_checks.append(f"{side} {split_id} continue_auc") - if plus_mae is None or plus_mae > 0.97: + mae_ok = plus_mae is not None and plus_mae <= 0.97 + mae_beats_market_only = base_mae is None or (plus_mae is not None and plus_mae < base_mae) + if not mae_ok: reasons.append(f"{side} {split_id} edge_mae_vs_constant_ratio above 0.97: {plus_mae}") - elif base_mae is not None and plus_mae >= base_mae: + if not mae_beats_market_only: reasons.append(f"{side} {split_id} edge_mae_vs_constant_ratio not better than market_only: {plus_mae} >= {base_mae}") - else: + if mae_ok and mae_beats_market_only: passed_checks.append(f"{side} {split_id} edge_mae_vs_constant_ratio") return { "status": "PASS_TO_FORMAL_CHAIN" if not reasons else "NOT_READY_FOR_FORMAL_CHAIN",