Add Entry condition pair diagnostics

This commit is contained in:
Codex
2026-06-28 08:21:01 +08:00
parent dc4d00a373
commit 3f49af5ba6
3 changed files with 446 additions and 0 deletions
+46
View File
@@ -15,6 +15,7 @@ if str(TRAINING_ROOT) not in sys.path:
from trader_training.onnx_export import LinearHead, export_heads
from trader_training.dynamic_exit_search import search_dynamic_exit_plans
from trader_training.entry_condition_pair_screen import screen_entry_condition_pairs
from trader_training.entry_feature_screen import _bucket_edges, _screen_edge_column
from trader_training.io_utils import read_json, write_json
from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels
@@ -70,6 +71,51 @@ class TrainingContractTest(unittest.TestCase):
self.assertEqual(-np.inf, edges[0])
self.assertEqual(np.inf, edges[-1])
def test_entry_condition_pair_screen_finds_stable_two_feature_filter(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)
run_root = data_root / "trader-v4" / "runs" / "unit-condition-pair"
dataset_path = run_root / "dataset" / "entry_train.parquet"
dataset_path.parent.mkdir(parents=True)
frames = []
row_count = 1200
base_feature_values = np.linspace(0.0, 0.999, row_count)
for split_id in TRAINING_SPLITS:
frame = pd.DataFrame({feature: 0.0 for feature in FEATURE_ORDER}, index=np.arange(row_count))
frame["split_id"] = split_id
frame["ret_1m_bps"] = base_feature_values
frame["ret_5m_bps"] = base_feature_values
good_mask = (frame["ret_1m_bps"] > 0.9) & (frame["ret_5m_bps"] > 0.9)
frame["long_entry_target"] = good_mask.astype(int)
frame["short_entry_target"] = 0
frame["long_actual_plan_net_edge_bps"] = np.where(good_mask, 8.0, -6.0)
frame["short_actual_plan_net_edge_bps"] = -6.0
frame["long_mae_bps"] = np.where(good_mask, 2.0, 15.0)
frame["short_mae_bps"] = 15.0
frames.append(frame)
pd.concat(frames, ignore_index=True).to_parquet(dataset_path, index=False)
screen_entry_condition_pairs(
Namespace(
data_root=data_root,
run_id="unit-condition-pair",
min_seed_rows=50,
min_pair_rows=50,
max_seed_conditions_per_side=8,
max_buckets_per_feature=2,
)
)
result = read_json(run_root / "diagnostics" / "entry_condition_pair_screen_result.json")
candidates = pd.read_csv(run_root / "diagnostics" / "entry_condition_pair_candidates.csv")
self.assertGreater(result["stable_candidate_count"], 0)
self.assertTrue(candidates["usable_candidate"].any())
best = candidates.iloc[0]
self.assertEqual("LONG", best["side"])
self.assertGreater(float(best["min_eval_edge_bps"]), 0.0)
def test_dynamic_exit_search_writes_plan_diagnostics(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
data_root = Path(tmp)