Add dynamic exit plan search diagnostics
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
from trader_training.dynamic_exit_search import search_dynamic_exit_plans
|
||||
from trader_training.io_utils import add_common_args, setup_logging
|
||||
|
||||
|
||||
def _float_tuple(value: str) -> tuple[float, ...]:
|
||||
return tuple(float(item.strip()) for item in value.split(",") if item.strip())
|
||||
|
||||
|
||||
def _int_tuple(value: str) -> tuple[int, ...]:
|
||||
return tuple(int(item.strip()) for item in value.split(",") if item.strip())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_common_args(parser)
|
||||
parser.add_argument("--feature-path", type=Path)
|
||||
parser.add_argument("--replay-path", type=Path)
|
||||
parser.add_argument("--label-config-path", type=Path)
|
||||
parser.add_argument("--cost-config-path", type=Path)
|
||||
parser.add_argument("--horizons", type=_int_tuple)
|
||||
parser.add_argument("--targets", type=_float_tuple)
|
||||
parser.add_argument("--stops", type=_float_tuple)
|
||||
parser.add_argument("--trailing-stops", type=_float_tuple)
|
||||
parser.add_argument("--second-target-multipliers", type=_float_tuple)
|
||||
parser.add_argument("--take1-ratios", type=_float_tuple)
|
||||
parser.add_argument("--take2-ratios", type=_float_tuple)
|
||||
args = parser.parse_args()
|
||||
setup_logging()
|
||||
search_dynamic_exit_plans(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,6 +14,7 @@ if str(TRAINING_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(TRAINING_ROOT))
|
||||
|
||||
from trader_training.onnx_export import LinearHead, export_heads
|
||||
from trader_training.dynamic_exit_search import search_dynamic_exit_plans
|
||||
from trader_training.entry_feature_screen import _bucket_edges, _screen_edge_column
|
||||
from trader_training.io_utils import read_json, write_json
|
||||
from trader_training.labels import ENTRY_LABEL_METHOD, _path_stats_for_group, build_entry_labels
|
||||
@@ -56,6 +57,65 @@ class TrainingContractTest(unittest.TestCase):
|
||||
self.assertEqual(-np.inf, edges[0])
|
||||
self.assertEqual(np.inf, edges[-1])
|
||||
|
||||
def test_dynamic_exit_search_writes_plan_diagnostics(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
run_root = data_root / "trader-v4" / "runs" / "unit-dynamic-exit"
|
||||
feature_path = run_root / "feature" / "feature_frame.parquet"
|
||||
replay_path = run_root / "replay" / "replay_1m.parquet"
|
||||
config_path = data_root / "label_config.json"
|
||||
feature_path.parent.mkdir(parents=True)
|
||||
replay_path.parent.mkdir(parents=True)
|
||||
|
||||
times = pd.date_range("2026-01-01", periods=7, freq="min", tz="UTC")
|
||||
pd.DataFrame(
|
||||
{
|
||||
"sample_id": [f"s{i}" for i in range(4)],
|
||||
"symbol": ["BTC-USDT-PERP"] * 4,
|
||||
"event_time": times[:4],
|
||||
"open_time_ms": [0, 60_000, 120_000, 180_000],
|
||||
"split_id": ["tune_inner", "validation_locked", "latest_stress", "fit_inner"],
|
||||
"walk_forward_fold": [0, 0, 0, 0],
|
||||
"data_quality_flag": ["OK", "OK", "OK", "OK"],
|
||||
}
|
||||
).to_parquet(feature_path, index=False)
|
||||
pd.DataFrame(
|
||||
{
|
||||
"event_time": times,
|
||||
"open_time_ms": [0, 60_000, 120_000, 180_000, 240_000, 300_000, 360_000],
|
||||
"symbol": ["BTC-USDT-PERP"] * 7,
|
||||
"open": [100.0] * 7,
|
||||
"high": [100.0, 100.12, 100.22, 100.24, 100.24, 100.24, 100.24],
|
||||
"low": [100.0, 100.00, 100.00, 100.00, 100.00, 100.00, 100.00],
|
||||
"close": [100.0, 100.10, 100.18, 100.20, 100.20, 100.20, 100.20],
|
||||
"spread_bps": [1.0] * 7,
|
||||
}
|
||||
).to_parquet(replay_path, index=False)
|
||||
write_json(config_path, {"entry": {"min_expected_net_edge_bps": 3.0}})
|
||||
|
||||
search_dynamic_exit_plans(
|
||||
Namespace(
|
||||
data_root=data_root,
|
||||
run_id="unit-dynamic-exit",
|
||||
feature_path=feature_path,
|
||||
replay_path=replay_path,
|
||||
label_config_path=config_path,
|
||||
cost_config_path=None,
|
||||
horizons=(3,),
|
||||
targets=(10.0,),
|
||||
stops=(5.0,),
|
||||
trailing_stops=(4.0,),
|
||||
second_target_multipliers=(2.0,),
|
||||
take1_ratios=(0.5,),
|
||||
take2_ratios=(0.25,),
|
||||
)
|
||||
)
|
||||
|
||||
result = read_json(run_root / "dynamic-exit-search" / "dynamic_exit_search_result.json")
|
||||
self.assertEqual("DYNAMIC_TRAILING_V1", result["best_plan"]["plan_method"])
|
||||
self.assertEqual(1, result["candidate_count"])
|
||||
self.assertTrue((run_root / "dynamic-exit-search" / "dynamic_exit_search_report.md").is_file())
|
||||
|
||||
def test_split_builder_uses_locked_validation_contract(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
data_root = Path(tmp)
|
||||
|
||||
@@ -0,0 +1,356 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from trader_training.io_utils import read_parquet, run_root, write_json, write_text
|
||||
from trader_training.labels import DEFAULT_COST_CONFIG, DEFAULT_LABEL_CONFIG, ENTRY_LABEL_METHOD, _build_path_stats, _load_config
|
||||
from trader_training.schemas import FIT_SPLIT, LATEST_STRESS_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT
|
||||
|
||||
|
||||
EVAL_SPLITS = (FIT_SPLIT, TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT)
|
||||
GATE_SPLITS = (TUNE_SPLIT, VALIDATION_LOCKED_SPLIT, LATEST_STRESS_SPLIT)
|
||||
DEFAULT_HORIZONS = (30, 45, 60)
|
||||
DEFAULT_TARGETS = (8.0, 12.0, 16.0)
|
||||
DEFAULT_STOPS = (4.0, 6.0, 8.0)
|
||||
DEFAULT_TRAILING_STOPS = (4.0, 8.0, 12.0)
|
||||
DEFAULT_SECOND_TARGET_MULTIPLIERS = (2.0,)
|
||||
DEFAULT_TAKE1_RATIOS = (0.50,)
|
||||
DEFAULT_TAKE2_RATIOS = (0.25,)
|
||||
|
||||
|
||||
def search_dynamic_exit_plans(args: Any) -> None:
|
||||
root = run_root(args)
|
||||
replay = read_parquet(args.replay_path or root / "replay" / "replay_1m.parquet")
|
||||
features = read_parquet(args.feature_path or root / "feature" / "feature_frame.parquet")
|
||||
label_config = _load_config(args.label_config_path, DEFAULT_LABEL_CONFIG)
|
||||
cost_config = _load_config(args.cost_config_path, DEFAULT_COST_CONFIG)
|
||||
entry_config = label_config["entry"]
|
||||
cost_bps = float(cost_config["fee_bps"]) + float(cost_config["slippage_bps"]) + float(cost_config["funding_cost_bps"])
|
||||
min_expected_edge_bps = float(entry_config["min_expected_net_edge_bps"])
|
||||
|
||||
trainable = features[
|
||||
features["data_quality_flag"].isin(["OK", "PARTIAL_OPTIONAL"])
|
||||
& features["split_id"].isin(EVAL_SPLITS)
|
||||
][["symbol", "open_time_ms", "split_id"]].copy()
|
||||
if trainable.empty:
|
||||
raise ValueError("dynamic exit search needs trainable feature rows")
|
||||
|
||||
grid = list(
|
||||
itertools.product(
|
||||
args.horizons or DEFAULT_HORIZONS,
|
||||
args.targets or DEFAULT_TARGETS,
|
||||
args.stops or DEFAULT_STOPS,
|
||||
args.trailing_stops or DEFAULT_TRAILING_STOPS,
|
||||
args.second_target_multipliers or DEFAULT_SECOND_TARGET_MULTIPLIERS,
|
||||
args.take1_ratios or DEFAULT_TAKE1_RATIOS,
|
||||
args.take2_ratios or DEFAULT_TAKE2_RATIOS,
|
||||
)
|
||||
)
|
||||
if not grid:
|
||||
raise ValueError("dynamic exit search grid is empty")
|
||||
|
||||
logging.info(
|
||||
"trader.training.dynamic_exit_search_started runId=%s candidateCount=%s",
|
||||
args.run_id,
|
||||
len(grid),
|
||||
)
|
||||
rows: list[dict[str, Any]] = []
|
||||
for index, (horizon, target_bps, stop_bps, trailing_stop_bps, second_multiplier, take1_ratio, take2_ratio) in enumerate(grid, start=1):
|
||||
second_target_bps = float(target_bps) * float(second_multiplier)
|
||||
plan_id = _plan_id(horizon, target_bps, stop_bps, trailing_stop_bps, second_multiplier, take1_ratio, take2_ratio)
|
||||
plan_config = {
|
||||
"plan_method": "DYNAMIC_TRAILING_V1",
|
||||
"partial_take_1_ratio": float(take1_ratio),
|
||||
"partial_take_2_ratio": float(take2_ratio),
|
||||
"second_target_bps": second_target_bps,
|
||||
"trailing_stop_bps": float(trailing_stop_bps),
|
||||
"breakeven_after_first_target": True,
|
||||
}
|
||||
logging.info(
|
||||
"trader.training.dynamic_exit_candidate_start runId=%s candidateIndex=%s candidateCount=%s planId=%s",
|
||||
args.run_id,
|
||||
index,
|
||||
len(grid),
|
||||
plan_id,
|
||||
)
|
||||
stats = _build_path_stats(replay, int(horizon), float(target_bps), float(stop_bps), plan_config=plan_config)
|
||||
merged = stats.merge(trainable, on=["symbol", "open_time_ms"], how="inner")
|
||||
if merged.empty:
|
||||
logging.info("trader.training.dynamic_exit_candidate_skipped runId=%s planId=%s reason=no_trainable_rows", args.run_id, plan_id)
|
||||
continue
|
||||
merged["actual_net_edge_bps"] = merged["gross_edge_bps"].astype("float64") - cost_bps
|
||||
rows.extend(
|
||||
_candidate_rows(
|
||||
merged,
|
||||
plan_id,
|
||||
int(horizon),
|
||||
float(target_bps),
|
||||
float(stop_bps),
|
||||
float(trailing_stop_bps),
|
||||
second_target_bps,
|
||||
float(second_multiplier),
|
||||
float(take1_ratio),
|
||||
float(take2_ratio),
|
||||
cost_bps,
|
||||
min_expected_edge_bps,
|
||||
)
|
||||
)
|
||||
logging.info(
|
||||
"trader.training.dynamic_exit_candidate_done runId=%s planId=%s mergedRows=%s",
|
||||
args.run_id,
|
||||
plan_id,
|
||||
len(merged),
|
||||
)
|
||||
|
||||
result = pd.DataFrame(rows)
|
||||
if result.empty:
|
||||
raise ValueError("dynamic exit search produced no candidate rows")
|
||||
|
||||
summary = _plan_summary(result)
|
||||
best = _select_best_plan(summary)
|
||||
payload = {
|
||||
"run_id": args.run_id,
|
||||
"cost_bps": cost_bps,
|
||||
"min_expected_net_edge_bps": min_expected_edge_bps,
|
||||
"entry_label_method": ENTRY_LABEL_METHOD,
|
||||
"candidate_count": int(summary["plan_id"].nunique()),
|
||||
"robust_candidate_found": bool(best["robust_candidate_found"]),
|
||||
"best_plan": best,
|
||||
}
|
||||
out_dir = root / "dynamic-exit-search"
|
||||
write_json(out_dir / "dynamic_exit_search_result.json", _jsonable(payload))
|
||||
write_text(out_dir / "dynamic_exit_search_rows.csv", result.to_csv(index=False))
|
||||
write_text(out_dir / "dynamic_exit_search_summary.csv", summary.to_csv(index=False))
|
||||
write_text(out_dir / "dynamic_exit_search_report.md", _markdown_report(payload, summary))
|
||||
logging.info(
|
||||
"trader.training.dynamic_exit_search_finished runId=%s candidateCount=%s bestPlan=%s robust=%s bestScore=%.6f",
|
||||
args.run_id,
|
||||
payload["candidate_count"],
|
||||
best["plan_id"],
|
||||
best["robust_candidate_found"],
|
||||
best["score"],
|
||||
)
|
||||
|
||||
|
||||
def _candidate_rows(
|
||||
frame: pd.DataFrame,
|
||||
plan_id: str,
|
||||
horizon: int,
|
||||
target_bps: float,
|
||||
stop_bps: float,
|
||||
trailing_stop_bps: float,
|
||||
second_target_bps: float,
|
||||
second_target_multiplier: float,
|
||||
take1_ratio: float,
|
||||
take2_ratio: float,
|
||||
cost_bps: float,
|
||||
min_expected_edge_bps: float,
|
||||
) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
for split_id, side in itertools.product(EVAL_SPLITS, ("LONG", "SHORT")):
|
||||
mask = frame["split_id"].eq(split_id) & frame["side"].eq(side)
|
||||
if not mask.any():
|
||||
continue
|
||||
part = frame.loc[mask]
|
||||
actual = part["actual_net_edge_bps"].astype("float64")
|
||||
rows.append(
|
||||
{
|
||||
"plan_id": plan_id,
|
||||
"split_id": split_id,
|
||||
"side": side,
|
||||
"horizon_minutes": horizon,
|
||||
"target_bps": target_bps,
|
||||
"stop_bps": stop_bps,
|
||||
"trailing_stop_bps": trailing_stop_bps,
|
||||
"second_target_bps": second_target_bps,
|
||||
"second_target_multiplier": second_target_multiplier,
|
||||
"partial_take_1_ratio": take1_ratio,
|
||||
"partial_take_2_ratio": take2_ratio,
|
||||
"cost_bps": cost_bps,
|
||||
"rows": int(len(part)),
|
||||
"avg_actual_net_edge_bps": float(actual.mean()),
|
||||
"median_actual_net_edge_bps": float(actual.median()),
|
||||
"p10_actual_net_edge_bps": float(actual.quantile(0.10)),
|
||||
"p90_actual_net_edge_bps": float(actual.quantile(0.90)),
|
||||
"positive_label_rate": float((actual >= min_expected_edge_bps).mean()),
|
||||
"breakeven_rate": float((actual >= 0.0).mean()),
|
||||
"target_hit_rate": float(part["target_hit"].mean()),
|
||||
"stop_hit_rate": float(part["stop_hit"].mean()),
|
||||
"timeout_rate": float(part["timeout_hit"].mean()),
|
||||
"avg_time_to_exit_min": float(part["time_to_exit_ms"].mean() / 60_000.0),
|
||||
"avg_mfe_bps": float(part["mfe_bps"].mean()),
|
||||
"avg_mae_bps": float(part["mae_bps"].mean()),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def _plan_summary(rows: pd.DataFrame) -> pd.DataFrame:
|
||||
group_cols = [
|
||||
"plan_id",
|
||||
"horizon_minutes",
|
||||
"target_bps",
|
||||
"stop_bps",
|
||||
"trailing_stop_bps",
|
||||
"second_target_bps",
|
||||
"second_target_multiplier",
|
||||
"partial_take_1_ratio",
|
||||
"partial_take_2_ratio",
|
||||
"side",
|
||||
]
|
||||
metrics = [
|
||||
"avg_actual_net_edge_bps",
|
||||
"median_actual_net_edge_bps",
|
||||
"positive_label_rate",
|
||||
"breakeven_rate",
|
||||
"target_hit_rate",
|
||||
"stop_hit_rate",
|
||||
"timeout_rate",
|
||||
"avg_time_to_exit_min",
|
||||
"avg_mfe_bps",
|
||||
"avg_mae_bps",
|
||||
]
|
||||
split_rows = rows.pivot_table(index=group_cols, columns="split_id", values=metrics, aggfunc="mean")
|
||||
split_rows.columns = [f"{metric}_{split}" for metric, split in split_rows.columns]
|
||||
split_rows = split_rows.reset_index()
|
||||
for split_id in EVAL_SPLITS:
|
||||
for metric in metrics:
|
||||
column = f"{metric}_{split_id}"
|
||||
if column not in split_rows.columns:
|
||||
split_rows[column] = np.nan
|
||||
|
||||
edge_cols = [f"avg_actual_net_edge_bps_{split}" for split in GATE_SPLITS]
|
||||
breakeven_cols = [f"breakeven_rate_{split}" for split in GATE_SPLITS]
|
||||
positive_cols = [f"positive_label_rate_{split}" for split in GATE_SPLITS]
|
||||
stop_cols = [f"stop_hit_rate_{split}" for split in GATE_SPLITS]
|
||||
split_rows["avg_actual_edge_eval"] = split_rows[edge_cols].mean(axis=1)
|
||||
split_rows["min_actual_edge_eval"] = split_rows[edge_cols].min(axis=1)
|
||||
split_rows["min_breakeven_rate_eval"] = split_rows[breakeven_cols].min(axis=1)
|
||||
split_rows["min_positive_label_rate_eval"] = split_rows[positive_cols].min(axis=1)
|
||||
split_rows["max_positive_label_rate_eval"] = split_rows[positive_cols].max(axis=1)
|
||||
split_rows["max_stop_hit_rate_eval"] = split_rows[stop_cols].max(axis=1)
|
||||
split_rows["score"] = (
|
||||
split_rows["avg_actual_edge_eval"].fillna(-999.0) * 8.0
|
||||
+ split_rows["min_actual_edge_eval"].fillna(-999.0) * 4.0
|
||||
+ split_rows["min_breakeven_rate_eval"].fillna(0.0) * 20.0
|
||||
+ split_rows["min_positive_label_rate_eval"].fillna(0.0) * 20.0
|
||||
- split_rows["max_stop_hit_rate_eval"].fillna(1.0) * 8.0
|
||||
)
|
||||
return split_rows.sort_values("score", ascending=False).reset_index(drop=True)
|
||||
|
||||
|
||||
def _select_best_plan(summary: pd.DataFrame) -> dict[str, Any]:
|
||||
robust = summary[
|
||||
(summary["avg_actual_edge_eval"] > 0.0)
|
||||
& (summary["min_actual_edge_eval"] > -1.0)
|
||||
& (summary["min_breakeven_rate_eval"] >= 0.45)
|
||||
& (summary["min_positive_label_rate_eval"] >= 0.03)
|
||||
& (summary["max_positive_label_rate_eval"] <= 0.55)
|
||||
].copy()
|
||||
robust_found = not robust.empty
|
||||
candidates = robust if robust_found else summary
|
||||
row = candidates.sort_values("score", ascending=False, na_position="last").iloc[0]
|
||||
return {
|
||||
"plan_id": str(row["plan_id"]),
|
||||
"plan_method": "DYNAMIC_TRAILING_V1",
|
||||
"side": str(row["side"]),
|
||||
"horizon_minutes": int(row["horizon_minutes"]),
|
||||
"target_bps": float(row["target_bps"]),
|
||||
"stop_bps": float(row["stop_bps"]),
|
||||
"trailing_stop_bps": float(row["trailing_stop_bps"]),
|
||||
"second_target_bps": float(row["second_target_bps"]),
|
||||
"second_target_multiplier": float(row["second_target_multiplier"]),
|
||||
"partial_take_1_ratio": float(row["partial_take_1_ratio"]),
|
||||
"partial_take_2_ratio": float(row["partial_take_2_ratio"]),
|
||||
"breakeven_after_first_target": True,
|
||||
"score": float(row["score"]),
|
||||
"avg_actual_edge_eval": float(row["avg_actual_edge_eval"]),
|
||||
"min_actual_edge_eval": float(row["min_actual_edge_eval"]),
|
||||
"min_breakeven_rate_eval": float(row["min_breakeven_rate_eval"]),
|
||||
"min_positive_label_rate_eval": float(row["min_positive_label_rate_eval"]),
|
||||
"max_positive_label_rate_eval": float(row["max_positive_label_rate_eval"]),
|
||||
"max_stop_hit_rate_eval": float(row["max_stop_hit_rate_eval"]),
|
||||
"robust_candidate_found": bool(robust_found),
|
||||
}
|
||||
|
||||
|
||||
def _plan_id(
|
||||
horizon: int,
|
||||
target_bps: float,
|
||||
stop_bps: float,
|
||||
trailing_stop_bps: float,
|
||||
second_target_multiplier: float,
|
||||
take1_ratio: float,
|
||||
take2_ratio: float,
|
||||
) -> str:
|
||||
return (
|
||||
f"dyn_h{int(horizon)}_t{target_bps:g}_s{stop_bps:g}"
|
||||
f"_trail{trailing_stop_bps:g}_t2x{second_target_multiplier:g}"
|
||||
f"_p{int(round(take1_ratio * 100))}_{int(round(take2_ratio * 100))}"
|
||||
)
|
||||
|
||||
|
||||
def _markdown_report(payload: dict[str, Any], summary: pd.DataFrame) -> str:
|
||||
top = summary.head(20)
|
||||
best = payload["best_plan"]
|
||||
verdict = "找到可继续训练的稳定出场候选。" if payload["robust_candidate_found"] else "没有找到稳定为正的出场候选;只能把最高分组合当成下一轮排查对象。"
|
||||
lines = [
|
||||
"# Dynamic Exit Search Report",
|
||||
"",
|
||||
f"- run_id: `{payload['run_id']}`",
|
||||
f"- cost_bps: {payload['cost_bps']}",
|
||||
f"- min_expected_net_edge_bps: {payload['min_expected_net_edge_bps']}",
|
||||
f"- entry_label_method: `{payload['entry_label_method']}`",
|
||||
f"- candidate_count: {payload['candidate_count']}",
|
||||
f"- verdict: {verdict}",
|
||||
"",
|
||||
"## Best Plan For Next Experiment",
|
||||
"",
|
||||
"```json",
|
||||
json.dumps(best, ensure_ascii=False, sort_keys=False),
|
||||
"```",
|
||||
"",
|
||||
"## Top Plans",
|
||||
"",
|
||||
_markdown_table(top),
|
||||
"",
|
||||
"说明:这里统计的是动态出场后的真实计划收益,已经扣掉手续费、滑点、资金费。它不是上线结论,只用来决定下一轮训练是否值得换出场参数。",
|
||||
"",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _markdown_table(frame: pd.DataFrame) -> str:
|
||||
if frame.empty:
|
||||
return "无数据。"
|
||||
columns = list(frame.columns)
|
||||
lines = ["| " + " | ".join(columns) + " |", "| " + " | ".join("---" for _ in columns) + " |"]
|
||||
for row in frame.to_dict("records"):
|
||||
values = []
|
||||
for column in columns:
|
||||
value = row.get(column, "")
|
||||
if isinstance(value, float):
|
||||
value = round(value, 6)
|
||||
values.append(str(value))
|
||||
lines.append("| " + " | ".join(values) + " |")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _jsonable(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _jsonable(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_jsonable(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [_jsonable(item) for item in value]
|
||||
if isinstance(value, (np.integer,)):
|
||||
return int(value)
|
||||
if isinstance(value, (np.floating,)):
|
||||
return float(value)
|
||||
return value
|
||||
Reference in New Issue
Block a user