from __future__ import annotations import logging from pathlib import Path from typing import Any import numpy as np from trader_training.io_utils import read_json, sha256_file, write_json, write_text from trader_training.schemas import FEATURE_ORDER, MODEL_OUTPUTS, OUTPUT_MAPPING def validate_artifact_bundle(args: Any) -> None: root = args.artifact_root errors: list[str] = [] release_gate = {"release_gate_status": "UNKNOWN", "release_gate_reasons": ["artifact content not checked"]} required = [ "manifests/model_bundle_manifest.json", "manifests/model_manifest.json", "manifests/calibration_manifest.json", "manifests/position_manager_manifest.json", "schemas/feature_schema.json", "schemas/feature_order.json", "schemas/output_schema.json", "price_plan_context.json", "examples/sample_input.json", "examples/sample_output.json", ] for relative in required: if not (root / relative).is_file(): errors.append(f"missing file: {relative}") if not errors: _validate_content(root, errors, args.require_active, args.run_onnx) # Structure validation only proves the bundle is loadable. The release # gate is the separate business decision for whether it may become ACTIVE. release_gate = _release_gate(root) if args.require_active and release_gate["release_gate_status"] != "PASS": errors.append(f"release gate must be PASS for ACTIVE, actual={release_gate['release_gate_status']}") status = "PASS" if not errors else "FAIL" result = {"status": status, "error_count": len(errors), "errors": errors, "artifact_root": str(root), **release_gate} output_root = root.parent write_json(output_root / "artifact_validation_result.json", result) lines = ["# Artifact Validation Report", "", f"- status: {status}", f"- error_count: {len(errors)}", ""] for error in errors: lines.append(f"- {error}") write_text(output_root / "artifact_validation_report.md", "\n".join(lines) + "\n") logging.info("trader.training.artifact_validated status=%s errorCount=%s path=%s", status, len(errors), root) if errors: raise SystemExit("artifact validation failed; see artifact_validation_report.md") def _validate_content(root: Path, errors: list[str], require_active: bool, run_onnx: bool) -> None: feature_order = read_json(root / "schemas/feature_order.json") if feature_order != FEATURE_ORDER: errors.append(f"feature_order.json does not match V4 {len(FEATURE_ORDER)}-feature order") model_bundle = read_json(root / "manifests/model_bundle_manifest.json") if require_active and model_bundle.get("status") != "ACTIVE": errors.append("model_bundle_manifest.status must be ACTIVE for Java SHADOW") if model_bundle.get("status") not in {"CANDIDATE", "ACTIVE"}: errors.append("model_bundle_manifest.status must be CANDIDATE or ACTIVE") manifests = read_json(root / "manifests/model_manifest.json") if len(manifests) != 5: errors.append("model_manifest.json must contain exactly five logical models") seen = {item.get("model_type") for item in manifests} if seen != set(MODEL_OUTPUTS): errors.append(f"model types mismatch: {seen}") for item in manifests: model_type = item.get("model_type") if item.get("model_format") != "ONNX": errors.append(f"{model_type} model_format must be ONNX") if item.get("input_tensor_name") != "features": errors.append(f"{model_type} input tensor must be features") if item.get("input_shape_json", {}).get("features") != len(FEATURE_ORDER): errors.append(f"{model_type} input_shape_json.features must be {len(FEATURE_ORDER)}") if item.get("onnx_opset_version") != 17: errors.append(f"{model_type} opset must be 17") if item.get("output_mapping_json") != OUTPUT_MAPPING.get(model_type): errors.append(f"{model_type} output_mapping_json does not match Java contract") _check_hash(root, item.get("artifact_path"), item.get("artifact_hash_sha256"), errors) _check_hash(root, item.get("feature_schema_path"), item.get("feature_schema_hash"), errors) _check_hash(root, item.get("feature_order_path"), item.get("feature_order_hash"), errors) _check_hash(root, item.get("output_schema_path"), item.get("output_schema_hash"), errors) if require_active and item.get("status") != "ACTIVE": errors.append(f"{model_type} status must be ACTIVE for Java SHADOW") if item.get("status") != model_bundle.get("status"): errors.append(f"{model_type} status does not match model_bundle_manifest.status") calibrators = read_json(root / "manifests/calibration_manifest.json") if len(calibrators) != 5: errors.append("calibration_manifest.json must contain five calibrators") for item in calibrators: _check_hash(root, item.get("calibrator_path"), item.get("calibrator_hash_sha256"), errors) if require_active and item.get("status") != "ACTIVE": errors.append(f"{item.get('model_name')} calibrator status must be ACTIVE") if item.get("status") != model_bundle.get("status"): errors.append(f"{item.get('model_name')} calibrator status does not match model_bundle_manifest.status") pm_manifest = read_json(root / "manifests/position_manager_manifest.json") if require_active and pm_manifest.get("status") != "ACTIVE": errors.append("position_manager_manifest.status must be ACTIVE for Java SHADOW") if pm_manifest.get("status") != model_bundle.get("status"): errors.append("position_manager_manifest.status does not match model_bundle_manifest.status") if run_onnx and not errors: _run_sample_inference(root, manifests, errors) def _release_gate(root: Path) -> dict[str, Any]: reasons: list[str] = [] model_bundle = read_json(root / "manifests/model_bundle_manifest.json") if model_bundle.get("backtest_status") != "PASS": reasons.append(f"backtest_status={model_bundle.get('backtest_status')}") reasons.extend(model_bundle.get("backtest_status_reasons_json", [])) for item in read_json(root / "manifests/model_manifest.json"): if item.get("quality_status") != "PASS": reasons.append(f"{item.get('model_type')}.quality_status={item.get('quality_status')}") reasons.extend([f"{item.get('model_type')}:{reason}" for reason in item.get("quality_reasons_json", [])]) for item in read_json(root / "manifests/calibration_manifest.json"): if item.get("quality_status") != "PASS": reasons.append(f"{item.get('model_name')}.calibration_quality_status={item.get('quality_status')}") reasons.extend([f"{item.get('model_name')}:calibration:{reason}" for reason in item.get("quality_reasons_json", [])]) return { "release_gate_status": "PASS" if not reasons else "REJECTED", "release_gate_reasons": reasons, } def _check_hash(root: Path, relative: str | None, expected: str | None, errors: list[str]) -> None: if not relative or not expected: errors.append(f"missing hash contract for {relative}") return path = root / relative if not path.is_file(): errors.append(f"hash target missing: {relative}") return actual = sha256_file(path) if actual != expected: errors.append(f"hash mismatch: {relative}") def _run_sample_inference(root: Path, manifests: list[dict[str, Any]], errors: list[str]) -> None: try: import onnxruntime as ort except ModuleNotFoundError as exc: raise SystemExit("Python package 'onnxruntime' is required for --run-onnx validation.") from exc sample = read_json(root / "examples/sample_input.json") features = np.array([[float(sample[name]) for name in FEATURE_ORDER]], dtype=np.float32) for item in manifests: session = ort.InferenceSession(str(root / item["artifact_path"])) outputs = session.run(None, {"features": features}) if not outputs or outputs[0].shape[-1] != len(MODEL_OUTPUTS[item["model_type"]]): errors.append(f"{item['model_type']} sample output shape is invalid")