2026-06-27 20:28:31 +08:00
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
2026-06-27 23:06:43 +08:00
from sklearn . linear_model import HuberRegressor , LogisticRegression , Ridge
2026-06-27 20:28:31 +08:00
from sklearn . metrics import brier_score_loss , mean_absolute_error , roc_auc_score
from sklearn . preprocessing import StandardScaler
from trader_training . io_utils import read_json , read_parquet , run_root , sha256_json , write_json , write_parquet , write_text
2026-06-27 23:06:43 +08:00
from trader_training . labels import DEFAULT_LABEL_CONFIG , _build_path_stats
2026-06-27 20:28:31 +08:00
from trader_training . schemas import FEATURE_ORDER , FIT_SPLIT , LATEST_STRESS_SPLIT , TUNE_SPLIT , VALIDATION_LOCKED_SPLIT
STATE_FEATURES = [
" position_side_sign " ,
" time_in_position_minutes " ,
" unrealized_pnl_bps " ,
" mfe_since_entry_bps " ,
" mae_since_entry_bps " ,
" distance_to_stop_bps " ,
" distance_to_target_bps " ,
2026-06-27 23:06:43 +08:00
" entry_predicted_edge_bps " ,
" entry_direction_prob " ,
" add_count " ,
" minutes_since_last_add " ,
2026-06-27 20:28:31 +08:00
]
EVAL_SPLITS = ( TUNE_SPLIT , VALIDATION_LOCKED_SPLIT , LATEST_STRESS_SPLIT )
ALL_SPLITS = ( FIT_SPLIT , TUNE_SPLIT , VALIDATION_LOCKED_SPLIT , LATEST_STRESS_SPLIT )
def run_state_continue_experiment ( args : Any ) - > None :
root = run_root ( args )
baseline_root = args . data_root / " trader-v4 " / " runs " / args . baseline_run_id
out_dir = root / " experiments " / " state_continue "
ages = _parse_ages ( args . ages_minutes )
2026-06-27 23:06:43 +08:00
regressor_kind = getattr ( args , " regressor_kind " , " huber " )
ridge_alpha = float ( getattr ( args , " ridge_alpha " , 10.0 ) )
huber_max_iter = int ( getattr ( args , " huber_max_iter " , 1000 ) )
regression_target_clip_bps = float ( getattr ( args , " regression_target_clip_bps " , 0.0 ) )
2026-06-27 20:28:31 +08:00
logging . info (
2026-06-27 23:06:43 +08:00
" trader.training.state_continue_experiment_started runId= %s baselineRunId= %s ages= %s regressorKind= %s ridgeAlpha= %s huberMaxIter= %s regressionTargetClipBps= %s " ,
2026-06-27 20:28:31 +08:00
args . run_id ,
args . baseline_run_id ,
ages ,
2026-06-27 23:06:43 +08:00
regressor_kind ,
ridge_alpha ,
huber_max_iter ,
regression_target_clip_bps ,
2026-06-27 20:28:31 +08:00
)
feature = _load_feature_frame ( baseline_root )
2026-06-27 23:06:43 +08:00
entry = _load_entry_labels ( baseline_root , feature )
2026-06-27 20:28:31 +08:00
replay = _load_replay ( baseline_root )
plan = read_json ( baseline_root / " label " / " price_plan_context.json " )
stop_bps = float ( plan [ " stopDistanceBps " ] )
target_bps = float ( plan [ " targetDistanceBps " ] )
cost_bps = float ( plan [ " costBps " ] )
2026-06-27 23:06:43 +08:00
continue_horizon = int ( DEFAULT_LABEL_CONFIG [ " continue " ] [ " horizon_minutes " ] )
min_continue_edge_bps = float ( DEFAULT_LABEL_CONFIG [ " continue " ] [ " min_expected_continue_edge_bps " ] )
2026-06-27 20:28:31 +08:00
2026-06-27 23:06:43 +08:00
state_frame = _build_state_frame ( feature , entry , replay , ages , stop_bps , target_bps , cost_bps , continue_horizon , min_continue_edge_bps )
2026-06-27 20:28:31 +08:00
if args . max_rows_per_split :
state_frame = _cap_rows_per_split ( state_frame , int ( args . max_rows_per_split ) )
dataset_hash = write_parquet ( out_dir / " state_continue_train.parquet " , state_frame )
logging . info (
" trader.training.state_continue_dataset_written runId= %s rowCount= %s splitCounts= %s path= %s " ,
args . run_id ,
len ( state_frame ) ,
state_frame [ " split_id " ] . value_counts ( ) . to_dict ( ) ,
out_dir / " state_continue_train.parquet " ,
)
2026-06-27 23:06:43 +08:00
source_manifest = _source_manifest (
args ,
baseline_root ,
ages ,
stop_bps ,
target_bps ,
cost_bps ,
continue_horizon ,
min_continue_edge_bps ,
state_frame ,
dataset_hash ,
regressor_kind ,
ridge_alpha ,
huber_max_iter ,
regression_target_clip_bps ,
)
2026-06-27 20:28:31 +08:00
write_json ( out_dir / " experiment_manifest.json " , source_manifest )
write_json ( out_dir / " position_state_feature_schema.json " , _state_feature_schema ( ) )
order_hash = write_json ( out_dir / " position_state_feature_order.json " , STATE_FEATURES )
write_json (
out_dir / " position_state_source_manifest.json " ,
{
2026-06-27 23:06:43 +08:00
" entry_predicted_edge_bps " : " run-10 frozen ENTRY ONNX output selected by entry side " ,
" entry_direction_prob " : " run-10 frozen DIRECTION ONNX output selected by entry side " ,
2026-06-27 20:28:31 +08:00
" out_of_fold_used " : False ,
2026-06-27 23:06:43 +08:00
" frozen_model_output_used " : True ,
" frozen_model_output_policy " : " baseline model is fixed and is not retrained inside this experiment " ,
2026-06-27 20:28:31 +08:00
" replay_decision_trace_used " : False ,
" state_feature_order_hash " : order_hash ,
" row_count " : len ( state_frame ) ,
" split_counts " : state_frame [ " split_id " ] . value_counts ( ) . to_dict ( ) ,
} ,
)
feature_sets = {
" market_only " : FEATURE_ORDER ,
" market_plus_state " : [ * FEATURE_ORDER , * STATE_FEATURES ] ,
}
results : dict [ str , Any ] = { }
prediction_frames : list [ pd . DataFrame ] = [ ]
for side in ( " LONG " , " SHORT " ) :
side_frame = state_frame [ state_frame [ " position_side " ] . eq ( side ) ] . copy ( )
for feature_set_name , feature_columns in feature_sets . items ( ) :
key = f " { side . lower ( ) } _ { feature_set_name } "
2026-06-27 23:06:43 +08:00
result , predictions = _train_side_models ( side_frame , side , feature_columns , regressor_kind , ridge_alpha , huber_max_iter , regression_target_clip_bps )
2026-06-27 20:28:31 +08:00
results [ key ] = result
predictions [ " side " ] = side
predictions [ " feature_set " ] = feature_set_name
prediction_frames . append ( predictions )
logging . info (
" trader.training.state_continue_model_trained runId= %s side= %s featureSet= %s tuneAuc= %s tuneMaeRatio= %s " ,
args . run_id ,
side ,
feature_set_name ,
result . get ( TUNE_SPLIT , { } ) . get ( " continue_auc " ) ,
result . get ( TUNE_SPLIT , { } ) . get ( " edge_mae_vs_constant_ratio " ) ,
)
predictions = pd . concat ( prediction_frames , ignore_index = True ) if prediction_frames else pd . DataFrame ( )
2026-06-27 23:06:43 +08:00
verdict = _verdict ( results )
2026-06-27 20:28:31 +08:00
write_parquet ( out_dir / " state_continue_predictions.parquet " , predictions )
write_json ( out_dir / " state_continue_result.json " , results )
2026-06-27 23:06:43 +08:00
write_json ( out_dir / " state_continue_verdict.json " , verdict )
write_text ( out_dir / " state_continue_experiment_report.md " , _report ( args , baseline_root , source_manifest , results , verdict ) )
2026-06-27 20:28:31 +08:00
logging . info ( " trader.training.state_continue_experiment_finished runId= %s report= %s " , args . run_id , out_dir / " state_continue_experiment_report.md " )
def _parse_ages ( raw : str ) - > list [ int ] :
ages = [ int ( item . strip ( ) ) for item in raw . split ( " , " ) if item . strip ( ) ]
if not ages or any ( age < = 0 for age in ages ) :
raise ValueError ( f " invalid ages-minutes: { raw } " )
return sorted ( set ( ages ) )
def _load_feature_frame ( baseline_root : Path ) - > pd . DataFrame :
feature = read_parquet ( baseline_root / " feature " / " feature_frame.parquet " )
required = { " sample_id " , " symbol " , " event_time " , " open_time_ms " , " split_id " , " walk_forward_fold " , " data_quality_flag " , * FEATURE_ORDER }
missing = sorted ( required . difference ( feature . columns ) )
if missing :
raise ValueError ( f " baseline feature frame missing columns: { missing } " )
feature = feature [ feature [ " data_quality_flag " ] . isin ( [ " OK " , " PARTIAL_OPTIONAL " ] ) ] . copy ( )
feature = feature [ feature [ " split_id " ] . isin ( ALL_SPLITS ) ] . copy ( )
return feature
2026-06-27 23:06:43 +08:00
def _load_entry_labels ( baseline_root : Path , feature : pd . DataFrame ) - > pd . DataFrame :
2026-06-27 20:28:31 +08:00
entry = read_parquet ( baseline_root / " label " / " entry_labels.parquet " )
required = { " sample_id " , " symbol " , " event_time " , " side " , " entry_target " , " split_id " , " walk_forward_fold " }
missing = sorted ( required . difference ( entry . columns ) )
if missing :
raise ValueError ( f " baseline entry labels missing columns: { missing } " )
entry = entry [ ( entry [ " entry_target " ] == 1 ) & ( entry [ " side " ] . isin ( [ " LONG " , " SHORT " ] ) ) ] . copy ( )
entry [ " entry_open_time_ms " ] = pd . to_datetime ( entry [ " event_time " ] , utc = True ) . astype ( " int64 " ) / / 1_000_000
2026-06-27 23:06:43 +08:00
entry_scores = _frozen_entry_scores_by_sample ( baseline_root , feature )
entry = entry . merge ( entry_scores , on = " sample_id " , how = " inner " )
if entry . empty :
raise ValueError ( " state continue entry set is empty after merging frozen baseline model outputs " )
long_mask = entry [ " side " ] . eq ( " LONG " )
entry [ " entry_predicted_edge_bps " ] = np . where (
long_mask ,
entry [ " frozen_long_expected_net_edge_bps " ] ,
entry [ " frozen_short_expected_net_edge_bps " ] ,
)
entry [ " entry_direction_prob " ] = np . where ( long_mask , entry [ " frozen_long_prob " ] , entry [ " frozen_short_prob " ] )
return entry [ [ " sample_id " , " symbol " , " event_time " , " side " , " entry_open_time_ms " , " entry_predicted_edge_bps " , " entry_direction_prob " ] ] . copy ( )
def _frozen_entry_scores_by_sample ( baseline_root : Path , feature : pd . DataFrame ) - > pd . DataFrame :
source = feature [ [ " sample_id " , * FEATURE_ORDER ] ] . drop_duplicates ( " sample_id " ) . copy ( )
direction = _predict_frozen_linear_model (
baseline_root / " model " / " direction " / " direction.onnx " ,
source ,
{
" direction " : ( " softmax " , ( " frozen_long_prob " , " frozen_short_prob " , " frozen_neutral_prob " ) ) ,
} ,
)
entry = _predict_frozen_linear_model (
baseline_root / " model " / " entry " / " entry.onnx " ,
source ,
{
" long_entry_prob " : ( " sigmoid " , ( " frozen_long_entry_prob " , ) ) ,
" short_entry_prob " : ( " sigmoid " , ( " frozen_short_entry_prob " , ) ) ,
" long_expected_net_edge_bps " : ( " identity " , ( " frozen_long_expected_net_edge_bps " , ) ) ,
" short_expected_net_edge_bps " : ( " identity " , ( " frozen_short_expected_net_edge_bps " , ) ) ,
} ,
)
return direction . merge ( entry , on = " sample_id " , how = " inner " )
def _predict_frozen_linear_model ( model_path : Path , frame : pd . DataFrame , heads : dict [ str , tuple [ str , tuple [ str , . . . ] ] ] ) - > pd . DataFrame :
try :
import onnx
from onnx import numpy_helper
except ModuleNotFoundError as exc :
raise SystemExit ( " Python package ' onnx ' is required to read frozen baseline ONNX weights. " ) from exc
if not model_path . is_file ( ) :
raise FileNotFoundError ( f " frozen model is missing: { model_path } " )
model = onnx . load ( model_path )
initializers = { item . name : numpy_helper . to_array ( item ) for item in model . graph . initializer }
x = frame [ FEATURE_ORDER ] . apply ( pd . to_numeric , errors = " coerce " ) . replace ( [ np . inf , - np . inf ] , np . nan ) . fillna ( 0.0 ) . astype ( " float32 " ) . to_numpy ( )
out = pd . DataFrame ( { " sample_id " : frame [ " sample_id " ] . to_numpy ( ) } )
for head_name , ( kind , output_columns ) in heads . items ( ) :
weight_name = f " { head_name } _W "
bias_name = f " { head_name } _B "
if weight_name not in initializers or bias_name not in initializers :
raise ValueError ( f " frozen model { model_path } is missing head initializers: { head_name } " )
values = x @ np . asarray ( initializers [ weight_name ] , dtype = np . float32 ) + np . asarray ( initializers [ bias_name ] , dtype = np . float32 ) . reshape ( 1 , - 1 )
if kind == " softmax " :
values = _softmax ( values )
elif kind == " sigmoid " :
values = _sigmoid ( values )
elif kind != " identity " :
raise ValueError ( f " unsupported frozen head kind: { kind } " )
if values . shape [ 1 ] != len ( output_columns ) :
raise ValueError ( f " head { head_name } output width mismatch: { values . shape [ 1 ] } != { len ( output_columns ) } " )
for index , column in enumerate ( output_columns ) :
out [ column ] = values [ : , index ] . astype ( " float32 " )
return out
def _softmax ( values : np . ndarray ) - > np . ndarray :
shifted = values - np . max ( values , axis = 1 , keepdims = True )
exp = np . exp ( shifted )
return exp / exp . sum ( axis = 1 , keepdims = True )
def _sigmoid ( values : np . ndarray ) - > np . ndarray :
clipped = np . clip ( values , - 50.0 , 50.0 )
return 1.0 / ( 1.0 + np . exp ( - clipped ) )
2026-06-27 20:28:31 +08:00
def _load_replay ( baseline_root : Path ) - > pd . DataFrame :
split_manifest = read_json ( baseline_root / " split " / " split_manifest.json " )
replay_path = Path ( split_manifest [ " source_replay_path " ] )
replay = read_parquet ( replay_path )
required = { " symbol " , " event_time " , " open_time_ms " , " high " , " low " , " close " , " spread_bps " }
missing = sorted ( required . difference ( replay . columns ) )
if missing :
raise ValueError ( f " source replay missing columns: { missing } " )
return replay . sort_values ( [ " symbol " , " open_time_ms " ] ) . reset_index ( drop = True )
def _build_state_frame (
feature : pd . DataFrame ,
entry : pd . DataFrame ,
replay : pd . DataFrame ,
ages : list [ int ] ,
stop_bps : float ,
target_bps : float ,
cost_bps : float ,
2026-06-27 23:06:43 +08:00
continue_horizon : int ,
min_continue_edge_bps : float ,
2026-06-27 20:28:31 +08:00
) - > pd . DataFrame :
2026-06-27 23:06:43 +08:00
future_stats = _build_path_stats ( replay , horizon = continue_horizon , target_bps = target_bps , stop_bps = stop_bps )
2026-06-27 20:28:31 +08:00
future_stats = future_stats . rename ( columns = { " open_time_ms " : " current_open_time_ms " } )
current_feature = feature . rename ( columns = { " sample_id " : " current_sample_id " , " event_time " : " current_event_time " , " open_time_ms " : " current_open_time_ms " } )
replay_state_source = _state_source_by_age ( replay , ages )
frames : list [ pd . DataFrame ] = [ ]
for age in ages :
candidates = entry . copy ( )
candidates [ " time_in_position_minutes " ] = age
2026-06-27 23:06:43 +08:00
candidates [ " add_count " ] = 0.0
candidates [ " minutes_since_last_add " ] = 9999.0
2026-06-27 20:28:31 +08:00
candidates [ " current_open_time_ms " ] = candidates [ " entry_open_time_ms " ] + age * 60_000
candidates = candidates . merge (
replay_state_source [ replay_state_source [ " time_in_position_minutes " ] . eq ( age ) ] ,
on = [ " symbol " , " current_open_time_ms " , " time_in_position_minutes " ] ,
how = " inner " ,
)
candidates = candidates . merge ( current_feature , on = [ " symbol " , " current_open_time_ms " ] , how = " inner " )
candidates = candidates . merge (
future_stats ,
left_on = [ " symbol " , " current_open_time_ms " , " side " ] ,
right_on = [ " symbol " , " current_open_time_ms " , " side " ] ,
how = " inner " ,
)
if candidates . empty :
continue
2026-06-27 23:06:43 +08:00
frames . append ( _state_rows_for_age ( candidates , stop_bps , target_bps , cost_bps , min_continue_edge_bps ) )
2026-06-27 20:28:31 +08:00
logging . info ( " trader.training.state_continue_age_built ageMinutes= %s rowCount= %s " , age , len ( candidates ) )
if not frames :
raise ValueError ( " state continue experiment produced no rows " )
out = pd . concat ( frames , ignore_index = True )
out = out . replace ( [ np . inf , - np . inf ] , np . nan )
required = [ * FEATURE_ORDER , * STATE_FEATURES , " continue_target " , " expected_continue_edge_bps " ]
out = out . dropna ( subset = required ) . copy ( )
return out
def _state_source_by_age ( replay : pd . DataFrame , ages : list [ int ] ) - > pd . DataFrame :
frames : list [ pd . DataFrame ] = [ ]
for _ , group in replay . groupby ( " symbol " , sort = False , observed = False ) :
group = group . sort_values ( " open_time_ms " ) . copy ( )
for age in ages :
rolling_high = group [ " high " ] . rolling ( age + 1 , min_periods = age + 1 ) . max ( )
rolling_low = group [ " low " ] . rolling ( age + 1 , min_periods = age + 1 ) . min ( )
frame = pd . DataFrame (
{
" symbol " : group [ " symbol " ] ,
" current_open_time_ms " : group [ " open_time_ms " ] ,
" time_in_position_minutes " : age ,
" entry_price " : group [ " close " ] . shift ( age ) ,
" current_price " : group [ " close " ] ,
" high_since_entry " : rolling_high ,
" low_since_entry " : rolling_low ,
}
)
frames . append ( frame . dropna ( ) )
return pd . concat ( frames , ignore_index = True ) if frames else pd . DataFrame ( )
2026-06-27 23:06:43 +08:00
def _state_rows_for_age ( frame : pd . DataFrame , stop_bps : float , target_bps : float , cost_bps : float , min_continue_edge_bps : float = 5.0 ) - > pd . DataFrame :
2026-06-27 20:28:31 +08:00
side_sign = np . where ( frame [ " side " ] . eq ( " LONG " ) , 1.0 , - 1.0 )
entry_price = frame [ " entry_price " ] . astype ( float )
current_price = frame [ " current_price " ] . astype ( float )
high_since = frame [ " high_since_entry " ] . astype ( float )
low_since = frame [ " low_since_entry " ] . astype ( float )
long_mask = frame [ " side " ] . eq ( " LONG " )
unrealized = np . where ( long_mask , ( current_price / entry_price - 1.0 ) * 10000.0 , ( entry_price / current_price - 1.0 ) * 10000.0 ) - cost_bps
mfe = np . where ( long_mask , ( high_since / entry_price - 1.0 ) * 10000.0 , ( entry_price / low_since - 1.0 ) * 10000.0 )
mae = np . where ( long_mask , ( entry_price / low_since - 1.0 ) * 10000.0 , ( high_since / entry_price - 1.0 ) * 10000.0 )
stop_price = np . where ( long_mask , entry_price * ( 1.0 - stop_bps / 10000.0 ) , entry_price * ( 1.0 + stop_bps / 10000.0 ) )
target_price = np . where ( long_mask , entry_price * ( 1.0 + target_bps / 10000.0 ) , entry_price * ( 1.0 - target_bps / 10000.0 ) )
distance_to_stop = np . where ( long_mask , ( current_price / stop_price - 1.0 ) * 10000.0 , ( stop_price / current_price - 1.0 ) * 10000.0 )
distance_to_target = np . where ( long_mask , ( target_price / current_price - 1.0 ) * 10000.0 , ( current_price / target_price - 1.0 ) * 10000.0 )
expected_edge = frame [ " future_return_bps " ] . astype ( float ) - cost_bps
2026-06-27 23:06:43 +08:00
continue_target = ( ( expected_edge > = min_continue_edge_bps ) & ( frame [ " mae_bps " ] . astype ( float ) < stop_bps ) ) . astype ( " int8 " )
2026-06-27 20:28:31 +08:00
out = frame [
[
" current_sample_id " ,
" symbol " ,
" current_event_time " ,
" current_open_time_ms " ,
" side " ,
" split_id " ,
" walk_forward_fold " ,
* FEATURE_ORDER ,
]
] . copy ( )
out = out . rename (
columns = {
" current_sample_id " : " sample_id " ,
" current_event_time " : " event_time " ,
" current_open_time_ms " : " open_time_ms " ,
" side " : " position_side " ,
}
)
out [ " position_side_sign " ] = side_sign . astype ( " float32 " )
out [ " time_in_position_minutes " ] = frame [ " time_in_position_minutes " ] . astype ( " float32 " )
out [ " unrealized_pnl_bps " ] = unrealized . astype ( " float32 " )
out [ " mfe_since_entry_bps " ] = np . maximum ( mfe , 0.0 ) . astype ( " float32 " )
out [ " mae_since_entry_bps " ] = np . maximum ( mae , 0.0 ) . astype ( " float32 " )
out [ " distance_to_stop_bps " ] = distance_to_stop . astype ( " float32 " )
out [ " distance_to_target_bps " ] = distance_to_target . astype ( " float32 " )
2026-06-27 23:06:43 +08:00
out [ " entry_predicted_edge_bps " ] = frame [ " entry_predicted_edge_bps " ] . astype ( " float32 " )
out [ " entry_direction_prob " ] = frame [ " entry_direction_prob " ] . astype ( " float32 " )
out [ " add_count " ] = frame [ " add_count " ] . astype ( " float32 " )
out [ " minutes_since_last_add " ] = frame [ " minutes_since_last_add " ] . astype ( " float32 " )
2026-06-27 20:28:31 +08:00
out [ " continue_target " ] = continue_target
out [ " expected_continue_edge_bps " ] = expected_edge . astype ( " float32 " )
return out
def _cap_rows_per_split ( frame : pd . DataFrame , max_rows_per_split : int ) - > pd . DataFrame :
capped = [ ]
for split_id , part in frame . sort_values ( " event_time " ) . groupby ( " split_id " , sort = False , observed = False ) :
if len ( part ) > max_rows_per_split :
part = part . tail ( max_rows_per_split ) . copy ( )
capped . append ( part )
logging . info ( " trader.training.state_continue_split_capped splitId= %s rowCount= %s maxRows= %s " , split_id , len ( part ) , max_rows_per_split )
return pd . concat ( capped , ignore_index = True )
2026-06-27 23:06:43 +08:00
def _train_side_models (
frame : pd . DataFrame ,
side : str ,
feature_columns : list [ str ] ,
regressor_kind : str = " huber " ,
ridge_alpha : float = 10.0 ,
huber_max_iter : int = 1000 ,
regression_target_clip_bps : float = 0.0 ,
) - > tuple [ dict [ str , Any ] , pd . DataFrame ] :
2026-06-27 20:28:31 +08:00
train = frame [ frame [ " split_id " ] . eq ( FIT_SPLIT ) ] . copy ( )
if train . empty :
raise ValueError ( f " state continue { side } has no fit_inner rows " )
scaler = StandardScaler ( )
x_train = scaler . fit_transform ( train [ feature_columns ] . astype ( " float32 " ) )
y_train_cls = train [ " continue_target " ] . astype ( int ) . to_numpy ( )
y_train_reg = train [ " expected_continue_edge_bps " ] . astype ( float ) . to_numpy ( )
2026-06-27 23:06:43 +08:00
y_train_fit = y_train_reg
if regression_target_clip_bps > 0 :
y_train_fit = np . clip ( y_train_reg , - regression_target_clip_bps , regression_target_clip_bps )
2026-06-27 20:28:31 +08:00
clf = LogisticRegression ( max_iter = 500 )
clf . fit ( x_train , y_train_cls )
2026-06-27 23:06:43 +08:00
reg_max_iter = huber_max_iter
if regressor_kind == " huber " :
reg = HuberRegressor ( alpha = 0.001 , epsilon = 1.35 , max_iter = reg_max_iter )
elif regressor_kind == " ridge " :
reg = Ridge ( alpha = ridge_alpha )
else :
raise ValueError ( f " unsupported state continue regressor kind: { regressor_kind } " )
reg . fit ( x_train , y_train_fit )
2026-06-27 20:28:31 +08:00
metrics : dict [ str , Any ] = { }
prediction_frames : list [ pd . DataFrame ] = [ ]
for split_id in ALL_SPLITS :
part = frame [ frame [ " split_id " ] . eq ( split_id ) ] . copy ( )
if part . empty :
continue
x = scaler . transform ( part [ feature_columns ] . astype ( " float32 " ) )
y_cls = part [ " continue_target " ] . astype ( int ) . to_numpy ( )
y_reg = part [ " expected_continue_edge_bps " ] . astype ( float ) . to_numpy ( )
proba = clf . predict_proba ( x ) [ : , 1 ]
pred_edge = reg . predict ( x )
2026-06-27 23:06:43 +08:00
if regression_target_clip_bps > 0 :
pred_edge = np . clip ( pred_edge , - regression_target_clip_bps , regression_target_clip_bps )
2026-06-27 20:28:31 +08:00
metrics [ split_id ] = _split_metrics ( y_train_cls , y_train_reg , y_cls , y_reg , proba , pred_edge )
2026-06-27 23:06:43 +08:00
pred_frame = part [
[
" sample_id " ,
" symbol " ,
" event_time " ,
" split_id " ,
" position_side " ,
" time_in_position_minutes " ,
" unrealized_pnl_bps " ,
" mfe_since_entry_bps " ,
" mae_since_entry_bps " ,
" entry_predicted_edge_bps " ,
" entry_direction_prob " ,
" continue_target " ,
" expected_continue_edge_bps " ,
]
] . copy ( )
2026-06-27 20:28:31 +08:00
pred_frame [ " continue_prob " ] = proba . astype ( " float32 " )
pred_frame [ " predicted_continue_edge_bps " ] = pred_edge . astype ( " float32 " )
prediction_frames . append ( pred_frame )
metrics [ " row_count " ] = int ( len ( frame ) )
metrics [ " feature_count " ] = len ( feature_columns )
metrics [ " feature_hash " ] = sha256_json ( feature_columns )
2026-06-27 23:06:43 +08:00
n_iter = getattr ( reg , " n_iter_ " , None )
metrics [ " regressor_kind " ] = regressor_kind
metrics [ " ridge_alpha " ] = ridge_alpha if regressor_kind == " ridge " else None
metrics [ " regressor_iterations " ] = int ( n_iter ) if n_iter is not None else 0
metrics [ " regressor_max_iter " ] = reg_max_iter
metrics [ " regressor_converged " ] = True if n_iter is None else 0 < = metrics [ " regressor_iterations " ] < reg_max_iter
metrics [ " regression_target_clip_bps " ] = regression_target_clip_bps if regression_target_clip_bps > 0 else None
2026-06-27 20:28:31 +08:00
return metrics , pd . concat ( prediction_frames , ignore_index = True )
def _split_metrics (
y_train_cls : np . ndarray ,
y_train_reg : np . ndarray ,
y_cls : np . ndarray ,
y_reg : np . ndarray ,
proba : np . ndarray ,
pred_edge : np . ndarray ,
) - > dict [ str , Any ] :
train_rate = float ( np . mean ( y_train_cls ) )
constant_proba = np . full ( len ( y_cls ) , train_rate )
train_median = float ( np . median ( y_train_reg ) )
constant_edge = np . full ( len ( y_reg ) , train_median )
out : dict [ str , Any ] = {
" row_count " : int ( len ( y_cls ) ) ,
" positive_rate " : float ( np . mean ( y_cls ) ) ,
" brier " : float ( brier_score_loss ( y_cls , proba ) ) ,
" constant_brier " : float ( brier_score_loss ( y_cls , constant_proba ) ) ,
" edge_mae " : float ( mean_absolute_error ( y_reg , pred_edge ) ) ,
" edge_constant_mae " : float ( mean_absolute_error ( y_reg , constant_edge ) ) ,
}
if len ( np . unique ( y_cls ) ) == 2 :
out [ " continue_auc " ] = float ( roc_auc_score ( y_cls , proba ) )
out [ " brier_vs_constant_ratio " ] = float ( out [ " brier " ] / out [ " constant_brier " ] ) if out [ " constant_brier " ] > 0 else None
out [ " edge_mae_vs_constant_ratio " ] = float ( out [ " edge_mae " ] / out [ " edge_constant_mae " ] ) if out [ " edge_constant_mae " ] > 0 else None
return out
def _source_manifest (
args : Any ,
baseline_root : Path ,
ages : list [ int ] ,
stop_bps : float ,
target_bps : float ,
cost_bps : float ,
2026-06-27 23:06:43 +08:00
continue_horizon : int ,
min_continue_edge_bps : float ,
2026-06-27 20:28:31 +08:00
state_frame : pd . DataFrame ,
dataset_hash : str ,
2026-06-27 23:06:43 +08:00
regressor_kind : str ,
ridge_alpha : float ,
huber_max_iter : int ,
regression_target_clip_bps : float ,
2026-06-27 20:28:31 +08:00
) - > dict [ str , Any ] :
return {
" experiment " : " state_continue_diagnostic_v1 " ,
" run_id " : args . run_id ,
" baseline_run_id " : args . baseline_run_id ,
" baseline_root " : str ( baseline_root ) ,
" ages_minutes " : ages ,
" target_bps " : target_bps ,
" stop_bps " : stop_bps ,
" cost_bps " : cost_bps ,
2026-06-27 23:06:43 +08:00
" continue_horizon_minutes " : continue_horizon ,
" min_continue_edge_bps " : min_continue_edge_bps ,
" regressor_kind " : regressor_kind ,
" ridge_alpha " : ridge_alpha if regressor_kind == " ridge " else None ,
" huber_max_iter " : huber_max_iter if regressor_kind == " huber " else None ,
" regression_target_clip_bps " : regression_target_clip_bps if regression_target_clip_bps > 0 else None ,
2026-06-27 20:28:31 +08:00
" dataset_hash_sha256 " : dataset_hash ,
" row_count " : int ( len ( state_frame ) ) ,
" split_counts " : state_frame [ " split_id " ] . value_counts ( ) . to_dict ( ) ,
" side_counts " : state_frame [ " position_side " ] . value_counts ( ) . to_dict ( ) ,
" feature_inputs " : {
" market_feature_count " : len ( FEATURE_ORDER ) ,
" state_features " : STATE_FEATURES ,
" state_feature_count " : len ( STATE_FEATURES ) ,
} ,
" leakage_policy " : {
" uses_future_entry_label_as_feature " : False ,
" uses_same_round_model_prediction_as_feature " : False ,
2026-06-27 23:06:43 +08:00
" entry_predicted_edge_bps " : " baseline frozen ENTRY ONNX output selected by side " ,
" entry_direction_prob " : " baseline frozen DIRECTION ONNX output selected by side " ,
" add_count " : " synthetic first-position diagnostic, fixed to 0 " ,
" minutes_since_last_add " : " synthetic first-position diagnostic, fixed to 9999 " ,
2026-06-27 20:28:31 +08:00
} ,
}
def _state_feature_schema ( ) - > list [ dict [ str , Any ] ] :
return [
{ " name " : " position_side_sign " , " unit " : " -1/1 " , " source " : " synthetic position state " , " leakage_check " : " known at current position time " } ,
{ " name " : " time_in_position_minutes " , " unit " : " minute " , " source " : " entry time to current time " , " leakage_check " : " known at current position time " } ,
{ " name " : " unrealized_pnl_bps " , " unit " : " bps " , " source " : " entry price and current close " , " leakage_check " : " uses <= current time price " } ,
{ " name " : " mfe_since_entry_bps " , " unit " : " bps " , " source " : " high since entry " , " leakage_check " : " uses only entry..current high " } ,
{ " name " : " mae_since_entry_bps " , " unit " : " bps " , " source " : " low/high since entry " , " leakage_check " : " uses only entry..current low/high " } ,
{ " name " : " distance_to_stop_bps " , " unit " : " bps " , " source " : " price plan and current close " , " leakage_check " : " uses fixed plan and current price " } ,
{ " name " : " distance_to_target_bps " , " unit " : " bps " , " source " : " price plan and current close " , " leakage_check " : " uses fixed plan and current price " } ,
2026-06-27 23:06:43 +08:00
{ " name " : " entry_predicted_edge_bps " , " unit " : " bps " , " source " : " baseline frozen ENTRY ONNX " , " leakage_check " : " baseline model is fixed before this experiment " } ,
{ " name " : " entry_direction_prob " , " unit " : " probability " , " source " : " baseline frozen DIRECTION ONNX " , " leakage_check " : " baseline model is fixed before this experiment " } ,
{ " name " : " add_count " , " unit " : " count " , " source " : " synthetic position state " , " leakage_check " : " known at current position time " } ,
{ " name " : " minutes_since_last_add " , " unit " : " minute " , " source " : " synthetic position state " , " leakage_check " : " known at current position time " } ,
2026-06-27 20:28:31 +08:00
]
2026-06-27 23:06:43 +08:00
def _verdict ( results : dict [ str , Any ] ) - > dict [ str , Any ] :
reasons : list [ str ] = [ ]
passed_checks : list [ str ] = [ ]
for side in ( " long " , " short " ) :
plus = results [ f " { side } _market_plus_state " ]
base = results [ f " { side } _market_only " ]
if not plus . get ( " regressor_converged " ) :
reasons . append ( f " { side } market_plus_state regressor did not converge " )
for split_id in ( VALIDATION_LOCKED_SPLIT , LATEST_STRESS_SPLIT ) :
plus_metric = plus . get ( split_id , { } )
base_metric = base . get ( split_id , { } )
plus_auc = plus_metric . get ( " continue_auc " )
base_auc = base_metric . get ( " continue_auc " )
plus_mae = plus_metric . get ( " edge_mae_vs_constant_ratio " )
base_mae = base_metric . get ( " edge_mae_vs_constant_ratio " )
if plus_auc is None or plus_auc < 0.60 :
reasons . append ( f " { side } { split_id } continue_auc below 0.60: { plus_auc } " )
elif base_auc is not None and plus_auc < = base_auc :
reasons . append ( f " { side } { split_id } continue_auc not better than market_only: { plus_auc } <= { base_auc } " )
else :
passed_checks . append ( f " { side } { split_id } continue_auc " )
if plus_mae is None or plus_mae > 0.97 :
reasons . append ( f " { side } { split_id } edge_mae_vs_constant_ratio above 0.97: { plus_mae } " )
elif base_mae is not None and plus_mae > = base_mae :
reasons . append ( f " { side } { split_id } edge_mae_vs_constant_ratio not better than market_only: { plus_mae } >= { base_mae } " )
else :
passed_checks . append ( f " { side } { split_id } edge_mae_vs_constant_ratio " )
return {
" status " : " PASS_TO_FORMAL_CHAIN " if not reasons else " NOT_READY_FOR_FORMAL_CHAIN " ,
" acceptance_rule " : {
" validation_and_latest_auc_min " : 0.60 ,
" validation_and_latest_edge_mae_vs_constant_max " : 0.97 ,
" must_beat_market_only " : True ,
" regressor_must_converge " : True ,
} ,
" passed_checks " : passed_checks ,
" reasons " : reasons ,
}
def _report ( args : Any , baseline_root : Path , manifest : dict [ str , Any ] , results : dict [ str , Any ] , verdict : dict [ str , Any ] ) - > str :
2026-06-27 20:28:31 +08:00
baseline = read_json ( baseline_root / " model " / " model_train_manifest.json " )
continue_metrics = baseline [ " CONTINUE " ] [ " metrics " ]
lines = [
" # State Continue Experiment Report " ,
" " ,
f " - run_id: ` { args . run_id } ` " ,
f " - baseline_run_id: ` { args . baseline_run_id } ` " ,
f " - row_count: ` { manifest [ ' row_count ' ] } ` " ,
f " - ages_minutes: ` { manifest [ ' ages_minutes ' ] } ` " ,
2026-06-27 23:06:43 +08:00
f " - regressor_kind: ` { manifest [ ' regressor_kind ' ] } ` " ,
f " - huber_max_iter: ` { manifest [ ' huber_max_iter ' ] } ` " ,
f " - regression_target_clip_bps: ` { manifest [ ' regression_target_clip_bps ' ] } ` " ,
f " - continue_horizon_minutes: ` { manifest [ ' continue_horizon_minutes ' ] } ` " ,
f " - min_continue_edge_bps: ` { manifest [ ' min_continue_edge_bps ' ] } ` " ,
2026-06-27 20:28:31 +08:00
" " ,
" ## Baseline run-10 Continue " ,
" " ,
" | head | auc | mae_vs_constant | " ,
" | --- | ---: | ---: | " ,
f " | long_continue_prob | { continue_metrics [ ' long_continue_prob ' ] . get ( ' auc ' ) } | | " ,
f " | short_continue_prob | { continue_metrics [ ' short_continue_prob ' ] . get ( ' auc ' ) } | | " ,
f " | long_expected_continue_edge_bps | | { continue_metrics [ ' long_expected_continue_edge_bps ' ] . get ( ' mae_vs_constant_ratio ' ) } | " ,
f " | short_expected_continue_edge_bps | | { continue_metrics [ ' short_expected_continue_edge_bps ' ] . get ( ' mae_vs_constant_ratio ' ) } | " ,
" " ,
" ## Diagnostic Result " ,
" " ,
" | side | feature_set | split | rows | auc | brier_ratio | mae_ratio | " ,
" | --- | --- | --- | ---: | ---: | ---: | ---: | " ,
]
for key , item in results . items ( ) :
side , feature_set = key . split ( " _ " , 1 )
for split_id in EVAL_SPLITS :
metric = item . get ( split_id , { } )
lines . append (
f " | { side . upper ( ) } | { feature_set } | { split_id } | { metric . get ( ' row_count ' ) } | { metric . get ( ' continue_auc ' ) } | { metric . get ( ' brier_vs_constant_ratio ' ) } | { metric . get ( ' edge_mae_vs_constant_ratio ' ) } | "
)
lines . extend (
[
" " ,
" ## Verdict Rule " ,
" " ,
" 状态特征只有在 `market_plus_state` 同时好过 `market_only`,并且 validation_locked / latest_stress 没有反向变差时,才进入正式链路。 " ,
" " ,
2026-06-27 23:06:43 +08:00
" ## Verdict " ,
" " ,
f " - status: ` { verdict [ ' status ' ] } ` " ,
f " - reasons: ` { len ( verdict [ ' reasons ' ] ) } ` " ,
" " ,
2026-06-27 20:28:31 +08:00
]
)
2026-06-27 23:06:43 +08:00
for reason in verdict [ " reasons " ] :
lines . append ( f " - { reason } " )
if verdict [ " passed_checks " ] :
lines . extend ( [ " " , " ## Passed Checks " , " " ] )
for item in verdict [ " passed_checks " ] :
lines . append ( f " - { item } " )
2026-06-27 20:28:31 +08:00
return " \n " . join ( lines )