diff --git a/src/main/java/com/quantai/trader/controller/TraderReplayController.java b/src/main/java/com/quantai/trader/controller/TraderReplayController.java index c33e344..9363555 100644 --- a/src/main/java/com/quantai/trader/controller/TraderReplayController.java +++ b/src/main/java/com/quantai/trader/controller/TraderReplayController.java @@ -17,6 +17,6 @@ public class TraderReplayController { @PostMapping("/api/trader/replay/cycles") public TraderCycleResult runOneCycle(@RequestBody ReplayMarketEvent event) { - return runner.runFlatCycle(event); + return runner.runCycle(event); } } diff --git a/src/main/java/com/quantai/trader/domain/TraderAction.java b/src/main/java/com/quantai/trader/domain/TraderAction.java index e2bf39d..1f0d1cf 100644 --- a/src/main/java/com/quantai/trader/domain/TraderAction.java +++ b/src/main/java/com/quantai/trader/domain/TraderAction.java @@ -47,6 +47,12 @@ public record TraderAction( pricePlanConfigHash = requiredText(pricePlanConfigHash, "pricePlanConfigHash"); positionRatio = positive(positionRatio, "positionRatio"); } + if (actionType == TraderActionType.REDUCE_LONG || actionType == TraderActionType.REDUCE_SHORT) { + positionRatio = positive(positionRatio, "positionRatio"); + if (positionRatio.compareTo(ONE) > 0) { + throw new IllegalArgumentException("positionRatio must be <= 1 for reduce action"); + } + } if (actionType.reducesExposure() && !reduceOnly) { throw new IllegalArgumentException("reduce/close action must be reduceOnly"); } diff --git a/src/main/java/com/quantai/trader/domain/TraderActionFactory.java b/src/main/java/com/quantai/trader/domain/TraderActionFactory.java index c3a4cfe..c85c723 100644 --- a/src/main/java/com/quantai/trader/domain/TraderActionFactory.java +++ b/src/main/java/com/quantai/trader/domain/TraderActionFactory.java @@ -4,6 +4,7 @@ import com.quantai.trader.enums.PositionSide; import com.quantai.trader.enums.TraderActionType; import org.springframework.stereotype.Component; +import java.math.BigDecimal; import java.util.Map; @Component @@ -23,7 +24,7 @@ public class TraderActionFactory { side, finalAction.increasesExposure() ? pmDecision.pricePlanId() : null, finalAction.increasesExposure() ? pmDecision.pricePlanConfigHash() : null, - finalAction == TraderActionType.OPEN_LONG || finalAction == TraderActionType.OPEN_SHORT ? pmDecision.targetPositionRatio() : pmDecision.addRatio(), + ratioFor(finalAction, pmDecision), null, pmDecision.stopPrice(), pmDecision.targetPrice(), @@ -33,6 +34,15 @@ public class TraderActionFactory { Map.of("riskAllowed", riskDecision.allowAction())); } + private BigDecimal ratioFor(TraderActionType action, TraderPositionManagerDecision pmDecision) { + return switch (action) { + case OPEN_LONG, OPEN_SHORT -> pmDecision.targetPositionRatio(); + case ADD_LONG, ADD_SHORT -> pmDecision.addRatio(); + case REDUCE_LONG, REDUCE_SHORT -> pmDecision.reduceRatio(); + case WAIT, HOLD, CLOSE_LONG, CLOSE_SHORT, MOVE_STOP, CANCEL -> null; + }; + } + private PositionSide sideFor(TraderActionType action, PositionSide pmSide) { return switch (action) { case OPEN_LONG, ADD_LONG, REDUCE_LONG, CLOSE_LONG -> PositionSide.LONG; diff --git a/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java b/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java index 5becad9..b3d4bd9 100644 --- a/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java +++ b/src/main/java/com/quantai/trader/replay/TraderP0CycleRunner.java @@ -11,6 +11,8 @@ import com.quantai.trader.outbox.TraderOutboxRepository; import com.quantai.trader.outbox.TraderOutboxEvent; import com.quantai.trader.persistence.TraderDecisionTraceWriter; import com.quantai.trader.position.TraderPositionManager; +import com.quantai.trader.replay.state.TraderReplayState; +import com.quantai.trader.replay.state.TraderReplayStateStore; import com.quantai.trader.risk.RiskGateInput; import com.quantai.trader.risk.RiskLimits; import com.quantai.trader.risk.TraderRiskGate; @@ -35,6 +37,7 @@ public class TraderP0CycleRunner { private final EvidenceAppender evidenceAppender; private final TraderDecisionTraceWriter traceWriter; private final TraderOutboxRepository outboxRepository; + private final TraderReplayStateStore stateStore; public TraderP0CycleRunner(TraderProperties properties, TraderArtifactLoader artifactLoader, @@ -44,7 +47,8 @@ public class TraderP0CycleRunner { TraderActionFactory actionFactory, EvidenceAppender evidenceAppender, TraderDecisionTraceWriter traceWriter, - TraderOutboxRepository outboxRepository) { + TraderOutboxRepository outboxRepository, + TraderReplayStateStore stateStore) { this.properties = properties; this.artifactLoader = artifactLoader; this.modelService = modelService; @@ -54,19 +58,21 @@ public class TraderP0CycleRunner { this.evidenceAppender = evidenceAppender; this.traceWriter = traceWriter; this.outboxRepository = outboxRepository; + this.stateStore = stateStore; } - public TraderCycleResult runFlatCycle(ReplayMarketEvent event) { + public TraderCycleResult runCycle(ReplayMarketEvent event) { String cycleId = "cycle_" + event.runId() + "_" + event.eventTime().toEpochMilli(); TraderArtifactBundle bundle = artifactLoader.loadActiveBundle(); TraderDecisionCycle cycle = new TraderDecisionCycle(event.runId(), cycleId, event.symbol(), event.eventTime(), properties.runMode(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion()); TraderMarketSnapshot snapshot = snapshot(event, cycleId); + TraderReplayState state = stateStore.load(cycle, snapshot); evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of()); TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle); evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId())); PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput, - flatPosition(cycle, snapshot), account(cycle), execution(cycle), bundle.pmConfig()); + state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig()); TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput); evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name())); TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(), @@ -77,6 +83,7 @@ public class TraderP0CycleRunner { outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(), "TRADER_ACTION", action.actionId(), "ACTION_CREATED", properties.runMode().name() + "_RECORDER", Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now())); + stateStore.advance(state, action, snapshot); log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING", action.runId(), action.cycleId(), action.actionType()); return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action); } @@ -88,23 +95,6 @@ public class TraderP0CycleRunner { event.depthNotional5Bps().compareTo(BigDecimal.ZERO) > 0, Map.of(), Map.of()); } - private TraderPositionState flatPosition(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) { - return new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(), - PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, new BigDecimal("1000"), - 0, BigDecimal.ONE, null); - } - - private TraderAccountState account(TraderDecisionCycle cycle) { - return new TraderAccountState("account_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), - BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ONE, 0); - } - - private TraderExecutionState execution(TraderDecisionCycle cycle) { - return new TraderExecutionState("execution_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(), - java.util.List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"), - new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE); - } - private RiskLimits riskLimits() { return new RiskLimits(properties.risk().maxDailyLossBps(), properties.risk().maxTotalExposureRatio(), properties.risk().minLiquidationBufferBps(), properties.execution().maxApiErrorCount(), diff --git a/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java b/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java new file mode 100644 index 0000000..be09775 --- /dev/null +++ b/src/main/java/com/quantai/trader/replay/state/P0ReplayStateStore.java @@ -0,0 +1,166 @@ +package com.quantai.trader.replay.state; + +import com.quantai.trader.config.TraderProperties; +import com.quantai.trader.domain.*; +import com.quantai.trader.enums.PositionSide; +import com.quantai.trader.enums.TraderActionType; +import com.quantai.trader.util.TraderNumbers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Process-local state for ordered P0 replay/shadow cycles; each runId+symbol stream advances from the prior cycle. + */ +@Component +public class P0ReplayStateStore implements TraderReplayStateStore { + private static final Logger log = LoggerFactory.getLogger(P0ReplayStateStore.class); + + private final TraderProperties properties; + private final ConcurrentHashMap states = new ConcurrentHashMap<>(); + + public P0ReplayStateStore(TraderProperties properties) { + this.properties = properties; + } + + @Override + public TraderReplayState load(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) { + TraderReplayState state = states.computeIfAbsent(key(cycle.runId(), cycle.symbol()), ignored -> flatState(cycle, snapshot)); + TraderReplayState refreshed = refreshForCycle(cycle, snapshot, state); + states.put(key(cycle.runId(), cycle.symbol()), refreshed); + log.info("event=trader.replay_state.loaded runId={} cycleId={} side={} ratio={}", + cycle.runId(), cycle.cycleId(), refreshed.positionState().side(), refreshed.positionState().positionRatio()); + return refreshed; + } + + @Override + public TraderReplayState advance(TraderReplayState current, TraderAction action, TraderMarketSnapshot snapshot) { + TraderPositionState nextPosition = switch (action.actionType()) { + case OPEN_LONG, OPEN_SHORT -> openPosition(current.positionState(), action, snapshot); + case ADD_LONG, ADD_SHORT -> addPosition(current.positionState(), action, snapshot); + case REDUCE_LONG, REDUCE_SHORT -> reducePosition(current.positionState(), action, snapshot); + case CLOSE_LONG, CLOSE_SHORT -> flatPosition(current.positionState(), action.cycleId(), snapshot); + case WAIT, HOLD, MOVE_STOP, CANCEL -> refreshPosition(current.positionState(), action.cycleId(), snapshot); + }; + TraderReplayState next = new TraderReplayState( + nextPosition, + account(action, nextPosition), + execution(action)); + states.put(key(action.runId(), action.symbol()), next); + log.info("event=trader.replay_state.advanced runId={} cycleId={} action={} side={} ratio={}", + action.runId(), action.cycleId(), action.actionType(), nextPosition.side(), nextPosition.positionRatio()); + return next; + } + + private TraderReplayState flatState(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) { + TraderPositionState position = new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(), + PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, new BigDecimal("1000"), + 0, properties.positionManager().maxTotalPositionRatio(), null); + return new TraderReplayState(position, account(cycle, position), execution(cycle)); + } + + private TraderReplayState refreshForCycle(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderReplayState state) { + TraderPositionState position = new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(), + state.positionState().side(), state.positionState().positionRatio(), state.positionState().averageEntryPrice(), + snapshot.markPrice(), pnlBps(state.positionState().side(), state.positionState().averageEntryPrice(), snapshot.markPrice()), + state.positionState().liquidationBufferBps(), state.positionState().addCount(), + remainingCapacity(state.positionState().positionRatio()), state.positionState().lastAddTime()); + return new TraderReplayState(position, account(cycle, position), execution(cycle)); + } + + private TraderPositionState openPosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) { + PositionSide side = action.actionType() == TraderActionType.OPEN_LONG ? PositionSide.LONG : PositionSide.SHORT; + BigDecimal ratio = TraderNumbers.required(action.positionRatio(), "action.positionRatio"); + return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(), + side, ratio, snapshot.markPrice(), snapshot.markPrice(), BigDecimal.ZERO, current.liquidationBufferBps(), + 0, remainingCapacity(ratio), null); + } + + private TraderPositionState addPosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) { + BigDecimal addRatio = TraderNumbers.required(action.positionRatio(), "action.positionRatio"); + BigDecimal newRatio = current.positionRatio().add(addRatio); + BigDecimal weightedEntry = weightedEntry(current.averageEntryPrice(), current.positionRatio(), snapshot.markPrice(), addRatio); + return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(), + current.side(), newRatio, weightedEntry, snapshot.markPrice(), pnlBps(current.side(), weightedEntry, snapshot.markPrice()), + current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), Instant.now()); + } + + private TraderPositionState reducePosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) { + BigDecimal reduceRatio = TraderNumbers.required(action.positionRatio(), "action.positionRatio"); + BigDecimal newRatio = current.positionRatio().multiply(BigDecimal.ONE.subtract(reduceRatio), MathContext.DECIMAL64); + if (newRatio.compareTo(new BigDecimal("0.00000001")) <= 0) { + return flatPosition(current, action.cycleId(), snapshot); + } + return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(), + current.side(), newRatio, current.averageEntryPrice(), snapshot.markPrice(), + pnlBps(current.side(), current.averageEntryPrice(), snapshot.markPrice()), current.liquidationBufferBps(), + current.addCount(), remainingCapacity(newRatio), current.lastAddTime()); + } + + private TraderPositionState refreshPosition(TraderPositionState current, String cycleId, TraderMarketSnapshot snapshot) { + if (current.isFlat()) { + return flatPosition(current, cycleId, snapshot); + } + return new TraderPositionState("position_state_" + cycleId, current.runId(), cycleId, current.symbol(), + current.side(), current.positionRatio(), current.averageEntryPrice(), snapshot.markPrice(), + pnlBps(current.side(), current.averageEntryPrice(), snapshot.markPrice()), current.liquidationBufferBps(), + current.addCount(), remainingCapacity(current.positionRatio()), current.lastAddTime()); + } + + private TraderPositionState flatPosition(TraderPositionState current, String cycleId, TraderMarketSnapshot snapshot) { + return new TraderPositionState("position_state_" + cycleId, current.runId(), cycleId, current.symbol(), + PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, + current.liquidationBufferBps(), 0, properties.positionManager().maxTotalPositionRatio(), null); + } + + private TraderAccountState account(TraderDecisionCycle cycle, TraderPositionState position) { + return new TraderAccountState("account_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), + BigDecimal.ZERO, position.positionRatio(), remainingCapacity(position.positionRatio()), 0); + } + + private TraderAccountState account(TraderAction action, TraderPositionState position) { + return new TraderAccountState("account_state_" + action.cycleId(), action.runId(), action.cycleId(), + BigDecimal.ZERO, position.positionRatio(), remainingCapacity(position.positionRatio()), 0); + } + + private TraderExecutionState execution(TraderDecisionCycle cycle) { + return new TraderExecutionState("execution_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(), + List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"), + new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE); + } + + private TraderExecutionState execution(TraderAction action) { + return new TraderExecutionState("execution_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(), + List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"), + new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE); + } + + private BigDecimal weightedEntry(BigDecimal currentEntry, BigDecimal currentRatio, BigDecimal addPrice, BigDecimal addRatio) { + BigDecimal total = currentRatio.add(addRatio); + return currentEntry.multiply(currentRatio).add(addPrice.multiply(addRatio)).divide(total, MathContext.DECIMAL64); + } + + private BigDecimal pnlBps(PositionSide side, BigDecimal averageEntryPrice, BigDecimal currentPrice) { + if (side == PositionSide.NONE || averageEntryPrice == null) { + return BigDecimal.ZERO; + } + BigDecimal priceMove = side.isLong() + ? currentPrice.subtract(averageEntryPrice) + : averageEntryPrice.subtract(currentPrice); + return priceMove.divide(averageEntryPrice, MathContext.DECIMAL64).multiply(new BigDecimal("10000")); + } + + private BigDecimal remainingCapacity(BigDecimal positionRatio) { + return properties.positionManager().maxTotalPositionRatio().subtract(positionRatio).max(BigDecimal.ZERO); + } + + private String key(String runId, String symbol) { + return runId + "::" + symbol; + } +} diff --git a/src/main/java/com/quantai/trader/replay/state/TraderReplayState.java b/src/main/java/com/quantai/trader/replay/state/TraderReplayState.java new file mode 100644 index 0000000..b6ccb1d --- /dev/null +++ b/src/main/java/com/quantai/trader/replay/state/TraderReplayState.java @@ -0,0 +1,12 @@ +package com.quantai.trader.replay.state; + +import com.quantai.trader.domain.TraderAccountState; +import com.quantai.trader.domain.TraderExecutionState; +import com.quantai.trader.domain.TraderPositionState; + +public record TraderReplayState( + TraderPositionState positionState, + TraderAccountState accountState, + TraderExecutionState executionState +) { +} diff --git a/src/main/java/com/quantai/trader/replay/state/TraderReplayStateStore.java b/src/main/java/com/quantai/trader/replay/state/TraderReplayStateStore.java new file mode 100644 index 0000000..a400e2c --- /dev/null +++ b/src/main/java/com/quantai/trader/replay/state/TraderReplayStateStore.java @@ -0,0 +1,11 @@ +package com.quantai.trader.replay.state; + +import com.quantai.trader.domain.TraderAction; +import com.quantai.trader.domain.TraderDecisionCycle; +import com.quantai.trader.domain.TraderMarketSnapshot; + +public interface TraderReplayStateStore { + TraderReplayState load(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot); + + TraderReplayState advance(TraderReplayState current, TraderAction action, TraderMarketSnapshot snapshot); +} diff --git a/src/test/java/com/quantai/trader/domain/TraderActionFactoryTest.java b/src/test/java/com/quantai/trader/domain/TraderActionFactoryTest.java new file mode 100644 index 0000000..52c002d --- /dev/null +++ b/src/test/java/com/quantai/trader/domain/TraderActionFactoryTest.java @@ -0,0 +1,25 @@ +package com.quantai.trader.domain; + +import com.quantai.trader.enums.PositionSide; +import com.quantai.trader.enums.TraderActionType; +import org.junit.jupiter.api.Test; + +import static com.quantai.trader.TestFixtures.*; +import static org.assertj.core.api.Assertions.assertThat; + +class TraderActionFactoryTest { + private final TraderActionFactory factory = new TraderActionFactory(); + + @Test + void mapsReduceDecisionRatioOntoActionPositionRatio() { + TraderRiskDecision riskDecision = new TraderRiskDecision( + "risk-1", "run-1", "cycle-1", "pm-cycle-1", + true, TraderActionType.REDUCE_LONG, TraderActionType.REDUCE_LONG, null, java.util.Map.of()); + + TraderAction action = factory.create(pmDecision(TraderActionType.REDUCE_LONG, PositionSide.LONG), + riskDecision, "BTC-USDT-PERP"); + + assertThat(action.positionRatio()).isEqualByComparingTo("0.50"); + assertThat(action.reduceOnly()).isTrue(); + } +} diff --git a/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java b/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java index 6351d82..cc01e12 100644 --- a/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java +++ b/src/test/java/com/quantai/trader/replay/TraderP0CycleRunnerTest.java @@ -10,6 +10,7 @@ import com.quantai.trader.outbox.TraderOutboxEvent; import com.quantai.trader.outbox.TraderOutboxRepository; import com.quantai.trader.persistence.TraderDecisionTraceWriter; import com.quantai.trader.position.TraderPositionManager; +import com.quantai.trader.replay.state.P0ReplayStateStore; import com.quantai.trader.risk.TraderRiskGate; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -34,18 +35,20 @@ class TraderP0CycleRunnerTest { EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository); RecordingTraceWriter traceWriter = new RecordingTraceWriter(); RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository(); + var properties = propertiesWithArtifactRoot(artifactRoot); TraderP0CycleRunner runner = new TraderP0CycleRunner( - propertiesWithArtifactRoot(artifactRoot), - new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()), + properties, + new TraderArtifactLoader(properties, objectMapper()), new ArtifactTraderModelService(), new TraderPositionManager(), new TraderRiskGate(), new TraderActionFactory(), evidenceAppender, traceWriter, - outboxRepository); + outboxRepository, + new P0ReplayStateStore(properties)); - TraderCycleResult result = runner.runFlatCycle(new ReplayMarketEvent( + TraderCycleResult result = runner.runCycle(new ReplayMarketEvent( "run-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"), new BigDecimal("1.2"), new BigDecimal("1000"))); @@ -65,18 +68,20 @@ class TraderP0CycleRunnerTest { EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository); RecordingTraceWriter traceWriter = new RecordingTraceWriter(); RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository(); + var properties = propertiesWithArtifactRoot(artifactRoot); TraderP0CycleRunner runner = new TraderP0CycleRunner( - propertiesWithArtifactRoot(artifactRoot), - new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()), + properties, + new TraderArtifactLoader(properties, objectMapper()), new ArtifactTraderModelService(), new TraderPositionManager(), new TraderRiskGate(), new TraderActionFactory(), evidenceAppender, traceWriter, - outboxRepository); + outboxRepository, + new P0ReplayStateStore(properties)); - TraderCycleResult result = runner.runFlatCycle(new ReplayMarketEvent( + TraderCycleResult result = runner.runCycle(new ReplayMarketEvent( "run-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("100"), new BigDecimal("99.5"), new BigDecimal("1.2"), BigDecimal.ZERO)); @@ -87,6 +92,40 @@ class TraderP0CycleRunnerTest { assertThat(evidenceRepository.items()).hasSize(4); } + @Test + void laterCycleUsesPositionStateFromEarlierOpen() throws IOException { + writeArtifactBundle(artifactRoot); + RecordingEvidenceRepository evidenceRepository = new RecordingEvidenceRepository(); + EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository); + RecordingTraceWriter traceWriter = new RecordingTraceWriter(); + RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository(); + var properties = propertiesWithArtifactRoot(artifactRoot); + TraderP0CycleRunner runner = new TraderP0CycleRunner( + properties, + new TraderArtifactLoader(properties, objectMapper()), + new ArtifactTraderModelService(), + new TraderPositionManager(), + new TraderRiskGate(), + new TraderActionFactory(), + evidenceAppender, + traceWriter, + outboxRepository, + new P0ReplayStateStore(properties)); + + TraderCycleResult first = runner.runCycle(new ReplayMarketEvent( + "run-state-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"), + new BigDecimal("1.2"), new BigDecimal("1000"))); + TraderCycleResult second = runner.runCycle(new ReplayMarketEvent( + "run-state-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("101"), new BigDecimal("100.5"), + new BigDecimal("1.2"), new BigDecimal("1000"))); + + assertThat(first.action().actionType()).isEqualTo(TraderActionType.OPEN_LONG); + assertThat(second.action().actionType()).isEqualTo(TraderActionType.ADD_LONG); + assertThat(traceWriter.positionStates()).extracting("side").containsExactly( + com.quantai.trader.enums.PositionSide.NONE, + com.quantai.trader.enums.PositionSide.LONG); + } + private static final class RecordingEvidenceRepository implements TraderEvidenceRepository { private final List items = new ArrayList<>(); @@ -115,16 +154,22 @@ class TraderP0CycleRunnerTest { private static final class RecordingTraceWriter implements TraderDecisionTraceWriter { private final List actions = new ArrayList<>(); + private final List positionStates = new ArrayList<>(); @Override public void persistCycleTrace(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderModelOutput modelOutput, TraderPositionState positionState, TraderPositionManagerDecision pmDecision, TraderRiskDecision riskDecision, TraderAction action) { actions.add(action); + positionStates.add(positionState); } List actions() { return actions; } + + List positionStates() { + return positionStates; + } } }