Add state Continue path interaction features
This commit is contained in:
@@ -18,7 +18,7 @@ from trader_training.state_continue_experiment import STATE_FEATURES, _predict_f
|
||||
|
||||
|
||||
class StateContinueExperimentTest(unittest.TestCase):
|
||||
def test_state_rows_include_required_position_and_frozen_entry_features(self) -> None:
|
||||
def test_state_rows_include_required_position_path_and_side_market_features(self) -> None:
|
||||
row = {
|
||||
"current_sample_id": "s0",
|
||||
"symbol": "BTC-USDT-PERP",
|
||||
@@ -41,6 +41,13 @@ class StateContinueExperimentTest(unittest.TestCase):
|
||||
}
|
||||
for feature_name in FEATURE_ORDER:
|
||||
row[feature_name] = 0.0
|
||||
row["ret_1m_bps"] = 2.0
|
||||
row["ret_5m_bps"] = 3.0
|
||||
row["taker_imbalance_1m"] = 0.1
|
||||
row["taker_imbalance_5m"] = 0.2
|
||||
row["book_microprice_basis_bps"] = 4.0
|
||||
row["book_pressure_taker_1m"] = 5.0
|
||||
row["book_pressure_taker_5m"] = 6.0
|
||||
frame = pd.DataFrame([row])
|
||||
|
||||
out = _state_rows_for_age(frame, stop_bps=8.0, target_bps=12.0, cost_bps=6.5)
|
||||
@@ -50,6 +57,17 @@ class StateContinueExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(1, int(out.iloc[0]["continue_target"]))
|
||||
self.assertAlmostEqual(8.5, float(out.iloc[0]["entry_predicted_edge_bps"]))
|
||||
self.assertAlmostEqual(0.64, float(out.iloc[0]["entry_direction_prob"]), places=6)
|
||||
self.assertAlmostEqual(16.5, float(out.iloc[0]["giveback_from_mfe_bps"]), places=4)
|
||||
self.assertAlmostEqual(8.5025, float(out.iloc[0]["recovery_from_mae_bps"]), places=4)
|
||||
self.assertGreater(float(out.iloc[0]["path_efficiency"]), 0.13)
|
||||
self.assertGreater(float(out.iloc[0]["mfe_mae_ratio"]), 3.3)
|
||||
self.assertAlmostEqual(2.0, float(out.iloc[0]["side_ret_1m_bps"]))
|
||||
self.assertAlmostEqual(3.0, float(out.iloc[0]["side_ret_5m_bps"]))
|
||||
self.assertAlmostEqual(0.1, float(out.iloc[0]["side_taker_imbalance_1m"]), places=6)
|
||||
self.assertAlmostEqual(0.2, float(out.iloc[0]["side_taker_imbalance_5m"]), places=6)
|
||||
self.assertAlmostEqual(4.0, float(out.iloc[0]["side_book_microprice_basis_bps"]))
|
||||
self.assertAlmostEqual(5.0, float(out.iloc[0]["side_book_pressure_taker_1m"]))
|
||||
self.assertAlmostEqual(6.0, float(out.iloc[0]["side_book_pressure_taker_5m"]))
|
||||
self.assertAlmostEqual(0.0, float(out.iloc[0]["add_count"]))
|
||||
self.assertAlmostEqual(9999.0, float(out.iloc[0]["minutes_since_last_add"]))
|
||||
|
||||
|
||||
@@ -26,6 +26,17 @@ STATE_FEATURES = [
|
||||
"distance_to_target_bps",
|
||||
"entry_predicted_edge_bps",
|
||||
"entry_direction_prob",
|
||||
"path_efficiency",
|
||||
"giveback_from_mfe_bps",
|
||||
"recovery_from_mae_bps",
|
||||
"mfe_mae_ratio",
|
||||
"side_ret_1m_bps",
|
||||
"side_ret_5m_bps",
|
||||
"side_taker_imbalance_1m",
|
||||
"side_taker_imbalance_5m",
|
||||
"side_book_microprice_basis_bps",
|
||||
"side_book_pressure_taker_1m",
|
||||
"side_book_pressure_taker_5m",
|
||||
"add_count",
|
||||
"minutes_since_last_add",
|
||||
]
|
||||
@@ -55,7 +66,8 @@ def run_state_continue_experiment(args: Any) -> None:
|
||||
)
|
||||
|
||||
feature = _load_feature_frame(baseline_root)
|
||||
entry = _load_entry_labels(baseline_root, feature)
|
||||
frozen_scores = _frozen_entry_scores_by_sample(baseline_root, feature)
|
||||
entry = _load_entry_labels(baseline_root, feature, frozen_scores)
|
||||
replay = _load_replay(baseline_root)
|
||||
plan = read_json(baseline_root / "label" / "price_plan_context.json")
|
||||
stop_bps = float(plan["stopDistanceBps"])
|
||||
@@ -100,6 +112,7 @@ def run_state_continue_experiment(args: Any) -> None:
|
||||
{
|
||||
"entry_predicted_edge_bps": "run-10 frozen ENTRY ONNX output selected by entry side",
|
||||
"entry_direction_prob": "run-10 frozen DIRECTION ONNX output selected by entry side",
|
||||
"path_features": "position path shape and side-adjusted market pressure features computed at current state time",
|
||||
"out_of_fold_used": False,
|
||||
"frozen_model_output_used": True,
|
||||
"frozen_model_output_policy": "baseline model is fixed and is not retrained inside this experiment",
|
||||
@@ -161,7 +174,7 @@ def _load_feature_frame(baseline_root: Path) -> pd.DataFrame:
|
||||
return feature
|
||||
|
||||
|
||||
def _load_entry_labels(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFrame:
|
||||
def _load_entry_labels(baseline_root: Path, feature: pd.DataFrame, frozen_scores: pd.DataFrame) -> pd.DataFrame:
|
||||
entry = read_parquet(baseline_root / "label" / "entry_labels.parquet")
|
||||
required = {"sample_id", "symbol", "event_time", "side", "entry_target", "split_id", "walk_forward_fold"}
|
||||
missing = sorted(required.difference(entry.columns))
|
||||
@@ -169,8 +182,7 @@ def _load_entry_labels(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFra
|
||||
raise ValueError(f"baseline entry labels missing columns: {missing}")
|
||||
entry = entry[(entry["entry_target"] == 1) & (entry["side"].isin(["LONG", "SHORT"]))].copy()
|
||||
entry["entry_open_time_ms"] = pd.to_datetime(entry["event_time"], utc=True).astype("int64") // 1_000_000
|
||||
entry_scores = _frozen_entry_scores_by_sample(baseline_root, feature)
|
||||
entry = entry.merge(entry_scores, on="sample_id", how="inner")
|
||||
entry = entry.merge(frozen_scores, on="sample_id", how="inner")
|
||||
if entry.empty:
|
||||
raise ValueError("state continue entry set is empty after merging frozen baseline model outputs")
|
||||
long_mask = entry["side"].eq("LONG")
|
||||
@@ -180,7 +192,17 @@ def _load_entry_labels(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFra
|
||||
entry["frozen_short_expected_net_edge_bps"],
|
||||
)
|
||||
entry["entry_direction_prob"] = np.where(long_mask, entry["frozen_long_prob"], entry["frozen_short_prob"])
|
||||
return entry[["sample_id", "symbol", "event_time", "side", "entry_open_time_ms", "entry_predicted_edge_bps", "entry_direction_prob"]].copy()
|
||||
return entry[
|
||||
[
|
||||
"sample_id",
|
||||
"symbol",
|
||||
"event_time",
|
||||
"side",
|
||||
"entry_open_time_ms",
|
||||
"entry_predicted_edge_bps",
|
||||
"entry_direction_prob",
|
||||
]
|
||||
].copy()
|
||||
|
||||
|
||||
def _frozen_entry_scores_by_sample(baseline_root: Path, feature: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -374,6 +396,20 @@ def _state_rows_for_age(frame: pd.DataFrame, stop_bps: float, target_bps: float,
|
||||
out["distance_to_target_bps"] = distance_to_target.astype("float32")
|
||||
out["entry_predicted_edge_bps"] = frame["entry_predicted_edge_bps"].astype("float32")
|
||||
out["entry_direction_prob"] = frame["entry_direction_prob"].astype("float32")
|
||||
safe_mfe = np.maximum(mfe, 0.0)
|
||||
safe_mae = np.maximum(mae, 0.0)
|
||||
out["path_efficiency"] = (unrealized / (safe_mfe + safe_mae + 1.0)).astype("float32")
|
||||
out["giveback_from_mfe_bps"] = (safe_mfe - np.maximum(unrealized, 0.0)).astype("float32")
|
||||
out["recovery_from_mae_bps"] = (unrealized + safe_mae).astype("float32")
|
||||
out["mfe_mae_ratio"] = (safe_mfe / (safe_mae + 1.0)).astype("float32")
|
||||
# Convert market pressure into "helps the current position" direction so LONG and SHORT share one meaning.
|
||||
out["side_ret_1m_bps"] = (side_sign * frame["ret_1m_bps"].astype(float)).astype("float32")
|
||||
out["side_ret_5m_bps"] = (side_sign * frame["ret_5m_bps"].astype(float)).astype("float32")
|
||||
out["side_taker_imbalance_1m"] = (side_sign * frame["taker_imbalance_1m"].astype(float)).astype("float32")
|
||||
out["side_taker_imbalance_5m"] = (side_sign * frame["taker_imbalance_5m"].astype(float)).astype("float32")
|
||||
out["side_book_microprice_basis_bps"] = (side_sign * frame["book_microprice_basis_bps"].astype(float)).astype("float32")
|
||||
out["side_book_pressure_taker_1m"] = (side_sign * frame["book_pressure_taker_1m"].astype(float)).astype("float32")
|
||||
out["side_book_pressure_taker_5m"] = (side_sign * frame["book_pressure_taker_5m"].astype(float)).astype("float32")
|
||||
out["add_count"] = frame["add_count"].astype("float32")
|
||||
out["minutes_since_last_add"] = frame["minutes_since_last_add"].astype("float32")
|
||||
out["continue_target"] = continue_target
|
||||
@@ -541,6 +577,7 @@ def _source_manifest(
|
||||
"uses_same_round_model_prediction_as_feature": False,
|
||||
"entry_predicted_edge_bps": "baseline frozen ENTRY ONNX output selected by side",
|
||||
"entry_direction_prob": "baseline frozen DIRECTION ONNX output selected by side",
|
||||
"path_features": "position path shape and side-adjusted market pressure at current state time",
|
||||
"add_count": "synthetic first-position diagnostic, fixed to 0",
|
||||
"minutes_since_last_add": "synthetic first-position diagnostic, fixed to 9999",
|
||||
},
|
||||
@@ -558,6 +595,17 @@ def _state_feature_schema() -> list[dict[str, Any]]:
|
||||
{"name": "distance_to_target_bps", "unit": "bps", "source": "price plan and current close", "leakage_check": "uses fixed plan and current price"},
|
||||
{"name": "entry_predicted_edge_bps", "unit": "bps", "source": "baseline frozen ENTRY ONNX", "leakage_check": "baseline model is fixed before this experiment"},
|
||||
{"name": "entry_direction_prob", "unit": "probability", "source": "baseline frozen DIRECTION ONNX", "leakage_check": "baseline model is fixed before this experiment"},
|
||||
{"name": "path_efficiency", "unit": "ratio", "source": "unrealized_pnl_bps / (mfe + mae + 1)", "leakage_check": "uses entry..current path only"},
|
||||
{"name": "giveback_from_mfe_bps", "unit": "bps", "source": "mfe_since_entry_bps - max(unrealized_pnl_bps, 0)", "leakage_check": "uses entry..current path only"},
|
||||
{"name": "recovery_from_mae_bps", "unit": "bps", "source": "unrealized_pnl_bps + mae_since_entry_bps", "leakage_check": "uses entry..current path only"},
|
||||
{"name": "mfe_mae_ratio", "unit": "ratio", "source": "mfe_since_entry_bps / (mae_since_entry_bps + 1)", "leakage_check": "uses entry..current path only"},
|
||||
{"name": "side_ret_1m_bps", "unit": "bps", "source": "position_side_sign * ret_1m_bps", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "side_ret_5m_bps", "unit": "bps", "source": "position_side_sign * ret_5m_bps", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "side_taker_imbalance_1m", "unit": "ratio", "source": "position_side_sign * taker_imbalance_1m", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "side_taker_imbalance_5m", "unit": "ratio", "source": "position_side_sign * taker_imbalance_5m", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "side_book_microprice_basis_bps", "unit": "bps", "source": "position_side_sign * book_microprice_basis_bps", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "side_book_pressure_taker_1m", "unit": "bps", "source": "position_side_sign * book_pressure_taker_1m", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "side_book_pressure_taker_5m", "unit": "bps", "source": "position_side_sign * book_pressure_taker_5m", "leakage_check": "uses <= current time feature only"},
|
||||
{"name": "add_count", "unit": "count", "source": "synthetic position state", "leakage_check": "known at current position time"},
|
||||
{"name": "minutes_since_last_add", "unit": "minute", "source": "synthetic position state", "leakage_check": "known at current position time"},
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user