from __future__ import annotations from dataclasses import dataclass from pathlib import Path import numpy as np from trader_training.schemas import FEATURE_ORDER @dataclass(frozen=True) class LinearHead: name: str kind: str weight: np.ndarray bias: np.ndarray def require_onnx(): try: import onnx from onnx import TensorProto, helper, numpy_helper except ModuleNotFoundError as exc: raise SystemExit("Python package 'onnx' is required. Install training/requirements.txt before export.") from exc return onnx, TensorProto, helper, numpy_helper def export_heads(path: Path, heads: list[LinearHead], feature_count: int = len(FEATURE_ORDER), opset: int = 17) -> None: onnx, TensorProto, helper, numpy_helper = require_onnx() nodes = [] initializers = [] concat_inputs = [] for idx, head in enumerate(heads): weight = np.asarray(head.weight, dtype=np.float32) bias = np.asarray(head.bias, dtype=np.float32).reshape(1, -1) if weight.ndim == 1: weight = weight.reshape(feature_count, 1) weight_name = f"{head.name}_W" bias_name = f"{head.name}_B" linear_name = f"{head.name}_linear" out_name = f"{head.name}_out" initializers.append(numpy_helper.from_array(weight, weight_name)) initializers.append(numpy_helper.from_array(bias, bias_name)) nodes.append(helper.make_node("MatMul", ["features", weight_name], [f"{linear_name}_mm"], name=f"{head.name}_matmul")) nodes.append(helper.make_node("Add", [f"{linear_name}_mm", bias_name], [linear_name], name=f"{head.name}_add")) if head.kind == "sigmoid": nodes.append(helper.make_node("Sigmoid", [linear_name], [out_name], name=f"{head.name}_sigmoid")) elif head.kind == "softmax": nodes.append(helper.make_node("Softmax", [linear_name], [out_name], name=f"{head.name}_softmax", axis=1)) elif head.kind == "identity": out_name = linear_name else: raise ValueError(f"unsupported ONNX head kind: {head.kind}") concat_inputs.append(out_name) if len(concat_inputs) == 1: nodes.append(helper.make_node("Identity", concat_inputs, ["prediction"], name="prediction_identity")) else: nodes.append(helper.make_node("Concat", concat_inputs, ["prediction"], name="prediction_concat", axis=1)) graph = helper.make_graph( nodes, "trader_v4_linear_heads", [helper.make_tensor_value_info("features", TensorProto.FLOAT, [1, feature_count])], [helper.make_tensor_value_info("prediction", TensorProto.FLOAT, [1, sum(_head_width(head) for head in heads)])], initializer=initializers, ) model = helper.make_model(graph, producer_name="trader-training", opset_imports=[helper.make_opsetid("", opset)]) model.ir_version = 10 onnx.checker.check_model(model) path.parent.mkdir(parents=True, exist_ok=True) onnx.save(model, path) def _head_width(head: LinearHead) -> int: bias = np.asarray(head.bias) return int(bias.size)