Files

163 lines
4.9 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import argparse
import hashlib
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Iterable
import pandas as pd
DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data")
DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw"
def setup_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s event=%(message)s",
)
def add_common_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT)
parser.add_argument("--run-id", required=True)
parser.add_argument("--config", type=Path)
parser.add_argument("--workspace", type=Path)
parser.add_argument("--fail-fast", action="store_true")
def run_root(args: argparse.Namespace) -> Path:
return args.data_root / "trader-v4" / "runs" / args.run_id
def ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def canonical_json_bytes(value: Any) -> bytes:
return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8")
def sha256_bytes(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def sha256_json(value: Any) -> str:
return sha256_bytes(canonical_json_bytes(value))
def sha256_file(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as fh:
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def write_json(path: Path, value: Any) -> str:
ensure_dir(path.parent)
data = canonical_json_bytes(value)
path.write_bytes(data + b"\n")
return sha256_bytes(data)
def read_json(path: Path) -> Any:
with path.open("r", encoding="utf-8") as fh:
return json.load(fh)
def write_parquet(path: Path, frame: pd.DataFrame) -> str:
ensure_dir(path.parent)
frame.to_parquet(path, index=False)
return sha256_file(path)
def read_parquet(path: Path) -> pd.DataFrame:
if not path.is_file():
raise FileNotFoundError(f"required parquet is missing: {path}")
return pd.read_parquet(path)
def write_text(path: Path, text: str) -> None:
ensure_dir(path.parent)
path.write_text(text, encoding="utf-8")
def utc_now_text() -> str:
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def to_utc_series(values: pd.Series) -> pd.Series:
if pd.api.types.is_datetime64_any_dtype(values):
return pd.to_datetime(values, utc=True)
numeric = pd.to_numeric(values, errors="coerce")
if numeric.notna().any():
max_value = numeric.dropna().abs().max()
unit = "ms" if max_value > 10_000_000_000 else "s"
return pd.to_datetime(numeric, unit=unit, utc=True)
return pd.to_datetime(values, utc=True, errors="coerce")
def open_time_ms(values: pd.Series) -> pd.Series:
dt = to_utc_series(values)
return (dt.astype("int64") // 1_000_000).astype("Int64")
def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]:
return start_date, end_date
def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]:
base = raw_root / f"table={table}"
if not base.is_dir():
return []
files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet"))
selected: list[Path] = []
for file in files:
dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "")
if start_date and dt_part < start_date:
continue
if end_date and dt_part > end_date:
continue
selected.append(file)
return selected
def read_partitioned_table(
raw_root: Path,
table: str,
symbol: str,
start_date: str | None,
end_date: str | None,
columns: Iterable[str] | None = None,
) -> pd.DataFrame:
files = partition_files(raw_root, table, symbol, start_date, end_date)
if not files:
return pd.DataFrame()
logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files))
frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files]
frame = pd.concat(frames, ignore_index=True)
logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame))
return frame
def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None:
missing = [column for column in columns if column not in frame.columns]
if missing:
raise ValueError(f"{name} is missing required columns: {missing}")
def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]:
payload = {
"path": str(path),
"hash_sha256": sha256_file(path) if path.is_file() else None,
"created_at": utc_now_text(),
}
payload.update(extra)
return payload