Implement Trader V4 training artifact pipeline
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data")
|
||||
DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw"
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s event=%(message)s",
|
||||
)
|
||||
|
||||
|
||||
def add_common_args(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--run-id", required=True)
|
||||
parser.add_argument("--config", type=Path)
|
||||
parser.add_argument("--workspace", type=Path)
|
||||
parser.add_argument("--fail-fast", action="store_true")
|
||||
|
||||
|
||||
def run_root(args: argparse.Namespace) -> Path:
|
||||
return args.data_root / "trader-v4" / "runs" / args.run_id
|
||||
|
||||
|
||||
def ensure_dir(path: Path) -> Path:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def canonical_json_bytes(value: Any) -> bytes:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8")
|
||||
|
||||
|
||||
def sha256_bytes(data: bytes) -> str:
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def sha256_json(value: Any) -> str:
|
||||
return sha256_bytes(canonical_json_bytes(value))
|
||||
|
||||
|
||||
def sha256_file(path: Path) -> str:
|
||||
digest = hashlib.sha256()
|
||||
with path.open("rb") as fh:
|
||||
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def write_json(path: Path, value: Any) -> str:
|
||||
ensure_dir(path.parent)
|
||||
data = canonical_json_bytes(value)
|
||||
path.write_bytes(data + b"\n")
|
||||
return sha256_bytes(data)
|
||||
|
||||
|
||||
def read_json(path: Path) -> Any:
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
return json.load(fh)
|
||||
|
||||
|
||||
def write_parquet(path: Path, frame: pd.DataFrame) -> str:
|
||||
ensure_dir(path.parent)
|
||||
frame.to_parquet(path, index=False)
|
||||
return sha256_file(path)
|
||||
|
||||
|
||||
def read_parquet(path: Path) -> pd.DataFrame:
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"required parquet is missing: {path}")
|
||||
return pd.read_parquet(path)
|
||||
|
||||
|
||||
def write_text(path: Path, text: str) -> None:
|
||||
ensure_dir(path.parent)
|
||||
path.write_text(text, encoding="utf-8")
|
||||
|
||||
|
||||
def utc_now_text() -> str:
|
||||
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def to_utc_series(values: pd.Series) -> pd.Series:
|
||||
if pd.api.types.is_datetime64_any_dtype(values):
|
||||
return pd.to_datetime(values, utc=True)
|
||||
numeric = pd.to_numeric(values, errors="coerce")
|
||||
if numeric.notna().any():
|
||||
max_value = numeric.dropna().abs().max()
|
||||
unit = "ms" if max_value > 10_000_000_000 else "s"
|
||||
return pd.to_datetime(numeric, unit=unit, utc=True)
|
||||
return pd.to_datetime(values, utc=True, errors="coerce")
|
||||
|
||||
|
||||
def open_time_ms(values: pd.Series) -> pd.Series:
|
||||
dt = to_utc_series(values)
|
||||
return (dt.astype("int64") // 1_000_000).astype("Int64")
|
||||
|
||||
|
||||
def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]:
|
||||
return start_date, end_date
|
||||
|
||||
|
||||
def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]:
|
||||
base = raw_root / f"table={table}"
|
||||
if not base.is_dir():
|
||||
return []
|
||||
files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet"))
|
||||
selected: list[Path] = []
|
||||
for file in files:
|
||||
dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "")
|
||||
if start_date and dt_part < start_date:
|
||||
continue
|
||||
if end_date and dt_part > end_date:
|
||||
continue
|
||||
selected.append(file)
|
||||
return selected
|
||||
|
||||
|
||||
def read_partitioned_table(
|
||||
raw_root: Path,
|
||||
table: str,
|
||||
symbol: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
columns: Iterable[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
files = partition_files(raw_root, table, symbol, start_date, end_date)
|
||||
if not files:
|
||||
return pd.DataFrame()
|
||||
logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files))
|
||||
frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files]
|
||||
frame = pd.concat(frames, ignore_index=True)
|
||||
logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame))
|
||||
return frame
|
||||
|
||||
|
||||
def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None:
|
||||
missing = [column for column in columns if column not in frame.columns]
|
||||
if missing:
|
||||
raise ValueError(f"{name} is missing required columns: {missing}")
|
||||
|
||||
|
||||
def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = {
|
||||
"path": str(path),
|
||||
"hash_sha256": sha256_file(path) if path.is_file() else None,
|
||||
"created_at": utc_now_text(),
|
||||
}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
Reference in New Issue
Block a user