from __future__ import annotations import argparse import hashlib import json import logging from datetime import UTC, datetime from pathlib import Path from typing import Any, Iterable import pandas as pd DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data") DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw" def setup_logging() -> None: logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s event=%(message)s", ) def add_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT) parser.add_argument("--run-id", required=True) parser.add_argument("--config", type=Path) parser.add_argument("--workspace", type=Path) parser.add_argument("--fail-fast", action="store_true") def run_root(args: argparse.Namespace) -> Path: return args.data_root / "trader-v4" / "runs" / args.run_id def ensure_dir(path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) return path def canonical_json_bytes(value: Any) -> bytes: return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8") def sha256_bytes(data: bytes) -> str: return hashlib.sha256(data).hexdigest() def sha256_json(value: Any) -> str: return sha256_bytes(canonical_json_bytes(value)) def sha256_file(path: Path) -> str: digest = hashlib.sha256() with path.open("rb") as fh: for chunk in iter(lambda: fh.read(1024 * 1024), b""): digest.update(chunk) return digest.hexdigest() def write_json(path: Path, value: Any) -> str: ensure_dir(path.parent) data = canonical_json_bytes(value) path.write_bytes(data + b"\n") return sha256_bytes(data) def read_json(path: Path) -> Any: with path.open("r", encoding="utf-8") as fh: return json.load(fh) def write_parquet(path: Path, frame: pd.DataFrame) -> str: ensure_dir(path.parent) frame.to_parquet(path, index=False) return sha256_file(path) def read_parquet(path: Path) -> pd.DataFrame: if not path.is_file(): raise FileNotFoundError(f"required parquet is missing: {path}") return pd.read_parquet(path) def write_text(path: Path, text: str) -> None: ensure_dir(path.parent) path.write_text(text, encoding="utf-8") def utc_now_text() -> str: return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") def to_utc_series(values: pd.Series) -> pd.Series: if pd.api.types.is_datetime64_any_dtype(values): return pd.to_datetime(values, utc=True) numeric = pd.to_numeric(values, errors="coerce") if numeric.notna().any(): max_value = numeric.dropna().abs().max() unit = "ms" if max_value > 10_000_000_000 else "s" return pd.to_datetime(numeric, unit=unit, utc=True) return pd.to_datetime(values, utc=True, errors="coerce") def open_time_ms(values: pd.Series) -> pd.Series: dt = to_utc_series(values) return (dt.astype("int64") // 1_000_000).astype("Int64") def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]: return start_date, end_date def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]: base = raw_root / f"table={table}" if not base.is_dir(): return [] files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet")) selected: list[Path] = [] for file in files: dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "") if start_date and dt_part < start_date: continue if end_date and dt_part > end_date: continue selected.append(file) return selected def read_partitioned_table( raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None, columns: Iterable[str] | None = None, ) -> pd.DataFrame: files = partition_files(raw_root, table, symbol, start_date, end_date) if not files: return pd.DataFrame() logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files)) frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files] frame = pd.concat(frames, ignore_index=True) logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame)) return frame def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None: missing = [column for column in columns if column not in frame.columns] if missing: raise ValueError(f"{name} is missing required columns: {missing}") def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]: payload = { "path": str(path), "hash_sha256": sha256_file(path) if path.is_file() else None, "created_at": utc_now_text(), } payload.update(extra) return payload