diff --git a/training/tests/test_training_contract.py b/training/tests/test_training_contract.py index 7e23f95..b9947ae 100644 --- a/training/tests/test_training_contract.py +++ b/training/tests/test_training_contract.py @@ -14,7 +14,7 @@ if str(TRAINING_ROOT) not in sys.path: sys.path.insert(0, str(TRAINING_ROOT)) from trader_training.onnx_export import LinearHead, export_heads -from trader_training.entry_feature_screen import _screen_edge_column +from trader_training.entry_feature_screen import _bucket_edges, _screen_edge_column from trader_training.io_utils import read_json, write_json from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels from trader_training.ofi_feature_experiment import l1_snapshot_diff_ofi_quote @@ -48,6 +48,14 @@ class TrainingContractTest(unittest.TestCase): self.assertEqual("long_actual_plan_net_edge_bps", _screen_edge_column(dataset, "LONG")) self.assertEqual("short_actual_plan_net_edge_bps", _screen_edge_column(dataset, "SHORT")) + def test_entry_feature_screen_keeps_zero_inflated_event_features(self) -> None: + values = np.concatenate((np.zeros(5000), np.linspace(1.0, 100.0, 500))) + edges = _bucket_edges(values) + + self.assertGreaterEqual(len(edges), 3) + self.assertEqual(-np.inf, edges[0]) + self.assertEqual(np.inf, edges[-1]) + def test_split_builder_uses_locked_validation_contract(self) -> None: with tempfile.TemporaryDirectory() as tmp: data_root = Path(tmp) diff --git a/training/trader_training/entry_feature_screen.py b/training/trader_training/entry_feature_screen.py index 783c31c..5a4317d 100644 --- a/training/trader_training/entry_feature_screen.py +++ b/training/trader_training/entry_feature_screen.py @@ -146,7 +146,18 @@ def _bucket_edges(values: np.ndarray) -> np.ndarray: edges = np.quantile(clean, quantiles) edges = np.unique(edges) if edges.size < 3: - return np.array([], dtype="float64") + non_zero = clean[clean != 0.0] + if non_zero.size < 300: + return np.array([], dtype="float64") + # 突破/扫单类特征常常绝大多数为 0。普通十分位会全挤在 0, + # 这里单独保留“没有事件”和“有事件强弱”两类桶,避免漏掉稀有但可能有用的信号。 + event_edges = np.unique(np.quantile(non_zero, np.linspace(0.0, 1.0, 6))) + if event_edges.size < 2: + return np.array([-np.inf, 0.0, np.inf], dtype="float64") + edges = np.unique(np.concatenate(([-np.inf, 0.0], event_edges[1:-1], [np.inf]))).astype("float64") + if edges.size < 3: + return np.array([], dtype="float64") + return edges edges[0] = -np.inf edges[-1] = np.inf return edges