9acb3460a1
Align entry labels with max future edge, tune direction labeling, and harden regression evaluation. Add training diagnostics, price-plan search, feature screening, and nonlinear benchmark scripts.
150 lines
8.1 KiB
Python
150 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from trader_training.io_utils import read_json, sha256_file, write_json, write_text
|
|
from trader_training.schemas import FEATURE_ORDER, MODEL_OUTPUTS, OUTPUT_MAPPING
|
|
|
|
|
|
def validate_artifact_bundle(args: Any) -> None:
|
|
root = args.artifact_root
|
|
errors: list[str] = []
|
|
release_gate = {"release_gate_status": "UNKNOWN", "release_gate_reasons": ["artifact content not checked"]}
|
|
required = [
|
|
"manifests/model_bundle_manifest.json",
|
|
"manifests/model_manifest.json",
|
|
"manifests/calibration_manifest.json",
|
|
"manifests/position_manager_manifest.json",
|
|
"schemas/feature_schema.json",
|
|
"schemas/feature_order.json",
|
|
"schemas/output_schema.json",
|
|
"price_plan_context.json",
|
|
"examples/sample_input.json",
|
|
"examples/sample_output.json",
|
|
]
|
|
for relative in required:
|
|
if not (root / relative).is_file():
|
|
errors.append(f"missing file: {relative}")
|
|
if not errors:
|
|
_validate_content(root, errors, args.require_active, args.run_onnx)
|
|
# Structure validation only proves the bundle is loadable. The release
|
|
# gate is the separate business decision for whether it may become ACTIVE.
|
|
release_gate = _release_gate(root)
|
|
if args.require_active and release_gate["release_gate_status"] != "PASS":
|
|
errors.append(f"release gate must be PASS for ACTIVE, actual={release_gate['release_gate_status']}")
|
|
status = "PASS" if not errors else "FAIL"
|
|
result = {"status": status, "error_count": len(errors), "errors": errors, "artifact_root": str(root), **release_gate}
|
|
output_root = root.parent
|
|
write_json(output_root / "artifact_validation_result.json", result)
|
|
lines = ["# Artifact Validation Report", "", f"- status: {status}", f"- error_count: {len(errors)}", ""]
|
|
for error in errors:
|
|
lines.append(f"- {error}")
|
|
write_text(output_root / "artifact_validation_report.md", "\n".join(lines) + "\n")
|
|
logging.info("trader.training.artifact_validated status=%s errorCount=%s path=%s", status, len(errors), root)
|
|
if errors:
|
|
raise SystemExit("artifact validation failed; see artifact_validation_report.md")
|
|
|
|
|
|
def _validate_content(root: Path, errors: list[str], require_active: bool, run_onnx: bool) -> None:
|
|
feature_order = read_json(root / "schemas/feature_order.json")
|
|
if feature_order != FEATURE_ORDER:
|
|
errors.append(f"feature_order.json does not match V4 {len(FEATURE_ORDER)}-feature order")
|
|
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
|
|
if require_active and model_bundle.get("status") != "ACTIVE":
|
|
errors.append("model_bundle_manifest.status must be ACTIVE for Java SHADOW")
|
|
if model_bundle.get("status") not in {"CANDIDATE", "ACTIVE"}:
|
|
errors.append("model_bundle_manifest.status must be CANDIDATE or ACTIVE")
|
|
manifests = read_json(root / "manifests/model_manifest.json")
|
|
if len(manifests) != 5:
|
|
errors.append("model_manifest.json must contain exactly five logical models")
|
|
seen = {item.get("model_type") for item in manifests}
|
|
if seen != set(MODEL_OUTPUTS):
|
|
errors.append(f"model types mismatch: {seen}")
|
|
for item in manifests:
|
|
model_type = item.get("model_type")
|
|
if item.get("model_format") != "ONNX":
|
|
errors.append(f"{model_type} model_format must be ONNX")
|
|
if item.get("input_tensor_name") != "features":
|
|
errors.append(f"{model_type} input tensor must be features")
|
|
if item.get("input_shape_json", {}).get("features") != len(FEATURE_ORDER):
|
|
errors.append(f"{model_type} input_shape_json.features must be {len(FEATURE_ORDER)}")
|
|
if item.get("onnx_opset_version") != 17:
|
|
errors.append(f"{model_type} opset must be 17")
|
|
if item.get("output_mapping_json") != OUTPUT_MAPPING.get(model_type):
|
|
errors.append(f"{model_type} output_mapping_json does not match Java contract")
|
|
_check_hash(root, item.get("artifact_path"), item.get("artifact_hash_sha256"), errors)
|
|
_check_hash(root, item.get("feature_schema_path"), item.get("feature_schema_hash"), errors)
|
|
_check_hash(root, item.get("feature_order_path"), item.get("feature_order_hash"), errors)
|
|
_check_hash(root, item.get("output_schema_path"), item.get("output_schema_hash"), errors)
|
|
if require_active and item.get("status") != "ACTIVE":
|
|
errors.append(f"{model_type} status must be ACTIVE for Java SHADOW")
|
|
if item.get("status") != model_bundle.get("status"):
|
|
errors.append(f"{model_type} status does not match model_bundle_manifest.status")
|
|
calibrators = read_json(root / "manifests/calibration_manifest.json")
|
|
if len(calibrators) != 5:
|
|
errors.append("calibration_manifest.json must contain five calibrators")
|
|
for item in calibrators:
|
|
_check_hash(root, item.get("calibrator_path"), item.get("calibrator_hash_sha256"), errors)
|
|
if require_active and item.get("status") != "ACTIVE":
|
|
errors.append(f"{item.get('model_name')} calibrator status must be ACTIVE")
|
|
if item.get("status") != model_bundle.get("status"):
|
|
errors.append(f"{item.get('model_name')} calibrator status does not match model_bundle_manifest.status")
|
|
pm_manifest = read_json(root / "manifests/position_manager_manifest.json")
|
|
if require_active and pm_manifest.get("status") != "ACTIVE":
|
|
errors.append("position_manager_manifest.status must be ACTIVE for Java SHADOW")
|
|
if pm_manifest.get("status") != model_bundle.get("status"):
|
|
errors.append("position_manager_manifest.status does not match model_bundle_manifest.status")
|
|
if run_onnx and not errors:
|
|
_run_sample_inference(root, manifests, errors)
|
|
|
|
|
|
def _release_gate(root: Path) -> dict[str, Any]:
|
|
reasons: list[str] = []
|
|
model_bundle = read_json(root / "manifests/model_bundle_manifest.json")
|
|
if model_bundle.get("backtest_status") != "PASS":
|
|
reasons.append(f"backtest_status={model_bundle.get('backtest_status')}")
|
|
reasons.extend(model_bundle.get("backtest_status_reasons_json", []))
|
|
for item in read_json(root / "manifests/model_manifest.json"):
|
|
if item.get("quality_status") != "PASS":
|
|
reasons.append(f"{item.get('model_type')}.quality_status={item.get('quality_status')}")
|
|
reasons.extend([f"{item.get('model_type')}:{reason}" for reason in item.get("quality_reasons_json", [])])
|
|
for item in read_json(root / "manifests/calibration_manifest.json"):
|
|
if item.get("quality_status") != "PASS":
|
|
reasons.append(f"{item.get('model_name')}.calibration_quality_status={item.get('quality_status')}")
|
|
reasons.extend([f"{item.get('model_name')}:calibration:{reason}" for reason in item.get("quality_reasons_json", [])])
|
|
return {
|
|
"release_gate_status": "PASS" if not reasons else "REJECTED",
|
|
"release_gate_reasons": reasons,
|
|
}
|
|
|
|
|
|
def _check_hash(root: Path, relative: str | None, expected: str | None, errors: list[str]) -> None:
|
|
if not relative or not expected:
|
|
errors.append(f"missing hash contract for {relative}")
|
|
return
|
|
path = root / relative
|
|
if not path.is_file():
|
|
errors.append(f"hash target missing: {relative}")
|
|
return
|
|
actual = sha256_file(path)
|
|
if actual != expected:
|
|
errors.append(f"hash mismatch: {relative}")
|
|
|
|
|
|
def _run_sample_inference(root: Path, manifests: list[dict[str, Any]], errors: list[str]) -> None:
|
|
try:
|
|
import onnxruntime as ort
|
|
except ModuleNotFoundError as exc:
|
|
raise SystemExit("Python package 'onnxruntime' is required for --run-onnx validation.") from exc
|
|
sample = read_json(root / "examples/sample_input.json")
|
|
features = np.array([[float(sample[name]) for name in FEATURE_ORDER]], dtype=np.float32)
|
|
for item in manifests:
|
|
session = ort.InferenceSession(str(root / item["artifact_path"]))
|
|
outputs = session.run(None, {"features": features})
|
|
if not outputs or outputs[0].shape[-1] != len(MODEL_OUTPUTS[item["model_type"]]):
|
|
errors.append(f"{item['model_type']} sample output shape is invalid")
|