Add crypto lake replay labels
This commit is contained in:
@@ -46,7 +46,7 @@ public class PlaybookCandidateEngine {
|
||||
Ids.candidateId(cycle, playbook.playbookId()),
|
||||
playbook.playbookId(),
|
||||
playbook.playbookVersion(),
|
||||
TraderSide.LONG,
|
||||
requiredSide(snapshot.setupFeatures(), "side"),
|
||||
playbook.variant(),
|
||||
snapshot.snapshotTime(),
|
||||
pricePlan,
|
||||
@@ -55,6 +55,20 @@ public class PlaybookCandidateEngine {
|
||||
));
|
||||
}
|
||||
|
||||
private TraderSide requiredSide(Map<String, Object> map, String key) {
|
||||
Object value = map.get(key);
|
||||
if (value instanceof TraderSide side) {
|
||||
return side;
|
||||
}
|
||||
if (value instanceof String text && !text.isBlank()) {
|
||||
return TraderSide.valueOf(text.trim().toUpperCase());
|
||||
}
|
||||
throw new TraderException(
|
||||
TraderErrorCode.TRADER_ENTRY_PLAN_INCOMPLETE,
|
||||
"setup feature is required when setupPass=true: " + key
|
||||
);
|
||||
}
|
||||
|
||||
private BigDecimal requiredDecimal(Map<String, Object> map, String key) {
|
||||
Object value = map.get(key);
|
||||
if (value instanceof Number number) {
|
||||
|
||||
@@ -79,14 +79,14 @@ public class TraderDecisionCycleRunner {
|
||||
StageDecision context = contextGate.evaluate(snapshot);
|
||||
evidenceAppender.append(cycle, "CONTEXT_GATE", context);
|
||||
if (context.blocked()) {
|
||||
TraderTrainingSample sample = sampleExporter.export(cycle.withState(TraderState.BLOCKED, "BLOCKED", context.blocker()), null, null, null);
|
||||
TraderTrainingSample sample = sampleExporter.export(cycle.withState(TraderState.BLOCKED, "BLOCKED", context.blocker()), snapshot, null, null, null);
|
||||
return new TraderCycleResult(cycle, null, null, sample);
|
||||
}
|
||||
|
||||
List<PlaybookCandidate> candidates = playbookCandidateEngine.generate(snapshot, cycle);
|
||||
if (candidates.isEmpty()) {
|
||||
evidenceAppender.append(cycle, "PLAYBOOK_CANDIDATE", StageDecision.block("NO_PLAYBOOK_CANDIDATE", "NO_PLAYBOOK_CANDIDATE"));
|
||||
TraderTrainingSample sample = sampleExporter.export(cycle, null, null, null);
|
||||
TraderTrainingSample sample = sampleExporter.export(cycle, snapshot, null, null, null);
|
||||
return new TraderCycleResult(cycle, null, null, sample);
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ public class TraderDecisionCycleRunner {
|
||||
var trigger = triggerMarkoutService.evaluate(snapshot, selected);
|
||||
evidenceAppender.append(cycle, "TRIGGER_MARKOUT", new StageDecision(trigger.pass(), trigger.reason(), trigger.blocker(), trigger.details()));
|
||||
if (trigger.blocked()) {
|
||||
TraderTrainingSample sample = sampleExporter.export(cycle.withState(TraderState.TRIGGER_WAIT, "WAIT", trigger.blocker()), selected, null, null);
|
||||
TraderTrainingSample sample = sampleExporter.export(cycle.withState(TraderState.TRIGGER_WAIT, "WAIT", trigger.blocker()), snapshot, selected, null, null);
|
||||
return new TraderCycleResult(cycle, null, null, sample);
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ public class TraderDecisionCycleRunner {
|
||||
RiskDecision risk = riskGate.evaluate(entryCycle, entryPlan, execution);
|
||||
evidenceAppender.append(entryCycle, "RISK_GATE", new StageDecision(risk.allowAction(), risk.allowAction() ? "RISK_PASS" : "RISK_BLOCKED", risk.blocker(), risk.details()));
|
||||
if (execution.blocked() || risk.blocked()) {
|
||||
TraderTrainingSample sample = sampleExporter.export(entryCycle.withState(TraderState.BLOCKED, "BLOCKED", risk.blocker()), selected, null, null);
|
||||
TraderTrainingSample sample = sampleExporter.export(entryCycle.withState(TraderState.BLOCKED, "BLOCKED", risk.blocker()), snapshot, selected, null, null);
|
||||
return new TraderCycleResult(entryCycle, null, null, sample);
|
||||
}
|
||||
|
||||
@@ -115,6 +115,7 @@ public class TraderDecisionCycleRunner {
|
||||
TraderLifecycleResult lifecycle = runPositionLifecycle(entryCycle, selected, action, path, snapshot);
|
||||
TraderTrainingSample sample = sampleExporter.export(
|
||||
lifecycle.finalCycle(),
|
||||
snapshot,
|
||||
selected,
|
||||
lifecycle.lastAction(),
|
||||
lifecycle.finalPath()
|
||||
|
||||
@@ -14,7 +14,8 @@ public record TraderMarketSnapshot(
|
||||
Map<String, Object> setupFeatures,
|
||||
Map<String, Object> triggerFeatures,
|
||||
Map<String, Object> executionFeatures,
|
||||
Map<String, Object> dataQuality
|
||||
Map<String, Object> dataQuality,
|
||||
Map<String, Object> labelInputs
|
||||
) {
|
||||
|
||||
public TraderMarketSnapshot {
|
||||
@@ -23,5 +24,6 @@ public record TraderMarketSnapshot(
|
||||
triggerFeatures = Maps.immutable(triggerFeatures);
|
||||
executionFeatures = Maps.immutable(executionFeatures);
|
||||
dataQuality = Maps.immutable(dataQuality);
|
||||
labelInputs = Maps.immutable(labelInputs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,8 @@ public class SnapshotBuilder {
|
||||
Objects.requireNonNull(tick.setupFeatures(), "setupFeatures is required"),
|
||||
Objects.requireNonNull(tick.triggerFeatures(), "triggerFeatures is required"),
|
||||
Objects.requireNonNull(tick.executionFeatures(), "executionFeatures is required"),
|
||||
Objects.requireNonNull(tick.dataQuality(), "dataQuality is required")
|
||||
Objects.requireNonNull(tick.dataQuality(), "dataQuality is required"),
|
||||
Objects.requireNonNull(tick.labelInputs(), "labelInputs is required")
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,498 @@
|
||||
package com.quantai.trader.replay;
|
||||
|
||||
import com.quantai.trader.domain.TraderException;
|
||||
import com.quantai.trader.enums.TraderErrorCode;
|
||||
import com.quantai.trader.enums.TraderSide;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.math.MathContext;
|
||||
import java.math.RoundingMode;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NavigableMap;
|
||||
import java.util.TreeMap;
|
||||
|
||||
@Component
|
||||
public class CryptoLakeReplayCsvMarketEventReader implements ReplayMarketEventReader {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(CryptoLakeReplayCsvMarketEventReader.class);
|
||||
private static final MathContext MC = new MathContext(16, RoundingMode.HALF_UP);
|
||||
private static final String REPLAY_SOURCE_KEY = "cryptoLakeReplay1m";
|
||||
private static final String CANDIDATE_SOURCE_KEY = "candidateEvents";
|
||||
private static final BigDecimal LONG_INVALID_BPS = new BigDecimal("12.0");
|
||||
private static final BigDecimal LONG_STOP_BPS = new BigDecimal("8.0");
|
||||
private static final BigDecimal LONG_TARGET_BPS = new BigDecimal("30.0");
|
||||
private static final BigDecimal SHORT_INVALID_BPS = new BigDecimal("12.0");
|
||||
private static final BigDecimal SHORT_STOP_BPS = new BigDecimal("8.0");
|
||||
private static final BigDecimal SHORT_TARGET_BPS = new BigDecimal("30.0");
|
||||
|
||||
@Override
|
||||
public boolean supports(ReplayRunConfig config) {
|
||||
DataSourceSpec source = config.dataSources() == null ? null : config.dataSources().get(REPLAY_SOURCE_KEY);
|
||||
return source != null && source.path() != null && source.path().endsWith(".csv");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validateReadable(ReplayRunConfig config) {
|
||||
validateSource(selectReplaySource(config), REPLAY_SOURCE_KEY);
|
||||
DataSourceSpec candidateSource = config.dataSources().get(CANDIDATE_SOURCE_KEY);
|
||||
if (candidateSource != null) {
|
||||
validateSource(candidateSource, CANDIDATE_SOURCE_KEY);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ReplayClockTick> readTicks(ReplayRunConfig config) {
|
||||
validateReadable(config);
|
||||
NavigableMap<Instant, MarketBar> bars = readReplayBars(config);
|
||||
List<ReplayClockTick> ticks = config.dataSources().containsKey(CANDIDATE_SOURCE_KEY)
|
||||
? readCandidateTicks(config, bars)
|
||||
: readMarketAuditTicks(config, bars);
|
||||
if (ticks.isEmpty()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "crypto lake replay csv produced no ticks");
|
||||
}
|
||||
log.info(
|
||||
"event=trader.replay.crypto_lake_csv.loaded runId={} symbol={} tickCount={} candidateMode={}",
|
||||
config.runId(),
|
||||
config.symbol(),
|
||||
ticks.size(),
|
||||
config.dataSources().containsKey(CANDIDATE_SOURCE_KEY)
|
||||
);
|
||||
return ticks;
|
||||
}
|
||||
|
||||
private NavigableMap<Instant, MarketBar> readReplayBars(ReplayRunConfig config) {
|
||||
Path path = Path.of(selectReplaySource(config).path());
|
||||
NavigableMap<Instant, MarketBar> bars = new TreeMap<>();
|
||||
try (CSVParser parser = CSVParser.parse(path, java.nio.charset.StandardCharsets.UTF_8,
|
||||
CSVFormat.DEFAULT.builder().setHeader().setSkipHeaderRecord(true).build())) {
|
||||
for (CSVRecord record : parser) {
|
||||
if (!config.symbol().equals(required(record, "symbol"))) {
|
||||
continue;
|
||||
}
|
||||
if (!"1m".equals(required(record, "timeframe"))) {
|
||||
continue;
|
||||
}
|
||||
MarketBar bar = marketBar(record);
|
||||
bars.put(bar.openTime(), bar);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "failed to read crypto lake replay csv: " + ex.getMessage());
|
||||
}
|
||||
if (bars.isEmpty()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "crypto lake replay csv has no rows for symbol: " + config.symbol());
|
||||
}
|
||||
return bars;
|
||||
}
|
||||
|
||||
private List<ReplayClockTick> readMarketAuditTicks(ReplayRunConfig config, NavigableMap<Instant, MarketBar> bars) {
|
||||
List<ReplayClockTick> ticks = new ArrayList<>();
|
||||
List<MarketBar> ordered = List.copyOf(bars.values());
|
||||
for (int i = 0; i < ordered.size(); i++) {
|
||||
MarketBar bar = ordered.get(i);
|
||||
if (outsideRunWindow(config, bar.openTime())) {
|
||||
continue;
|
||||
}
|
||||
ticks.add(toTick(config, bar, null, labelInputs(ordered, i, null)));
|
||||
}
|
||||
return ticks.stream()
|
||||
.sorted(Comparator.comparing(ReplayClockTick::eventTime))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private List<ReplayClockTick> readCandidateTicks(ReplayRunConfig config, NavigableMap<Instant, MarketBar> bars) {
|
||||
Path path = Path.of(config.dataSources().get(CANDIDATE_SOURCE_KEY).path());
|
||||
List<MarketBar> ordered = List.copyOf(bars.values());
|
||||
Map<Instant, Integer> indexByTime = new LinkedHashMap<>();
|
||||
for (int i = 0; i < ordered.size(); i++) {
|
||||
indexByTime.put(ordered.get(i).openTime(), i);
|
||||
}
|
||||
List<ReplayClockTick> ticks = new ArrayList<>();
|
||||
try (CSVParser parser = CSVParser.parse(path, java.nio.charset.StandardCharsets.UTF_8,
|
||||
CSVFormat.DEFAULT.builder().setHeader().setSkipHeaderRecord(true).build())) {
|
||||
for (CSVRecord record : parser) {
|
||||
if (!config.symbol().equals(required(record, "symbol"))) {
|
||||
continue;
|
||||
}
|
||||
Instant candidateTime = Instant.ofEpochMilli(requiredLong(record, "bar_time"));
|
||||
if (outsideRunWindow(config, candidateTime)) {
|
||||
continue;
|
||||
}
|
||||
Map.Entry<Instant, MarketBar> entry = bars.ceilingEntry(candidateTime);
|
||||
if (entry == null || outsideRunWindow(config, entry.getKey())) {
|
||||
continue;
|
||||
}
|
||||
int barIndex = indexByTime.get(entry.getKey());
|
||||
CandidateEvent event = candidateEvent(record, candidateTime);
|
||||
ticks.add(toTick(config, entry.getValue(), event, labelInputs(ordered, barIndex, event.side())));
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "failed to read candidate events csv: " + ex.getMessage());
|
||||
}
|
||||
return ticks.stream()
|
||||
.sorted(Comparator.comparing(ReplayClockTick::eventTime))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private ReplayClockTick toTick(
|
||||
ReplayRunConfig config,
|
||||
MarketBar bar,
|
||||
CandidateEvent candidate,
|
||||
Map<String, Object> labelInputs
|
||||
) {
|
||||
List<String> missing = missingFeatures(bar);
|
||||
Map<String, Object> context = new LinkedHashMap<>();
|
||||
context.put("contextPass", missing.isEmpty());
|
||||
context.put("replaySourceType", "CRYPTO_LAKE_1M_CSV");
|
||||
putDecimal(context, "sourceCoverage", bar.sourceCoverage());
|
||||
putDecimal(context, "fundingBps", bar.fundingBps());
|
||||
putDecimal(context, "openInterest", bar.openInterest());
|
||||
putDecimal(context, "volume", bar.volume());
|
||||
|
||||
Map<String, Object> setup = new LinkedHashMap<>();
|
||||
setup.put("setupPass", candidate != null);
|
||||
setup.put("setupName", candidate == null ? "market_audit_only" : "candidate_event_replay");
|
||||
if (candidate != null) {
|
||||
if (bar.close() == null) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "candidate event matched a replay bar without close price");
|
||||
}
|
||||
setup.put("candidateEventId", candidate.eventId());
|
||||
setup.put("signalType", candidate.signalType());
|
||||
setup.put("side", candidate.side().name());
|
||||
setup.put("sourceService", candidate.sourceService());
|
||||
putDecimal(setup, "entryPrice", bar.close());
|
||||
putDecimal(setup, "invalidPrice", priceByBps(bar.close(), invalidBps(candidate.side()), adverseSign(candidate.side())));
|
||||
putDecimal(setup, "stopPrice", priceByBps(bar.close(), stopBps(candidate.side()), adverseSign(candidate.side())));
|
||||
putDecimal(setup, "targetPrice", priceByBps(bar.close(), targetBps(candidate.side()), favorableSign(candidate.side())));
|
||||
putDecimal(setup, "executionQualityScore", executionQualityScore(bar));
|
||||
}
|
||||
|
||||
Map<String, Object> trigger = new LinkedHashMap<>();
|
||||
if (candidate != null && candidate.triggerScore() != null) {
|
||||
putDecimal(trigger, "triggerScore", candidate.triggerScore());
|
||||
}
|
||||
trigger.put("replayTriggerSource", candidate == null ? "NONE" : "CANDIDATE_EVENT");
|
||||
|
||||
Map<String, Object> execution = new LinkedHashMap<>();
|
||||
putDecimal(execution, "lastPrice", bar.close());
|
||||
putDecimal(execution, "bestBidPrice", bar.bestBidPrice());
|
||||
putDecimal(execution, "bestAskPrice", bar.bestAskPrice());
|
||||
putDecimal(execution, "observedSpreadBps", bar.observedSpreadBps());
|
||||
putDecimal(execution, "expectedSlippageBps", bar.expectedSlippageBps());
|
||||
putDecimal(execution, "p95LatencyMs", bar.p95LatencyMs());
|
||||
|
||||
Map<String, Object> dataQuality = new LinkedHashMap<>();
|
||||
dataQuality.put("missing_features", missing);
|
||||
putDecimal(dataQuality, "sourceCoverage", bar.sourceCoverage());
|
||||
dataQuality.put("replaySourcePath", selectReplaySource(config).path());
|
||||
|
||||
return new ReplayClockTick(
|
||||
config.runId(),
|
||||
config.symbol(),
|
||||
bar.openTime(),
|
||||
context,
|
||||
setup,
|
||||
trigger,
|
||||
execution,
|
||||
dataQuality,
|
||||
labelInputs
|
||||
);
|
||||
}
|
||||
|
||||
private Map<String, Object> labelInputs(List<MarketBar> bars, int index, TraderSide side) {
|
||||
Map<String, Object> labels = new LinkedHashMap<>();
|
||||
labels.put("labelSource", "CRYPTO_LAKE_1M_REPLAY");
|
||||
if (side == null) {
|
||||
labels.put("labelStatus", "MARKET_AUDIT_NO_SIDE");
|
||||
return labels;
|
||||
}
|
||||
MarketBar entry = bars.get(index);
|
||||
labels.put("side", side.name());
|
||||
putDecimal(labels, "entryPrice", entry.close());
|
||||
putIfPresent(labels, "markoutBps1m", markout(bars, index, side, 1));
|
||||
putIfPresent(labels, "markoutBps5m", markout(bars, index, side, 5));
|
||||
putIfPresent(labels, "markoutBps15m", markout(bars, index, side, 15));
|
||||
putIfPresent(labels, "mfeBps15m", mfe(bars, index, side, 15));
|
||||
putIfPresent(labels, "maeBps15m", mae(bars, index, side, 15));
|
||||
putIfPresent(labels, "targetBeforeStop15m", targetBeforeStop(bars, index, side, targetBps(side), stopBps(side), 15));
|
||||
putDecimal(labels, "expectedSlippageBps", entry.expectedSlippageBps());
|
||||
labels.put("labelStatus", hasMandatoryLabels(labels) ? "REPLAY_MARKOUT_LABELED" : "FUTURE_WINDOW_INCOMPLETE");
|
||||
return labels;
|
||||
}
|
||||
|
||||
private boolean hasMandatoryLabels(Map<String, Object> labels) {
|
||||
return labels.containsKey("markoutBps1m")
|
||||
&& labels.containsKey("markoutBps5m")
|
||||
&& labels.containsKey("markoutBps15m");
|
||||
}
|
||||
|
||||
private String markout(List<MarketBar> bars, int index, TraderSide side, int minutes) {
|
||||
if (index + minutes >= bars.size()) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal entry = bars.get(index).close();
|
||||
BigDecimal close = bars.get(index + minutes).close();
|
||||
return decimalText(sideReturnBps(side, entry, close));
|
||||
}
|
||||
|
||||
private String mfe(List<MarketBar> bars, int index, TraderSide side, int minutes) {
|
||||
if (index + minutes >= bars.size()) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal entry = bars.get(index).close();
|
||||
BigDecimal best = BigDecimal.ZERO;
|
||||
for (int i = index + 1; i <= index + minutes; i++) {
|
||||
BigDecimal favorable = side == TraderSide.LONG ? bars.get(i).high() : bars.get(i).low();
|
||||
best = best.max(sideReturnBps(side, entry, favorable));
|
||||
}
|
||||
return decimalText(best.max(BigDecimal.ZERO));
|
||||
}
|
||||
|
||||
private String mae(List<MarketBar> bars, int index, TraderSide side, int minutes) {
|
||||
if (index + minutes >= bars.size()) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal entry = bars.get(index).close();
|
||||
BigDecimal worst = BigDecimal.ZERO;
|
||||
for (int i = index + 1; i <= index + minutes; i++) {
|
||||
BigDecimal adverse = side == TraderSide.LONG ? bars.get(i).low() : bars.get(i).high();
|
||||
BigDecimal signed = sideReturnBps(side, entry, adverse);
|
||||
if (signed.compareTo(BigDecimal.ZERO) < 0) {
|
||||
worst = worst.max(signed.abs());
|
||||
}
|
||||
}
|
||||
return decimalText(worst);
|
||||
}
|
||||
|
||||
private Boolean targetBeforeStop(List<MarketBar> bars, int index, TraderSide side, BigDecimal targetBps, BigDecimal stopBps, int minutes) {
|
||||
if (index + minutes >= bars.size()) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal entry = bars.get(index).close();
|
||||
BigDecimal target = priceByBps(entry, targetBps, favorableSign(side));
|
||||
BigDecimal stop = priceByBps(entry, stopBps, adverseSign(side));
|
||||
for (int i = index + 1; i <= index + minutes; i++) {
|
||||
MarketBar bar = bars.get(i);
|
||||
boolean targetHit = side == TraderSide.LONG
|
||||
? bar.high().compareTo(target) >= 0
|
||||
: bar.low().compareTo(target) <= 0;
|
||||
boolean stopHit = side == TraderSide.LONG
|
||||
? bar.low().compareTo(stop) <= 0
|
||||
: bar.high().compareTo(stop) >= 0;
|
||||
if (targetHit) {
|
||||
return true;
|
||||
}
|
||||
if (stopHit) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private MarketBar marketBar(CSVRecord record) {
|
||||
return new MarketBar(
|
||||
Instant.parse(required(record, "open_time")),
|
||||
decimal(record, "open"),
|
||||
decimal(record, "high"),
|
||||
decimal(record, "low"),
|
||||
decimal(record, "close"),
|
||||
decimal(record, "volume"),
|
||||
decimal(record, "taker_buy_volume"),
|
||||
decimal(record, "funding_bps"),
|
||||
decimal(record, "open_interest"),
|
||||
decimal(record, "best_bid_price"),
|
||||
decimal(record, "best_ask_price"),
|
||||
decimal(record, "observed_spread_bps"),
|
||||
decimal(record, "expected_slippage_bps"),
|
||||
decimal(record, "p95_latency_ms"),
|
||||
decimal(record, "source_coverage")
|
||||
);
|
||||
}
|
||||
|
||||
private CandidateEvent candidateEvent(CSVRecord record, Instant candidateTime) {
|
||||
String side = required(record, "direction").toUpperCase();
|
||||
return new CandidateEvent(
|
||||
required(record, "event_id"),
|
||||
candidateTime,
|
||||
required(record, "signal_type"),
|
||||
TraderSide.valueOf(side),
|
||||
required(record, "source_service"),
|
||||
firstDecimal(record, "old_fusion_score", "legacy_fusion_score")
|
||||
);
|
||||
}
|
||||
|
||||
private List<String> missingFeatures(MarketBar bar) {
|
||||
List<String> missing = new ArrayList<>();
|
||||
requirePresent(missing, "open", bar.open());
|
||||
requirePresent(missing, "high", bar.high());
|
||||
requirePresent(missing, "low", bar.low());
|
||||
requirePresent(missing, "close", bar.close());
|
||||
requirePresent(missing, "taker_buy_volume", bar.takerBuyVolume());
|
||||
requirePresent(missing, "expected_slippage_bps", bar.expectedSlippageBps());
|
||||
requirePresent(missing, "source_coverage", bar.sourceCoverage());
|
||||
return missing;
|
||||
}
|
||||
|
||||
private void requirePresent(List<String> missing, String field, BigDecimal value) {
|
||||
if (value == null) {
|
||||
missing.add(field);
|
||||
}
|
||||
}
|
||||
|
||||
private DataSourceSpec selectReplaySource(ReplayRunConfig config) {
|
||||
DataSourceSpec source = config.dataSources().get(REPLAY_SOURCE_KEY);
|
||||
if (source == null) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "dataSources.cryptoLakeReplay1m is required");
|
||||
}
|
||||
return source;
|
||||
}
|
||||
|
||||
private void validateSource(DataSourceSpec source, String sourceType) {
|
||||
if (source.path() == null || source.path().isBlank()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "data source path is required: " + sourceType);
|
||||
}
|
||||
Path path = Path.of(source.path());
|
||||
if (!Files.isRegularFile(path) || !Files.isReadable(path)) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "data source is not readable: " + source.path());
|
||||
}
|
||||
}
|
||||
|
||||
private boolean outsideRunWindow(ReplayRunConfig config, Instant time) {
|
||||
return time.isBefore(config.from()) || !time.isBefore(config.to());
|
||||
}
|
||||
|
||||
private String required(CSVRecord record, String column) {
|
||||
String value = record.get(column);
|
||||
if (value == null || value.isBlank()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "csv column is required: " + column);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private long requiredLong(CSVRecord record, String column) {
|
||||
return Long.parseLong(required(record, column));
|
||||
}
|
||||
|
||||
private BigDecimal firstDecimal(CSVRecord record, String... columns) {
|
||||
for (String column : columns) {
|
||||
if (!record.isMapped(column)) {
|
||||
continue;
|
||||
}
|
||||
BigDecimal value = decimal(record, column);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private BigDecimal decimal(CSVRecord record, String column) {
|
||||
if (!record.isMapped(column)) {
|
||||
return null;
|
||||
}
|
||||
String value = record.get(column);
|
||||
if (value == null || value.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
return new BigDecimal(value);
|
||||
}
|
||||
|
||||
private String decimalText(BigDecimal value) {
|
||||
return value == null ? null : value.stripTrailingZeros().toPlainString();
|
||||
}
|
||||
|
||||
private void putDecimal(Map<String, Object> target, String key, BigDecimal value) {
|
||||
String text = decimalText(value);
|
||||
if (text != null) {
|
||||
target.put(key, text);
|
||||
}
|
||||
}
|
||||
|
||||
private void putIfPresent(Map<String, Object> target, String key, Object value) {
|
||||
if (value != null) {
|
||||
target.put(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
private BigDecimal executionQualityScore(MarketBar bar) {
|
||||
if (bar.expectedSlippageBps() == null) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal score = BigDecimal.ONE.subtract(bar.expectedSlippageBps().divide(new BigDecimal("20.0"), MC), MC);
|
||||
return score.max(new BigDecimal("0.20")).min(BigDecimal.ONE).setScale(8, RoundingMode.HALF_UP);
|
||||
}
|
||||
|
||||
private BigDecimal sideReturnBps(TraderSide side, BigDecimal entry, BigDecimal exit) {
|
||||
BigDecimal gross = exit.subtract(entry, MC)
|
||||
.divide(entry, MC)
|
||||
.multiply(new BigDecimal("10000"), MC);
|
||||
return side == TraderSide.LONG ? gross : gross.negate();
|
||||
}
|
||||
|
||||
private BigDecimal priceByBps(BigDecimal entry, BigDecimal bps, int sign) {
|
||||
BigDecimal multiplier = BigDecimal.ONE.add(BigDecimal.valueOf(sign).multiply(bps, MC).divide(new BigDecimal("10000"), MC), MC);
|
||||
return entry.multiply(multiplier, MC).setScale(8, RoundingMode.HALF_UP);
|
||||
}
|
||||
|
||||
private int favorableSign(TraderSide side) {
|
||||
return side == TraderSide.LONG ? 1 : -1;
|
||||
}
|
||||
|
||||
private int adverseSign(TraderSide side) {
|
||||
return side == TraderSide.LONG ? -1 : 1;
|
||||
}
|
||||
|
||||
private BigDecimal invalidBps(TraderSide side) {
|
||||
return side == TraderSide.LONG ? LONG_INVALID_BPS : SHORT_INVALID_BPS;
|
||||
}
|
||||
|
||||
private BigDecimal stopBps(TraderSide side) {
|
||||
return side == TraderSide.LONG ? LONG_STOP_BPS : SHORT_STOP_BPS;
|
||||
}
|
||||
|
||||
private BigDecimal targetBps(TraderSide side) {
|
||||
return side == TraderSide.LONG ? LONG_TARGET_BPS : SHORT_TARGET_BPS;
|
||||
}
|
||||
|
||||
private record MarketBar(
|
||||
Instant openTime,
|
||||
BigDecimal open,
|
||||
BigDecimal high,
|
||||
BigDecimal low,
|
||||
BigDecimal close,
|
||||
BigDecimal volume,
|
||||
BigDecimal takerBuyVolume,
|
||||
BigDecimal fundingBps,
|
||||
BigDecimal openInterest,
|
||||
BigDecimal bestBidPrice,
|
||||
BigDecimal bestAskPrice,
|
||||
BigDecimal observedSpreadBps,
|
||||
BigDecimal expectedSlippageBps,
|
||||
BigDecimal p95LatencyMs,
|
||||
BigDecimal sourceCoverage
|
||||
) {
|
||||
}
|
||||
|
||||
private record CandidateEvent(
|
||||
String eventId,
|
||||
Instant barTime,
|
||||
String signalType,
|
||||
TraderSide side,
|
||||
String sourceService,
|
||||
BigDecimal triggerScore
|
||||
) {
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,12 @@ public class JsonlReplayMarketEventReader implements ReplayMarketEventReader {
|
||||
this.objectMapper = new ObjectMapper().findAndRegisterModules();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supports(ReplayRunConfig config) {
|
||||
DataSourceSpec source = config.dataSources() == null ? null : config.dataSources().get("ticks");
|
||||
return source != null && source.path() != null && source.path().endsWith(".jsonl");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validateReadable(ReplayRunConfig config) {
|
||||
DataSourceSpec source = selectReplaySource(config);
|
||||
@@ -76,7 +82,8 @@ public class JsonlReplayMarketEventReader implements ReplayMarketEventReader {
|
||||
fixture.setupFeatures(),
|
||||
fixture.triggerFeatures(),
|
||||
fixture.executionFeatures(),
|
||||
fixture.dataQuality()
|
||||
fixture.dataQuality(),
|
||||
fixture.labelInputs() == null ? Map.of() : fixture.labelInputs()
|
||||
);
|
||||
} catch (IOException ex) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "invalid replay tick json: " + ex.getMessage());
|
||||
@@ -100,7 +107,8 @@ public class JsonlReplayMarketEventReader implements ReplayMarketEventReader {
|
||||
Map<String, Object> setupFeatures,
|
||||
Map<String, Object> triggerFeatures,
|
||||
Map<String, Object> executionFeatures,
|
||||
Map<String, Object> dataQuality
|
||||
Map<String, Object> dataQuality,
|
||||
Map<String, Object> labelInputs
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.quantai.trader.replay;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public record ReplayClockTick(
|
||||
@@ -11,6 +12,23 @@ public record ReplayClockTick(
|
||||
Map<String, Object> setupFeatures,
|
||||
Map<String, Object> triggerFeatures,
|
||||
Map<String, Object> executionFeatures,
|
||||
Map<String, Object> dataQuality
|
||||
Map<String, Object> dataQuality,
|
||||
Map<String, Object> labelInputs
|
||||
) {
|
||||
|
||||
public ReplayClockTick {
|
||||
contextFeatures = immutable(contextFeatures);
|
||||
setupFeatures = immutable(setupFeatures);
|
||||
triggerFeatures = immutable(triggerFeatures);
|
||||
executionFeatures = immutable(executionFeatures);
|
||||
dataQuality = immutable(dataQuality);
|
||||
labelInputs = immutable(labelInputs);
|
||||
}
|
||||
|
||||
private static Map<String, Object> immutable(Map<String, Object> value) {
|
||||
if (value == null || value.isEmpty()) {
|
||||
return Map.of();
|
||||
}
|
||||
return Map.copyOf(new LinkedHashMap<>(value));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import java.util.List;
|
||||
|
||||
public interface ReplayMarketEventReader {
|
||||
|
||||
boolean supports(ReplayRunConfig config);
|
||||
|
||||
void validateReadable(ReplayRunConfig config);
|
||||
|
||||
List<ReplayClockTick> readTicks(ReplayRunConfig config);
|
||||
|
||||
@@ -32,7 +32,7 @@ public class ReplayRunService {
|
||||
private final TraderPlaybookCatalog catalog;
|
||||
private final ReplayRunRepository repository;
|
||||
private final ReplayReportWriter reportWriter;
|
||||
private final ReplayMarketEventReader eventReader;
|
||||
private final List<ReplayMarketEventReader> eventReaders;
|
||||
private final TraderDecisionCycleRunner cycleRunner;
|
||||
private final ExecutorService executorService = Executors.newSingleThreadExecutor(runnable -> {
|
||||
Thread thread = new Thread(runnable, "trader-replay-worker");
|
||||
@@ -44,13 +44,13 @@ public class ReplayRunService {
|
||||
TraderPlaybookCatalog catalog,
|
||||
ReplayRunRepository repository,
|
||||
ReplayReportWriter reportWriter,
|
||||
ReplayMarketEventReader eventReader,
|
||||
List<ReplayMarketEventReader> eventReaders,
|
||||
TraderDecisionCycleRunner cycleRunner
|
||||
) {
|
||||
this.catalog = catalog;
|
||||
this.repository = repository;
|
||||
this.reportWriter = reportWriter;
|
||||
this.eventReader = eventReader;
|
||||
this.eventReaders = List.copyOf(eventReaders);
|
||||
this.cycleRunner = cycleRunner;
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ public class ReplayRunService {
|
||||
validateRequest(request);
|
||||
TraderPlaybookDefinitionSnapshot playbook = catalog.require(request.playbookId(), request.playbookVersion());
|
||||
request.dataSources().forEach((sourceType, spec) -> validateDataSource(request, sourceType, spec));
|
||||
eventReader.validateReadable(request);
|
||||
readerFor(request).validateReadable(request);
|
||||
|
||||
String runId = Ids.runId(Instant.now());
|
||||
ReplayRunConfig config = request.withRunId(runId);
|
||||
@@ -114,7 +114,7 @@ public class ReplayRunService {
|
||||
playbook.playbookVersion(),
|
||||
ReplayRunStatus.RUNNING
|
||||
);
|
||||
List<ReplayClockTick> ticks = eventReader.readTicks(run.config());
|
||||
List<ReplayClockTick> ticks = readerFor(run.config()).readTicks(run.config());
|
||||
List<TraderCycleResult> results = new ArrayList<>(ticks.size());
|
||||
TraderRuntimeState runtimeState = new TraderRuntimeState(
|
||||
run.runId(),
|
||||
@@ -186,6 +186,16 @@ public class ReplayRunService {
|
||||
.orElseThrow(() -> new IllegalStateException("replay run disappeared: " + runId));
|
||||
}
|
||||
|
||||
private ReplayMarketEventReader readerFor(ReplayRunConfig config) {
|
||||
return eventReaders.stream()
|
||||
.filter(reader -> reader.supports(config))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new TraderException(
|
||||
TraderErrorCode.TRADER_DATA_SOURCE_MISSING,
|
||||
"no replay reader supports the requested dataSources"
|
||||
));
|
||||
}
|
||||
|
||||
private void validateDataSource(ReplayRunConfig request, String sourceType, DataSourceSpec spec) {
|
||||
if (spec.timezone() == null || spec.timezone().isBlank()) {
|
||||
throw new TraderException(TraderErrorCode.TRADER_DATA_SOURCE_MISSING, "data source timezone is required: " + sourceType);
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.quantai.trader.report;
|
||||
|
||||
import com.quantai.trader.domain.TraderReplayReport;
|
||||
import com.quantai.trader.brain.TraderCycleResult;
|
||||
import com.quantai.trader.domain.TraderTrainingSample;
|
||||
import com.quantai.trader.persistence.ReplayReportRepository;
|
||||
import com.quantai.trader.playbook.TraderPlaybookDefinitionSnapshot;
|
||||
import com.quantai.trader.replay.ReplayRunConfig;
|
||||
@@ -11,8 +12,10 @@ import org.springframework.stereotype.Component;
|
||||
import java.math.BigDecimal;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Component
|
||||
public class ReplayReportWriter {
|
||||
@@ -30,13 +33,25 @@ public class ReplayReportWriter {
|
||||
) {
|
||||
int actionCount = (int) results.stream().filter(result -> result.action() != null).count();
|
||||
int sampleCount = (int) results.stream().filter(result -> result.sample() != null).count();
|
||||
SampleAudit audit = audit(results);
|
||||
int monthsCovered = Math.max(1, (int) ChronoUnit.MONTHS.between(
|
||||
config.from().atZone(java.time.ZoneOffset.UTC).withDayOfMonth(1),
|
||||
config.to().atZone(java.time.ZoneOffset.UTC).withDayOfMonth(1)
|
||||
));
|
||||
List<String> failureRisks = actionCount == 0
|
||||
? List.of("no_action_generated", "proxy_only_execution")
|
||||
: List.of("proxy_only_execution");
|
||||
List<String> failureRisks = failureRisks(actionCount, audit);
|
||||
Map<String, Object> auditReport = new LinkedHashMap<>();
|
||||
auditReport.put("replayEngine", replayEngine(config));
|
||||
auditReport.put("tickCount", results.size());
|
||||
auditReport.put("sampleCount", sampleCount);
|
||||
auditReport.put("actionCount", actionCount);
|
||||
auditReport.put("labeledSampleCount", audit.labeledSampleCount());
|
||||
auditReport.put("proxyOnlySampleCount", audit.proxyOnlySampleCount());
|
||||
auditReport.put("positiveNetReturnCount", audit.positiveNetReturnCount());
|
||||
auditReport.put("negativeNetReturnCount", audit.negativeNetReturnCount());
|
||||
auditReport.put("missingNetReturnCount", audit.missingNetReturnCount());
|
||||
putIfPresent(auditReport, "meanNetReturnBps1x", audit.meanNetReturnBps1x());
|
||||
putIfPresent(auditReport, "meanNetReturnBps10x", audit.meanNetReturnBps10x());
|
||||
auditReport.put("labelStatusDistribution", audit.labelStatusDistribution());
|
||||
TraderReplayReport report = new TraderReplayReport(
|
||||
config.runId(),
|
||||
Ids.reportId(config.runId()),
|
||||
@@ -45,21 +60,94 @@ public class ReplayReportWriter {
|
||||
playbook.playbookVersion(),
|
||||
actionCount,
|
||||
monthsCovered,
|
||||
BigDecimal.ZERO,
|
||||
BigDecimal.ZERO,
|
||||
BigDecimal.ZERO,
|
||||
Map.of(
|
||||
"p0ReplayEngine", "jsonl_fixture",
|
||||
"tickCount", results.size(),
|
||||
"sampleCount", sampleCount,
|
||||
"actionCount", actionCount
|
||||
),
|
||||
audit.meanNetReturnBps1x(),
|
||||
audit.meanNetReturnBps10x(),
|
||||
null,
|
||||
auditReport,
|
||||
failureRisks,
|
||||
"P0_OBSERVE_ONLY",
|
||||
audit.labeledSampleCount() > 0 ? "TRAINING_SAMPLE_AUDIT_ONLY" : "P0_OBSERVE_ONLY",
|
||||
null,
|
||||
Instant.now()
|
||||
);
|
||||
repository.insert(report);
|
||||
return report;
|
||||
}
|
||||
|
||||
private String replayEngine(ReplayRunConfig config) {
|
||||
if (config.dataSources().containsKey("cryptoLakeReplay1m")) {
|
||||
return "crypto_lake_1m_csv";
|
||||
}
|
||||
return "jsonl_fixture";
|
||||
}
|
||||
|
||||
private List<String> failureRisks(int actionCount, SampleAudit audit) {
|
||||
java.util.ArrayList<String> risks = new java.util.ArrayList<>();
|
||||
if (actionCount == 0) {
|
||||
risks.add("no_action_generated");
|
||||
}
|
||||
if (audit.labeledSampleCount() == 0) {
|
||||
risks.add("no_replay_markout_labels");
|
||||
}
|
||||
if (audit.proxyOnlySampleCount() > 0) {
|
||||
risks.add("proxy_only_samples_present");
|
||||
}
|
||||
if (audit.missingNetReturnCount() > 0) {
|
||||
risks.add("missing_net_return_labels");
|
||||
}
|
||||
return risks;
|
||||
}
|
||||
|
||||
private SampleAudit audit(List<TraderCycleResult> results) {
|
||||
List<TraderTrainingSample> samples = results.stream()
|
||||
.map(TraderCycleResult::sample)
|
||||
.filter(Objects::nonNull)
|
||||
.toList();
|
||||
int proxyOnly = (int) samples.stream().filter(TraderTrainingSample::proxyOnly).count();
|
||||
int labeled = samples.size() - proxyOnly;
|
||||
int missingNet = (int) samples.stream().filter(sample -> sample.netReturnBps1x() == null).count();
|
||||
int positive = (int) samples.stream()
|
||||
.filter(sample -> sample.netReturnBps1x() != null && sample.netReturnBps1x().compareTo(BigDecimal.ZERO) > 0)
|
||||
.count();
|
||||
int negative = (int) samples.stream()
|
||||
.filter(sample -> sample.netReturnBps1x() != null && sample.netReturnBps1x().compareTo(BigDecimal.ZERO) < 0)
|
||||
.count();
|
||||
BigDecimal mean1x = mean(samples.stream()
|
||||
.map(TraderTrainingSample::netReturnBps1x)
|
||||
.filter(Objects::nonNull)
|
||||
.toList());
|
||||
BigDecimal mean10x = mean(samples.stream()
|
||||
.map(TraderTrainingSample::netReturnBps10x)
|
||||
.filter(Objects::nonNull)
|
||||
.toList());
|
||||
Map<String, Long> labelStatuses = samples.stream()
|
||||
.map(sample -> String.valueOf(sample.labels().getOrDefault("label_status", "UNKNOWN")))
|
||||
.collect(java.util.stream.Collectors.groupingBy(status -> status, LinkedHashMap::new, java.util.stream.Collectors.counting()));
|
||||
return new SampleAudit(labeled, proxyOnly, positive, negative, missingNet, mean1x, mean10x, labelStatuses);
|
||||
}
|
||||
|
||||
private BigDecimal mean(List<BigDecimal> values) {
|
||||
if (values.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal sum = values.stream().reduce(BigDecimal.ZERO, BigDecimal::add);
|
||||
return sum.divide(BigDecimal.valueOf(values.size()), 8, java.math.RoundingMode.HALF_UP);
|
||||
}
|
||||
|
||||
private void putIfPresent(Map<String, Object> target, String key, Object value) {
|
||||
if (value != null) {
|
||||
target.put(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
private record SampleAudit(
|
||||
int labeledSampleCount,
|
||||
int proxyOnlySampleCount,
|
||||
int positiveNetReturnCount,
|
||||
int negativeNetReturnCount,
|
||||
int missingNetReturnCount,
|
||||
BigDecimal meanNetReturnBps1x,
|
||||
BigDecimal meanNetReturnBps10x,
|
||||
Map<String, Long> labelStatusDistribution
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.quantai.trader.sample;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
|
||||
public record TrainingLabelSet(
|
||||
Map<String, Object> labels,
|
||||
BigDecimal netReturnBps1x,
|
||||
BigDecimal netReturnBps10x,
|
||||
boolean proxyOnly
|
||||
) {
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.PlaybookCandidate;
|
||||
import com.quantai.trader.domain.TraderAction;
|
||||
import com.quantai.trader.domain.TraderDecisionCycle;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import com.quantai.trader.domain.TraderPositionPath;
|
||||
import com.quantai.trader.domain.TraderTrainingSample;
|
||||
import com.quantai.trader.persistence.TraderSampleRepository;
|
||||
@@ -13,6 +14,7 @@ import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
@@ -21,31 +23,23 @@ public class TrainingSampleExporter {
|
||||
private static final Logger log = LoggerFactory.getLogger(TrainingSampleExporter.class);
|
||||
private final TraderProperties properties;
|
||||
private final TraderSampleRepository repository;
|
||||
private final TriggerMarkoutLabeler labeler;
|
||||
|
||||
public TrainingSampleExporter(TraderProperties properties, TraderSampleRepository repository) {
|
||||
public TrainingSampleExporter(TraderProperties properties, TraderSampleRepository repository, TriggerMarkoutLabeler labeler) {
|
||||
this.properties = properties;
|
||||
this.repository = repository;
|
||||
this.labeler = labeler;
|
||||
}
|
||||
|
||||
public TraderTrainingSample export(
|
||||
TraderDecisionCycle cycle,
|
||||
TraderMarketSnapshot snapshot,
|
||||
PlaybookCandidate candidate,
|
||||
TraderAction action,
|
||||
TraderPositionPath path
|
||||
) {
|
||||
Map<String, Object> features = Map.of(
|
||||
"playbookId", candidate == null ? cycle.playbookId() : candidate.playbookId(),
|
||||
"playbookVersion", candidate == null ? cycle.playbookVersion() : candidate.playbookVersion(),
|
||||
"state", cycle.state().name(),
|
||||
"actionType", action == null ? "NONE" : action.actionType().name(),
|
||||
"proxyOnly", true
|
||||
);
|
||||
Map<String, Object> labels = Map.of(
|
||||
"trigger_acceptance", action != null,
|
||||
"target_before_stop", path != null && path.targetBeforeStop(),
|
||||
"stagnation_timeout_hit", path != null && path.stagnationTimeoutHit(),
|
||||
"best_counterfactual_action", action == null ? "WAIT" : action.actionType().name()
|
||||
);
|
||||
TrainingLabelSet labelSet = labeler.label(snapshot, candidate, action, path);
|
||||
Map<String, Object> features = features(cycle, snapshot, candidate, action);
|
||||
TraderTrainingSample sample = new TraderTrainingSample(
|
||||
cycle.runId(),
|
||||
cycle.cycleId(),
|
||||
@@ -56,13 +50,13 @@ public class TrainingSampleExporter {
|
||||
properties.getLabelVersion(),
|
||||
cycle.cycleTime(),
|
||||
features,
|
||||
labels,
|
||||
BigDecimal.ZERO,
|
||||
BigDecimal.ZERO,
|
||||
true
|
||||
labelSet.labels(),
|
||||
labelSet.netReturnBps1x(),
|
||||
labelSet.netReturnBps10x(),
|
||||
labelSet.proxyOnly()
|
||||
);
|
||||
log.info(
|
||||
"event=trader.sample.export_start runId={} cycleId={} symbol={} playbookId={} playbookVersion={} state={} actionId={} positionId={} sampleId={} proxyOnly=true",
|
||||
"event=trader.sample.export_start runId={} cycleId={} symbol={} playbookId={} playbookVersion={} state={} actionId={} positionId={} sampleId={} proxyOnly={} labelStatus={}",
|
||||
cycle.runId(),
|
||||
cycle.cycleId(),
|
||||
cycle.symbol(),
|
||||
@@ -71,11 +65,13 @@ public class TrainingSampleExporter {
|
||||
cycle.state(),
|
||||
sample.actionId(),
|
||||
sample.positionId(),
|
||||
sample.sampleId()
|
||||
sample.sampleId(),
|
||||
sample.proxyOnly(),
|
||||
sample.labels().get("label_status")
|
||||
);
|
||||
repository.insert(sample);
|
||||
log.info(
|
||||
"event=trader.sample.exported runId={} cycleId={} symbol={} playbookId={} playbookVersion={} state={} actionId={} positionId={} sampleId={} proxyOnly=true",
|
||||
"event=trader.sample.exported runId={} cycleId={} symbol={} playbookId={} playbookVersion={} state={} actionId={} positionId={} sampleId={} proxyOnly={} netReturnBps1x={}",
|
||||
cycle.runId(),
|
||||
cycle.cycleId(),
|
||||
cycle.symbol(),
|
||||
@@ -84,8 +80,35 @@ public class TrainingSampleExporter {
|
||||
cycle.state(),
|
||||
sample.actionId(),
|
||||
sample.positionId(),
|
||||
sample.sampleId()
|
||||
sample.sampleId(),
|
||||
sample.proxyOnly(),
|
||||
sample.netReturnBps1x()
|
||||
);
|
||||
return sample;
|
||||
}
|
||||
|
||||
private Map<String, Object> features(
|
||||
TraderDecisionCycle cycle,
|
||||
TraderMarketSnapshot snapshot,
|
||||
PlaybookCandidate candidate,
|
||||
TraderAction action
|
||||
) {
|
||||
Map<String, Object> features = new LinkedHashMap<>();
|
||||
features.put("playbookId", candidate == null ? cycle.playbookId() : candidate.playbookId());
|
||||
features.put("playbookVersion", candidate == null ? cycle.playbookVersion() : candidate.playbookVersion());
|
||||
features.put("state", cycle.state().name());
|
||||
features.put("actionType", action == null ? "NONE" : action.actionType().name());
|
||||
if (candidate != null) {
|
||||
features.put("candidateSide", candidate.side().name());
|
||||
features.put("candidateVariant", candidate.variant());
|
||||
}
|
||||
if (snapshot != null) {
|
||||
features.put("context", snapshot.contextFeatures());
|
||||
features.put("setup", snapshot.setupFeatures());
|
||||
features.put("trigger", snapshot.triggerFeatures());
|
||||
features.put("execution", snapshot.executionFeatures());
|
||||
features.put("dataQuality", snapshot.dataQuality());
|
||||
}
|
||||
return features;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package com.quantai.trader.sample;
|
||||
|
||||
import com.quantai.trader.domain.PlaybookCandidate;
|
||||
import com.quantai.trader.domain.TraderAction;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
import com.quantai.trader.domain.TraderPositionPath;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.MathContext;
|
||||
import java.math.RoundingMode;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class TriggerMarkoutLabeler {
|
||||
|
||||
private static final MathContext MC = new MathContext(16, RoundingMode.HALF_UP);
|
||||
private static final BigDecimal TAKER_FEE_ROUND_TRIP_BPS = new BigDecimal("8.0");
|
||||
|
||||
public TrainingLabelSet label(
|
||||
TraderMarketSnapshot snapshot,
|
||||
PlaybookCandidate candidate,
|
||||
TraderAction action,
|
||||
TraderPositionPath path
|
||||
) {
|
||||
Map<String, Object> labels = new LinkedHashMap<>();
|
||||
labels.put("label_family", "TRIGGER_MARKOUT");
|
||||
labels.put("trigger_acceptance", action != null);
|
||||
labels.put("target_before_stop", path != null && path.targetBeforeStop());
|
||||
labels.put("stagnation_timeout_hit", path != null && path.stagnationTimeoutHit());
|
||||
labels.put("action_type", action == null ? "NONE" : action.actionType().name());
|
||||
if (candidate != null) {
|
||||
labels.put("candidate_side", candidate.side().name());
|
||||
}
|
||||
|
||||
if (snapshot == null || snapshot.labelInputs().isEmpty()) {
|
||||
labels.put("label_status", "PROXY_ONLY_NO_REPLAY_LABEL");
|
||||
labels.put("best_counterfactual_action", action == null ? "WAIT" : action.actionType().name());
|
||||
return new TrainingLabelSet(labels, null, null, true);
|
||||
}
|
||||
|
||||
Map<String, Object> labelInputs = snapshot.labelInputs();
|
||||
labelInputs.forEach((key, value) -> labels.put("replay_" + key, value));
|
||||
String labelStatus = String.valueOf(labelInputs.getOrDefault("labelStatus", "UNKNOWN"));
|
||||
labels.put("label_status", labelStatus);
|
||||
labels.put("best_counterfactual_action", counterfactualAction(labelInputs));
|
||||
|
||||
BigDecimal netReturn1x = netReturn1x(labelInputs);
|
||||
BigDecimal netReturn10x = netReturn1x == null
|
||||
? null
|
||||
: netReturn1x.multiply(BigDecimal.TEN, MC).setScale(8, RoundingMode.HALF_UP);
|
||||
boolean proxyOnly = !"REPLAY_MARKOUT_LABELED".equals(labelStatus);
|
||||
return new TrainingLabelSet(labels, netReturn1x, netReturn10x, proxyOnly);
|
||||
}
|
||||
|
||||
private String counterfactualAction(Map<String, Object> labelInputs) {
|
||||
BigDecimal netReturn = netReturn1x(labelInputs);
|
||||
if (netReturn == null) {
|
||||
return "WAIT";
|
||||
}
|
||||
return netReturn.compareTo(BigDecimal.ZERO) > 0 ? "OPEN_INITIAL" : "WAIT";
|
||||
}
|
||||
|
||||
private BigDecimal netReturn1x(Map<String, Object> labelInputs) {
|
||||
BigDecimal markout15m = decimal(labelInputs.get("markoutBps15m"));
|
||||
BigDecimal expectedSlippage = decimal(labelInputs.get("expectedSlippageBps"));
|
||||
if (markout15m == null || expectedSlippage == null) {
|
||||
return null;
|
||||
}
|
||||
// TriggerMarkout is a market-path label. Round-trip taker fee and
|
||||
// level_1 expected slippage keep the label cost-aware without pretending
|
||||
// we have real App fill feedback.
|
||||
BigDecimal executionCost = TAKER_FEE_ROUND_TRIP_BPS.add(expectedSlippage.multiply(BigDecimal.valueOf(2), MC), MC);
|
||||
return markout15m.subtract(executionCost, MC).setScale(8, RoundingMode.HALF_UP);
|
||||
}
|
||||
|
||||
private BigDecimal decimal(Object value) {
|
||||
if (value instanceof BigDecimal decimal) {
|
||||
return decimal;
|
||||
}
|
||||
if (value instanceof Number number) {
|
||||
return BigDecimal.valueOf(number.doubleValue());
|
||||
}
|
||||
if (value instanceof String text && !text.isBlank()) {
|
||||
return new BigDecimal(text);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user