Files
quant-trader-service/training/trader_training/validator.py
T
Codex 9acb3460a1 Improve Trader V4 training pipeline
Align entry labels with max future edge, tune direction labeling, and harden regression evaluation.

Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
2026-06-27 19:57:29 +08:00

150 lines
8.1 KiB
Python

from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import numpy as np
from trader_training.io_utils import read_json, sha256_file, write_json, write_text
from trader_training.schemas import FEATURE_ORDER, MODEL_OUTPUTS, OUTPUT_MAPPING
def validate_artifact_bundle(args: Any) -> None:
root = args.artifact_root
errors: list[str] = []
release_gate = {"release_gate_status": "UNKNOWN", "release_gate_reasons": ["artifact content not checked"]}
required = [
"manifests/model_bundle_manifest.json",
"manifests/model_manifest.json",
"manifests/calibration_manifest.json",
"manifests/position_manager_manifest.json",
"schemas/feature_schema.json",
"schemas/feature_order.json",
"schemas/output_schema.json",
"price_plan_context.json",
"examples/sample_input.json",
"examples/sample_output.json",
]
for relative in required:
if not (root / relative).is_file():
errors.append(f"missing file: {relative}")
if not errors:
_validate_content(root, errors, args.require_active, args.run_onnx)
# Structure validation only proves the bundle is loadable. The release
# gate is the separate business decision for whether it may become ACTIVE.
release_gate = _release_gate(root)
if args.require_active and release_gate["release_gate_status"] != "PASS":
errors.append(f"release gate must be PASS for ACTIVE, actual={release_gate['release_gate_status']}")
status = "PASS" if not errors else "FAIL"
result = {"status": status, "error_count": len(errors), "errors": errors, "artifact_root": str(root), **release_gate}
output_root = root.parent
write_json(output_root / "artifact_validation_result.json", result)
lines = ["# Artifact Validation Report", "", f"- status: {status}", f"- error_count: {len(errors)}", ""]
for error in errors:
lines.append(f"- {error}")
write_text(output_root / "artifact_validation_report.md", "\n".join(lines) + "\n")
logging.info("trader.training.artifact_validated status=%s errorCount=%s path=%s", status, len(errors), root)
if errors:
raise SystemExit("artifact validation failed; see artifact_validation_report.md")
def _validate_content(root: Path, errors: list[str], require_active: bool, run_onnx: bool) -> None:
feature_order = read_json(root / "schemas/feature_order.json")
if feature_order != FEATURE_ORDER:
errors.append(f"feature_order.json does not match V4 {len(FEATURE_ORDER)}-feature order")
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
if require_active and model_bundle.get("status") != "ACTIVE":
errors.append("model_bundle_manifest.status must be ACTIVE for Java SHADOW")
if model_bundle.get("status") not in {"CANDIDATE", "ACTIVE"}:
errors.append("model_bundle_manifest.status must be CANDIDATE or ACTIVE")
manifests = read_json(root / "manifests/model_manifest.json")
if len(manifests) != 5:
errors.append("model_manifest.json must contain exactly five logical models")
seen = {item.get("model_type") for item in manifests}
if seen != set(MODEL_OUTPUTS):
errors.append(f"model types mismatch: {seen}")
for item in manifests:
model_type = item.get("model_type")
if item.get("model_format") != "ONNX":
errors.append(f"{model_type} model_format must be ONNX")
if item.get("input_tensor_name") != "features":
errors.append(f"{model_type} input tensor must be features")
if item.get("input_shape_json", {}).get("features") != len(FEATURE_ORDER):
errors.append(f"{model_type} input_shape_json.features must be {len(FEATURE_ORDER)}")
if item.get("onnx_opset_version") != 17:
errors.append(f"{model_type} opset must be 17")
if item.get("output_mapping_json") != OUTPUT_MAPPING.get(model_type):
errors.append(f"{model_type} output_mapping_json does not match Java contract")
_check_hash(root, item.get("artifact_path"), item.get("artifact_hash_sha256"), errors)
_check_hash(root, item.get("feature_schema_path"), item.get("feature_schema_hash"), errors)
_check_hash(root, item.get("feature_order_path"), item.get("feature_order_hash"), errors)
_check_hash(root, item.get("output_schema_path"), item.get("output_schema_hash"), errors)
if require_active and item.get("status") != "ACTIVE":
errors.append(f"{model_type} status must be ACTIVE for Java SHADOW")
if item.get("status") != model_bundle.get("status"):
errors.append(f"{model_type} status does not match model_bundle_manifest.status")
calibrators = read_json(root / "manifests/calibration_manifest.json")
if len(calibrators) != 5:
errors.append("calibration_manifest.json must contain five calibrators")
for item in calibrators:
_check_hash(root, item.get("calibrator_path"), item.get("calibrator_hash_sha256"), errors)
if require_active and item.get("status") != "ACTIVE":
errors.append(f"{item.get('model_name')} calibrator status must be ACTIVE")
if item.get("status") != model_bundle.get("status"):
errors.append(f"{item.get('model_name')} calibrator status does not match model_bundle_manifest.status")
pm_manifest = read_json(root / "manifests/position_manager_manifest.json")
if require_active and pm_manifest.get("status") != "ACTIVE":
errors.append("position_manager_manifest.status must be ACTIVE for Java SHADOW")
if pm_manifest.get("status") != model_bundle.get("status"):
errors.append("position_manager_manifest.status does not match model_bundle_manifest.status")
if run_onnx and not errors:
_run_sample_inference(root, manifests, errors)
def _release_gate(root: Path) -> dict[str, Any]:
reasons: list[str] = []
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
if model_bundle.get("backtest_status") != "PASS":
reasons.append(f"backtest_status={model_bundle.get('backtest_status')}")
reasons.extend(model_bundle.get("backtest_status_reasons_json", []))
for item in read_json(root / "manifests/model_manifest.json"):
if item.get("quality_status") != "PASS":
reasons.append(f"{item.get('model_type')}.quality_status={item.get('quality_status')}")
reasons.extend([f"{item.get('model_type')}:{reason}" for reason in item.get("quality_reasons_json", [])])
for item in read_json(root / "manifests/calibration_manifest.json"):
if item.get("quality_status") != "PASS":
reasons.append(f"{item.get('model_name')}.calibration_quality_status={item.get('quality_status')}")
reasons.extend([f"{item.get('model_name')}:calibration:{reason}" for reason in item.get("quality_reasons_json", [])])
return {
"release_gate_status": "PASS" if not reasons else "REJECTED",
"release_gate_reasons": reasons,
}
def _check_hash(root: Path, relative: str | None, expected: str | None, errors: list[str]) -> None:
if not relative or not expected:
errors.append(f"missing hash contract for {relative}")
return
path = root / relative
if not path.is_file():
errors.append(f"hash target missing: {relative}")
return
actual = sha256_file(path)
if actual != expected:
errors.append(f"hash mismatch: {relative}")
def _run_sample_inference(root: Path, manifests: list[dict[str, Any]], errors: list[str]) -> None:
try:
import onnxruntime as ort
except ModuleNotFoundError as exc:
raise SystemExit("Python package 'onnxruntime' is required for --run-onnx validation.") from exc
sample = read_json(root / "examples/sample_input.json")
features = np.array([[float(sample[name]) for name in FEATURE_ORDER]], dtype=np.float32)
for item in manifests:
session = ort.InferenceSession(str(root / item["artifact_path"]))
outputs = session.run(None, {"features": features})
if not outputs or outputs[0].shape[-1] != len(MODEL_OUTPUTS[item["model_type"]]):
errors.append(f"{item['model_type']} sample output shape is invalid")