163 lines
4.9 KiB
Python
163 lines
4.9 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from datetime import UTC, datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Iterable
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
|
||
|
|
DEFAULT_DATA_ROOT = Path("/Users/zach/Desktop/quant-strategy-training-data")
|
||
|
|
DEFAULT_RAW_ROOT = DEFAULT_DATA_ROOT / "crypto-lake" / "raw"
|
||
|
|
|
||
|
|
|
||
|
|
def setup_logging() -> None:
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format="%(asctime)s %(levelname)s event=%(message)s",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def add_common_args(parser: argparse.ArgumentParser) -> None:
|
||
|
|
parser.add_argument("--data-root", type=Path, default=DEFAULT_DATA_ROOT)
|
||
|
|
parser.add_argument("--run-id", required=True)
|
||
|
|
parser.add_argument("--config", type=Path)
|
||
|
|
parser.add_argument("--workspace", type=Path)
|
||
|
|
parser.add_argument("--fail-fast", action="store_true")
|
||
|
|
|
||
|
|
|
||
|
|
def run_root(args: argparse.Namespace) -> Path:
|
||
|
|
return args.data_root / "trader-v4" / "runs" / args.run_id
|
||
|
|
|
||
|
|
|
||
|
|
def ensure_dir(path: Path) -> Path:
|
||
|
|
path.mkdir(parents=True, exist_ok=True)
|
||
|
|
return path
|
||
|
|
|
||
|
|
|
||
|
|
def canonical_json_bytes(value: Any) -> bytes:
|
||
|
|
return json.dumps(value, ensure_ascii=False, sort_keys=False, separators=(",", ":")).encode("utf-8")
|
||
|
|
|
||
|
|
|
||
|
|
def sha256_bytes(data: bytes) -> str:
|
||
|
|
return hashlib.sha256(data).hexdigest()
|
||
|
|
|
||
|
|
|
||
|
|
def sha256_json(value: Any) -> str:
|
||
|
|
return sha256_bytes(canonical_json_bytes(value))
|
||
|
|
|
||
|
|
|
||
|
|
def sha256_file(path: Path) -> str:
|
||
|
|
digest = hashlib.sha256()
|
||
|
|
with path.open("rb") as fh:
|
||
|
|
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
|
||
|
|
digest.update(chunk)
|
||
|
|
return digest.hexdigest()
|
||
|
|
|
||
|
|
|
||
|
|
def write_json(path: Path, value: Any) -> str:
|
||
|
|
ensure_dir(path.parent)
|
||
|
|
data = canonical_json_bytes(value)
|
||
|
|
path.write_bytes(data + b"\n")
|
||
|
|
return sha256_bytes(data)
|
||
|
|
|
||
|
|
|
||
|
|
def read_json(path: Path) -> Any:
|
||
|
|
with path.open("r", encoding="utf-8") as fh:
|
||
|
|
return json.load(fh)
|
||
|
|
|
||
|
|
|
||
|
|
def write_parquet(path: Path, frame: pd.DataFrame) -> str:
|
||
|
|
ensure_dir(path.parent)
|
||
|
|
frame.to_parquet(path, index=False)
|
||
|
|
return sha256_file(path)
|
||
|
|
|
||
|
|
|
||
|
|
def read_parquet(path: Path) -> pd.DataFrame:
|
||
|
|
if not path.is_file():
|
||
|
|
raise FileNotFoundError(f"required parquet is missing: {path}")
|
||
|
|
return pd.read_parquet(path)
|
||
|
|
|
||
|
|
|
||
|
|
def write_text(path: Path, text: str) -> None:
|
||
|
|
ensure_dir(path.parent)
|
||
|
|
path.write_text(text, encoding="utf-8")
|
||
|
|
|
||
|
|
|
||
|
|
def utc_now_text() -> str:
|
||
|
|
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
||
|
|
|
||
|
|
|
||
|
|
def to_utc_series(values: pd.Series) -> pd.Series:
|
||
|
|
if pd.api.types.is_datetime64_any_dtype(values):
|
||
|
|
return pd.to_datetime(values, utc=True)
|
||
|
|
numeric = pd.to_numeric(values, errors="coerce")
|
||
|
|
if numeric.notna().any():
|
||
|
|
max_value = numeric.dropna().abs().max()
|
||
|
|
unit = "ms" if max_value > 10_000_000_000 else "s"
|
||
|
|
return pd.to_datetime(numeric, unit=unit, utc=True)
|
||
|
|
return pd.to_datetime(values, utc=True, errors="coerce")
|
||
|
|
|
||
|
|
|
||
|
|
def open_time_ms(values: pd.Series) -> pd.Series:
|
||
|
|
dt = to_utc_series(values)
|
||
|
|
return (dt.astype("int64") // 1_000_000).astype("Int64")
|
||
|
|
|
||
|
|
|
||
|
|
def date_texts(start_date: str | None, end_date: str | None) -> tuple[str | None, str | None]:
|
||
|
|
return start_date, end_date
|
||
|
|
|
||
|
|
|
||
|
|
def partition_files(raw_root: Path, table: str, symbol: str, start_date: str | None, end_date: str | None) -> list[Path]:
|
||
|
|
base = raw_root / f"table={table}"
|
||
|
|
if not base.is_dir():
|
||
|
|
return []
|
||
|
|
files = sorted(base.glob(f"exchange=*/symbol={symbol}/dt=*/data.parquet"))
|
||
|
|
selected: list[Path] = []
|
||
|
|
for file in files:
|
||
|
|
dt_part = next((part.split("=", 1)[1] for part in file.parts if part.startswith("dt=")), "")
|
||
|
|
if start_date and dt_part < start_date:
|
||
|
|
continue
|
||
|
|
if end_date and dt_part > end_date:
|
||
|
|
continue
|
||
|
|
selected.append(file)
|
||
|
|
return selected
|
||
|
|
|
||
|
|
|
||
|
|
def read_partitioned_table(
|
||
|
|
raw_root: Path,
|
||
|
|
table: str,
|
||
|
|
symbol: str,
|
||
|
|
start_date: str | None,
|
||
|
|
end_date: str | None,
|
||
|
|
columns: Iterable[str] | None = None,
|
||
|
|
) -> pd.DataFrame:
|
||
|
|
files = partition_files(raw_root, table, symbol, start_date, end_date)
|
||
|
|
if not files:
|
||
|
|
return pd.DataFrame()
|
||
|
|
logging.info("trader.training.raw_read_started table=%s symbol=%s fileCount=%s", table, symbol, len(files))
|
||
|
|
frames = [pd.read_parquet(file, columns=list(columns) if columns else None) for file in files]
|
||
|
|
frame = pd.concat(frames, ignore_index=True)
|
||
|
|
logging.info("trader.training.raw_read_finished table=%s rowCount=%s", table, len(frame))
|
||
|
|
return frame
|
||
|
|
|
||
|
|
|
||
|
|
def require_columns(frame: pd.DataFrame, columns: Iterable[str], name: str) -> None:
|
||
|
|
missing = [column for column in columns if column not in frame.columns]
|
||
|
|
if missing:
|
||
|
|
raise ValueError(f"{name} is missing required columns: {missing}")
|
||
|
|
|
||
|
|
|
||
|
|
def manifest(path: Path, extra: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
payload = {
|
||
|
|
"path": str(path),
|
||
|
|
"hash_sha256": sha256_file(path) if path.is_file() else None,
|
||
|
|
"created_at": utc_now_text(),
|
||
|
|
}
|
||
|
|
payload.update(extra)
|
||
|
|
return payload
|