Track replay position state across cycles

This commit is contained in:
Codex
2026-06-26 22:17:48 +08:00
parent 4e5f49d6fe
commit dad6b831b4
9 changed files with 295 additions and 30 deletions
@@ -11,6 +11,8 @@ import com.quantai.trader.outbox.TraderOutboxRepository;
import com.quantai.trader.outbox.TraderOutboxEvent;
import com.quantai.trader.persistence.TraderDecisionTraceWriter;
import com.quantai.trader.position.TraderPositionManager;
import com.quantai.trader.replay.state.TraderReplayState;
import com.quantai.trader.replay.state.TraderReplayStateStore;
import com.quantai.trader.risk.RiskGateInput;
import com.quantai.trader.risk.RiskLimits;
import com.quantai.trader.risk.TraderRiskGate;
@@ -35,6 +37,7 @@ public class TraderP0CycleRunner {
private final EvidenceAppender evidenceAppender;
private final TraderDecisionTraceWriter traceWriter;
private final TraderOutboxRepository outboxRepository;
private final TraderReplayStateStore stateStore;
public TraderP0CycleRunner(TraderProperties properties,
TraderArtifactLoader artifactLoader,
@@ -44,7 +47,8 @@ public class TraderP0CycleRunner {
TraderActionFactory actionFactory,
EvidenceAppender evidenceAppender,
TraderDecisionTraceWriter traceWriter,
TraderOutboxRepository outboxRepository) {
TraderOutboxRepository outboxRepository,
TraderReplayStateStore stateStore) {
this.properties = properties;
this.artifactLoader = artifactLoader;
this.modelService = modelService;
@@ -54,19 +58,21 @@ public class TraderP0CycleRunner {
this.evidenceAppender = evidenceAppender;
this.traceWriter = traceWriter;
this.outboxRepository = outboxRepository;
this.stateStore = stateStore;
}
public TraderCycleResult runFlatCycle(ReplayMarketEvent event) {
public TraderCycleResult runCycle(ReplayMarketEvent event) {
String cycleId = "cycle_" + event.runId() + "_" + event.eventTime().toEpochMilli();
TraderArtifactBundle bundle = artifactLoader.loadActiveBundle();
TraderDecisionCycle cycle = new TraderDecisionCycle(event.runId(), cycleId, event.symbol(), event.eventTime(),
properties.runMode(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion());
TraderMarketSnapshot snapshot = snapshot(event, cycleId);
TraderReplayState state = stateStore.load(cycle, snapshot);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of());
TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId()));
PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput,
flatPosition(cycle, snapshot), account(cycle), execution(cycle), bundle.pmConfig());
state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig());
TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput);
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name()));
TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(),
@@ -77,6 +83,7 @@ public class TraderP0CycleRunner {
outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(),
"TRADER_ACTION", action.actionId(), "ACTION_CREATED", properties.runMode().name() + "_RECORDER",
Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now()));
stateStore.advance(state, action, snapshot);
log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING", action.runId(), action.cycleId(), action.actionType());
return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action);
}
@@ -88,23 +95,6 @@ public class TraderP0CycleRunner {
event.depthNotional5Bps().compareTo(BigDecimal.ZERO) > 0, Map.of(), Map.of());
}
private TraderPositionState flatPosition(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) {
return new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, new BigDecimal("1000"),
0, BigDecimal.ONE, null);
}
private TraderAccountState account(TraderDecisionCycle cycle) {
return new TraderAccountState("account_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(),
BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ONE, 0);
}
private TraderExecutionState execution(TraderDecisionCycle cycle) {
return new TraderExecutionState("execution_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
java.util.List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"),
new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE);
}
private RiskLimits riskLimits() {
return new RiskLimits(properties.risk().maxDailyLossBps(), properties.risk().maxTotalExposureRatio(),
properties.risk().minLiquidationBufferBps(), properties.execution().maxApiErrorCount(),
@@ -0,0 +1,166 @@
package com.quantai.trader.replay.state;
import com.quantai.trader.config.TraderProperties;
import com.quantai.trader.domain.*;
import com.quantai.trader.enums.PositionSide;
import com.quantai.trader.enums.TraderActionType;
import com.quantai.trader.util.TraderNumbers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.math.BigDecimal;
import java.math.MathContext;
import java.time.Instant;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
/**
* Process-local state for ordered P0 replay/shadow cycles; each runId+symbol stream advances from the prior cycle.
*/
@Component
public class P0ReplayStateStore implements TraderReplayStateStore {
private static final Logger log = LoggerFactory.getLogger(P0ReplayStateStore.class);
private final TraderProperties properties;
private final ConcurrentHashMap<String, TraderReplayState> states = new ConcurrentHashMap<>();
public P0ReplayStateStore(TraderProperties properties) {
this.properties = properties;
}
@Override
public TraderReplayState load(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) {
TraderReplayState state = states.computeIfAbsent(key(cycle.runId(), cycle.symbol()), ignored -> flatState(cycle, snapshot));
TraderReplayState refreshed = refreshForCycle(cycle, snapshot, state);
states.put(key(cycle.runId(), cycle.symbol()), refreshed);
log.info("event=trader.replay_state.loaded runId={} cycleId={} side={} ratio={}",
cycle.runId(), cycle.cycleId(), refreshed.positionState().side(), refreshed.positionState().positionRatio());
return refreshed;
}
@Override
public TraderReplayState advance(TraderReplayState current, TraderAction action, TraderMarketSnapshot snapshot) {
TraderPositionState nextPosition = switch (action.actionType()) {
case OPEN_LONG, OPEN_SHORT -> openPosition(current.positionState(), action, snapshot);
case ADD_LONG, ADD_SHORT -> addPosition(current.positionState(), action, snapshot);
case REDUCE_LONG, REDUCE_SHORT -> reducePosition(current.positionState(), action, snapshot);
case CLOSE_LONG, CLOSE_SHORT -> flatPosition(current.positionState(), action.cycleId(), snapshot);
case WAIT, HOLD, MOVE_STOP, CANCEL -> refreshPosition(current.positionState(), action.cycleId(), snapshot);
};
TraderReplayState next = new TraderReplayState(
nextPosition,
account(action, nextPosition),
execution(action));
states.put(key(action.runId(), action.symbol()), next);
log.info("event=trader.replay_state.advanced runId={} cycleId={} action={} side={} ratio={}",
action.runId(), action.cycleId(), action.actionType(), nextPosition.side(), nextPosition.positionRatio());
return next;
}
private TraderReplayState flatState(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) {
TraderPositionState position = new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, new BigDecimal("1000"),
0, properties.positionManager().maxTotalPositionRatio(), null);
return new TraderReplayState(position, account(cycle, position), execution(cycle));
}
private TraderReplayState refreshForCycle(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderReplayState state) {
TraderPositionState position = new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
state.positionState().side(), state.positionState().positionRatio(), state.positionState().averageEntryPrice(),
snapshot.markPrice(), pnlBps(state.positionState().side(), state.positionState().averageEntryPrice(), snapshot.markPrice()),
state.positionState().liquidationBufferBps(), state.positionState().addCount(),
remainingCapacity(state.positionState().positionRatio()), state.positionState().lastAddTime());
return new TraderReplayState(position, account(cycle, position), execution(cycle));
}
private TraderPositionState openPosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
PositionSide side = action.actionType() == TraderActionType.OPEN_LONG ? PositionSide.LONG : PositionSide.SHORT;
BigDecimal ratio = TraderNumbers.required(action.positionRatio(), "action.positionRatio");
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
side, ratio, snapshot.markPrice(), snapshot.markPrice(), BigDecimal.ZERO, current.liquidationBufferBps(),
0, remainingCapacity(ratio), null);
}
private TraderPositionState addPosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
BigDecimal addRatio = TraderNumbers.required(action.positionRatio(), "action.positionRatio");
BigDecimal newRatio = current.positionRatio().add(addRatio);
BigDecimal weightedEntry = weightedEntry(current.averageEntryPrice(), current.positionRatio(), snapshot.markPrice(), addRatio);
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
current.side(), newRatio, weightedEntry, snapshot.markPrice(), pnlBps(current.side(), weightedEntry, snapshot.markPrice()),
current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), Instant.now());
}
private TraderPositionState reducePosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
BigDecimal reduceRatio = TraderNumbers.required(action.positionRatio(), "action.positionRatio");
BigDecimal newRatio = current.positionRatio().multiply(BigDecimal.ONE.subtract(reduceRatio), MathContext.DECIMAL64);
if (newRatio.compareTo(new BigDecimal("0.00000001")) <= 0) {
return flatPosition(current, action.cycleId(), snapshot);
}
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
current.side(), newRatio, current.averageEntryPrice(), snapshot.markPrice(),
pnlBps(current.side(), current.averageEntryPrice(), snapshot.markPrice()), current.liquidationBufferBps(),
current.addCount(), remainingCapacity(newRatio), current.lastAddTime());
}
private TraderPositionState refreshPosition(TraderPositionState current, String cycleId, TraderMarketSnapshot snapshot) {
if (current.isFlat()) {
return flatPosition(current, cycleId, snapshot);
}
return new TraderPositionState("position_state_" + cycleId, current.runId(), cycleId, current.symbol(),
current.side(), current.positionRatio(), current.averageEntryPrice(), snapshot.markPrice(),
pnlBps(current.side(), current.averageEntryPrice(), snapshot.markPrice()), current.liquidationBufferBps(),
current.addCount(), remainingCapacity(current.positionRatio()), current.lastAddTime());
}
private TraderPositionState flatPosition(TraderPositionState current, String cycleId, TraderMarketSnapshot snapshot) {
return new TraderPositionState("position_state_" + cycleId, current.runId(), cycleId, current.symbol(),
PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO,
current.liquidationBufferBps(), 0, properties.positionManager().maxTotalPositionRatio(), null);
}
private TraderAccountState account(TraderDecisionCycle cycle, TraderPositionState position) {
return new TraderAccountState("account_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(),
BigDecimal.ZERO, position.positionRatio(), remainingCapacity(position.positionRatio()), 0);
}
private TraderAccountState account(TraderAction action, TraderPositionState position) {
return new TraderAccountState("account_state_" + action.cycleId(), action.runId(), action.cycleId(),
BigDecimal.ZERO, position.positionRatio(), remainingCapacity(position.positionRatio()), 0);
}
private TraderExecutionState execution(TraderDecisionCycle cycle) {
return new TraderExecutionState("execution_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"),
new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE);
}
private TraderExecutionState execution(TraderAction action) {
return new TraderExecutionState("execution_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"),
new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE);
}
private BigDecimal weightedEntry(BigDecimal currentEntry, BigDecimal currentRatio, BigDecimal addPrice, BigDecimal addRatio) {
BigDecimal total = currentRatio.add(addRatio);
return currentEntry.multiply(currentRatio).add(addPrice.multiply(addRatio)).divide(total, MathContext.DECIMAL64);
}
private BigDecimal pnlBps(PositionSide side, BigDecimal averageEntryPrice, BigDecimal currentPrice) {
if (side == PositionSide.NONE || averageEntryPrice == null) {
return BigDecimal.ZERO;
}
BigDecimal priceMove = side.isLong()
? currentPrice.subtract(averageEntryPrice)
: averageEntryPrice.subtract(currentPrice);
return priceMove.divide(averageEntryPrice, MathContext.DECIMAL64).multiply(new BigDecimal("10000"));
}
private BigDecimal remainingCapacity(BigDecimal positionRatio) {
return properties.positionManager().maxTotalPositionRatio().subtract(positionRatio).max(BigDecimal.ZERO);
}
private String key(String runId, String symbol) {
return runId + "::" + symbol;
}
}
@@ -0,0 +1,12 @@
package com.quantai.trader.replay.state;
import com.quantai.trader.domain.TraderAccountState;
import com.quantai.trader.domain.TraderExecutionState;
import com.quantai.trader.domain.TraderPositionState;
public record TraderReplayState(
TraderPositionState positionState,
TraderAccountState accountState,
TraderExecutionState executionState
) {
}
@@ -0,0 +1,11 @@
package com.quantai.trader.replay.state;
import com.quantai.trader.domain.TraderAction;
import com.quantai.trader.domain.TraderDecisionCycle;
import com.quantai.trader.domain.TraderMarketSnapshot;
public interface TraderReplayStateStore {
TraderReplayState load(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot);
TraderReplayState advance(TraderReplayState current, TraderAction action, TraderMarketSnapshot snapshot);
}