Support conditional Entry training

This commit is contained in:
Codex
2026-06-28 08:40:30 +08:00
parent 0323fb3caf
commit 5ad77ffe90
3 changed files with 102 additions and 10 deletions
+15 -1
View File
@@ -25,7 +25,7 @@ from trader_training.ofi_feature_experiment import _load_entry_dataset, l1_snaps
from trader_training.promote import promote_artifact_bundle
from trader_training.replay import build_splits
from trader_training.schemas import FEATURE_ORDER, LATEST_STRESS_SPLIT, MODEL_OUTPUTS, OUTPUT_MAPPING, TRAINING_SPLITS, VALIDATION_LOCKED_SPLIT
from trader_training.training import TARGETS
from trader_training.training import TARGETS, _head_train_mask
class TrainingContractTest(unittest.TestCase):
@@ -65,6 +65,20 @@ class TrainingContractTest(unittest.TestCase):
self.assertEqual("long_actual_plan_net_edge_bps", heads["long_expected_net_edge_bps"])
self.assertEqual("short_actual_plan_net_edge_bps", heads["short_expected_net_edge_bps"])
def test_conditional_entry_training_uses_direction_label_rows(self) -> None:
train = pd.DataFrame({"long_target": [1, 0, 1, 0], "short_target": [0, 1, 0, 1]})
long_mask, long_filter = _head_train_mask("ENTRY", "long_entry_prob", train, Namespace(conditional_entry_direction_labels=True))
short_mask, short_filter = _head_train_mask("ENTRY", "short_expected_net_edge_bps", train, Namespace(conditional_entry_direction_labels=True))
default_mask, default_filter = _head_train_mask("ENTRY", "long_entry_prob", train, Namespace(conditional_entry_direction_labels=False))
self.assertEqual("DIRECTION_LABEL_LONG_FIT_ROWS", long_filter)
self.assertEqual([True, False, True, False], long_mask.tolist())
self.assertEqual("DIRECTION_LABEL_SHORT_FIT_ROWS", short_filter)
self.assertEqual([False, True, False, True], short_mask.tolist())
self.assertEqual("ALL_FIT_ROWS", default_filter)
self.assertEqual([True, True, True, True], default_mask.tolist())
def test_entry_feature_screen_keeps_zero_inflated_event_features(self) -> None:
values = np.concatenate((np.zeros(5000), np.linspace(1.0, 100.0, 500)))
edges = _bucket_edges(values)