Track replay position state across cycles
This commit is contained in:
@@ -17,6 +17,6 @@ public class TraderReplayController {
|
||||
|
||||
@PostMapping("/api/trader/replay/cycles")
|
||||
public TraderCycleResult runOneCycle(@RequestBody ReplayMarketEvent event) {
|
||||
return runner.runFlatCycle(event);
|
||||
return runner.runCycle(event);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,12 @@ public record TraderAction(
|
||||
pricePlanConfigHash = requiredText(pricePlanConfigHash, "pricePlanConfigHash");
|
||||
positionRatio = positive(positionRatio, "positionRatio");
|
||||
}
|
||||
if (actionType == TraderActionType.REDUCE_LONG || actionType == TraderActionType.REDUCE_SHORT) {
|
||||
positionRatio = positive(positionRatio, "positionRatio");
|
||||
if (positionRatio.compareTo(ONE) > 0) {
|
||||
throw new IllegalArgumentException("positionRatio must be <= 1 for reduce action");
|
||||
}
|
||||
}
|
||||
if (actionType.reducesExposure() && !reduceOnly) {
|
||||
throw new IllegalArgumentException("reduce/close action must be reduceOnly");
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
@@ -23,7 +24,7 @@ public class TraderActionFactory {
|
||||
side,
|
||||
finalAction.increasesExposure() ? pmDecision.pricePlanId() : null,
|
||||
finalAction.increasesExposure() ? pmDecision.pricePlanConfigHash() : null,
|
||||
finalAction == TraderActionType.OPEN_LONG || finalAction == TraderActionType.OPEN_SHORT ? pmDecision.targetPositionRatio() : pmDecision.addRatio(),
|
||||
ratioFor(finalAction, pmDecision),
|
||||
null,
|
||||
pmDecision.stopPrice(),
|
||||
pmDecision.targetPrice(),
|
||||
@@ -33,6 +34,15 @@ public class TraderActionFactory {
|
||||
Map.of("riskAllowed", riskDecision.allowAction()));
|
||||
}
|
||||
|
||||
private BigDecimal ratioFor(TraderActionType action, TraderPositionManagerDecision pmDecision) {
|
||||
return switch (action) {
|
||||
case OPEN_LONG, OPEN_SHORT -> pmDecision.targetPositionRatio();
|
||||
case ADD_LONG, ADD_SHORT -> pmDecision.addRatio();
|
||||
case REDUCE_LONG, REDUCE_SHORT -> pmDecision.reduceRatio();
|
||||
case WAIT, HOLD, CLOSE_LONG, CLOSE_SHORT, MOVE_STOP, CANCEL -> null;
|
||||
};
|
||||
}
|
||||
|
||||
private PositionSide sideFor(TraderActionType action, PositionSide pmSide) {
|
||||
return switch (action) {
|
||||
case OPEN_LONG, ADD_LONG, REDUCE_LONG, CLOSE_LONG -> PositionSide.LONG;
|
||||
|
||||
@@ -11,6 +11,8 @@ import com.quantai.trader.outbox.TraderOutboxRepository;
|
||||
import com.quantai.trader.outbox.TraderOutboxEvent;
|
||||
import com.quantai.trader.persistence.TraderDecisionTraceWriter;
|
||||
import com.quantai.trader.position.TraderPositionManager;
|
||||
import com.quantai.trader.replay.state.TraderReplayState;
|
||||
import com.quantai.trader.replay.state.TraderReplayStateStore;
|
||||
import com.quantai.trader.risk.RiskGateInput;
|
||||
import com.quantai.trader.risk.RiskLimits;
|
||||
import com.quantai.trader.risk.TraderRiskGate;
|
||||
@@ -35,6 +37,7 @@ public class TraderP0CycleRunner {
|
||||
private final EvidenceAppender evidenceAppender;
|
||||
private final TraderDecisionTraceWriter traceWriter;
|
||||
private final TraderOutboxRepository outboxRepository;
|
||||
private final TraderReplayStateStore stateStore;
|
||||
|
||||
public TraderP0CycleRunner(TraderProperties properties,
|
||||
TraderArtifactLoader artifactLoader,
|
||||
@@ -44,7 +47,8 @@ public class TraderP0CycleRunner {
|
||||
TraderActionFactory actionFactory,
|
||||
EvidenceAppender evidenceAppender,
|
||||
TraderDecisionTraceWriter traceWriter,
|
||||
TraderOutboxRepository outboxRepository) {
|
||||
TraderOutboxRepository outboxRepository,
|
||||
TraderReplayStateStore stateStore) {
|
||||
this.properties = properties;
|
||||
this.artifactLoader = artifactLoader;
|
||||
this.modelService = modelService;
|
||||
@@ -54,19 +58,21 @@ public class TraderP0CycleRunner {
|
||||
this.evidenceAppender = evidenceAppender;
|
||||
this.traceWriter = traceWriter;
|
||||
this.outboxRepository = outboxRepository;
|
||||
this.stateStore = stateStore;
|
||||
}
|
||||
|
||||
public TraderCycleResult runFlatCycle(ReplayMarketEvent event) {
|
||||
public TraderCycleResult runCycle(ReplayMarketEvent event) {
|
||||
String cycleId = "cycle_" + event.runId() + "_" + event.eventTime().toEpochMilli();
|
||||
TraderArtifactBundle bundle = artifactLoader.loadActiveBundle();
|
||||
TraderDecisionCycle cycle = new TraderDecisionCycle(event.runId(), cycleId, event.symbol(), event.eventTime(),
|
||||
properties.runMode(), bundle.modelBundleVersion(), bundle.calibrationBundleVersion(), bundle.pmConfigVersion());
|
||||
TraderMarketSnapshot snapshot = snapshot(event, cycleId);
|
||||
TraderReplayState state = stateStore.load(cycle, snapshot);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MARKET_SNAPSHOT", snapshot.dataReady(), "SNAPSHOT_BUILT", null, Map.of());
|
||||
TraderModelOutput modelOutput = modelService.evaluate(snapshot, bundle);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "MODEL_OUTPUT", true, "MODEL_EVALUATED", null, Map.of("modelOutputId", modelOutput.modelOutputId()));
|
||||
PositionManagerInput pmInput = new PositionManagerInput(cycle, snapshot, modelOutput,
|
||||
flatPosition(cycle, snapshot), account(cycle), execution(cycle), bundle.pmConfig());
|
||||
state.positionState(), state.accountState(), state.executionState(), bundle.pmConfig());
|
||||
TraderPositionManagerDecision pmDecision = positionManager.decide(pmInput);
|
||||
evidenceAppender.append(cycle.runId(), cycle.cycleId(), "PM_DECISION", true, pmDecision.reason(), null, Map.of("action", pmDecision.candidateAction().name()));
|
||||
TraderRiskDecision riskDecision = riskGate.evaluate(new RiskGateInput(pmDecision, pmInput.positionState(), pmInput.accountState(),
|
||||
@@ -77,6 +83,7 @@ public class TraderP0CycleRunner {
|
||||
outboxRepository.insert(new TraderOutboxEvent("outbox_" + action.actionId(), action.runId(), action.cycleId(),
|
||||
"TRADER_ACTION", action.actionId(), "ACTION_CREATED", properties.runMode().name() + "_RECORDER",
|
||||
Map.of("actionType", action.actionType().name()), action.idempotencyKey(), "PENDING", Instant.now()));
|
||||
stateStore.advance(state, action, snapshot);
|
||||
log.info("event=trader.cycle.completed runId={} cycleId={} action={} outbox=PENDING", action.runId(), action.cycleId(), action.actionType());
|
||||
return new TraderCycleResult(cycle.runId(), cycle.cycleId(), pmDecision, riskDecision, action);
|
||||
}
|
||||
@@ -88,23 +95,6 @@ public class TraderP0CycleRunner {
|
||||
event.depthNotional5Bps().compareTo(BigDecimal.ZERO) > 0, Map.of(), Map.of());
|
||||
}
|
||||
|
||||
private TraderPositionState flatPosition(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) {
|
||||
return new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
|
||||
PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, new BigDecimal("1000"),
|
||||
0, BigDecimal.ONE, null);
|
||||
}
|
||||
|
||||
private TraderAccountState account(TraderDecisionCycle cycle) {
|
||||
return new TraderAccountState("account_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(),
|
||||
BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ONE, 0);
|
||||
}
|
||||
|
||||
private TraderExecutionState execution(TraderDecisionCycle cycle) {
|
||||
return new TraderExecutionState("execution_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
|
||||
java.util.List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"),
|
||||
new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE);
|
||||
}
|
||||
|
||||
private RiskLimits riskLimits() {
|
||||
return new RiskLimits(properties.risk().maxDailyLossBps(), properties.risk().maxTotalExposureRatio(),
|
||||
properties.risk().minLiquidationBufferBps(), properties.execution().maxApiErrorCount(),
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
package com.quantai.trader.replay.state;
|
||||
|
||||
import com.quantai.trader.config.TraderProperties;
|
||||
import com.quantai.trader.domain.*;
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import com.quantai.trader.util.TraderNumbers;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.MathContext;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* Process-local state for ordered P0 replay/shadow cycles; each runId+symbol stream advances from the prior cycle.
|
||||
*/
|
||||
@Component
|
||||
public class P0ReplayStateStore implements TraderReplayStateStore {
|
||||
private static final Logger log = LoggerFactory.getLogger(P0ReplayStateStore.class);
|
||||
|
||||
private final TraderProperties properties;
|
||||
private final ConcurrentHashMap<String, TraderReplayState> states = new ConcurrentHashMap<>();
|
||||
|
||||
public P0ReplayStateStore(TraderProperties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderReplayState load(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) {
|
||||
TraderReplayState state = states.computeIfAbsent(key(cycle.runId(), cycle.symbol()), ignored -> flatState(cycle, snapshot));
|
||||
TraderReplayState refreshed = refreshForCycle(cycle, snapshot, state);
|
||||
states.put(key(cycle.runId(), cycle.symbol()), refreshed);
|
||||
log.info("event=trader.replay_state.loaded runId={} cycleId={} side={} ratio={}",
|
||||
cycle.runId(), cycle.cycleId(), refreshed.positionState().side(), refreshed.positionState().positionRatio());
|
||||
return refreshed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TraderReplayState advance(TraderReplayState current, TraderAction action, TraderMarketSnapshot snapshot) {
|
||||
TraderPositionState nextPosition = switch (action.actionType()) {
|
||||
case OPEN_LONG, OPEN_SHORT -> openPosition(current.positionState(), action, snapshot);
|
||||
case ADD_LONG, ADD_SHORT -> addPosition(current.positionState(), action, snapshot);
|
||||
case REDUCE_LONG, REDUCE_SHORT -> reducePosition(current.positionState(), action, snapshot);
|
||||
case CLOSE_LONG, CLOSE_SHORT -> flatPosition(current.positionState(), action.cycleId(), snapshot);
|
||||
case WAIT, HOLD, MOVE_STOP, CANCEL -> refreshPosition(current.positionState(), action.cycleId(), snapshot);
|
||||
};
|
||||
TraderReplayState next = new TraderReplayState(
|
||||
nextPosition,
|
||||
account(action, nextPosition),
|
||||
execution(action));
|
||||
states.put(key(action.runId(), action.symbol()), next);
|
||||
log.info("event=trader.replay_state.advanced runId={} cycleId={} action={} side={} ratio={}",
|
||||
action.runId(), action.cycleId(), action.actionType(), nextPosition.side(), nextPosition.positionRatio());
|
||||
return next;
|
||||
}
|
||||
|
||||
private TraderReplayState flatState(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot) {
|
||||
TraderPositionState position = new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
|
||||
PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO, new BigDecimal("1000"),
|
||||
0, properties.positionManager().maxTotalPositionRatio(), null);
|
||||
return new TraderReplayState(position, account(cycle, position), execution(cycle));
|
||||
}
|
||||
|
||||
private TraderReplayState refreshForCycle(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderReplayState state) {
|
||||
TraderPositionState position = new TraderPositionState("position_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
|
||||
state.positionState().side(), state.positionState().positionRatio(), state.positionState().averageEntryPrice(),
|
||||
snapshot.markPrice(), pnlBps(state.positionState().side(), state.positionState().averageEntryPrice(), snapshot.markPrice()),
|
||||
state.positionState().liquidationBufferBps(), state.positionState().addCount(),
|
||||
remainingCapacity(state.positionState().positionRatio()), state.positionState().lastAddTime());
|
||||
return new TraderReplayState(position, account(cycle, position), execution(cycle));
|
||||
}
|
||||
|
||||
private TraderPositionState openPosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
|
||||
PositionSide side = action.actionType() == TraderActionType.OPEN_LONG ? PositionSide.LONG : PositionSide.SHORT;
|
||||
BigDecimal ratio = TraderNumbers.required(action.positionRatio(), "action.positionRatio");
|
||||
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
|
||||
side, ratio, snapshot.markPrice(), snapshot.markPrice(), BigDecimal.ZERO, current.liquidationBufferBps(),
|
||||
0, remainingCapacity(ratio), null);
|
||||
}
|
||||
|
||||
private TraderPositionState addPosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
|
||||
BigDecimal addRatio = TraderNumbers.required(action.positionRatio(), "action.positionRatio");
|
||||
BigDecimal newRatio = current.positionRatio().add(addRatio);
|
||||
BigDecimal weightedEntry = weightedEntry(current.averageEntryPrice(), current.positionRatio(), snapshot.markPrice(), addRatio);
|
||||
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
|
||||
current.side(), newRatio, weightedEntry, snapshot.markPrice(), pnlBps(current.side(), weightedEntry, snapshot.markPrice()),
|
||||
current.liquidationBufferBps(), current.addCount() + 1, remainingCapacity(newRatio), Instant.now());
|
||||
}
|
||||
|
||||
private TraderPositionState reducePosition(TraderPositionState current, TraderAction action, TraderMarketSnapshot snapshot) {
|
||||
BigDecimal reduceRatio = TraderNumbers.required(action.positionRatio(), "action.positionRatio");
|
||||
BigDecimal newRatio = current.positionRatio().multiply(BigDecimal.ONE.subtract(reduceRatio), MathContext.DECIMAL64);
|
||||
if (newRatio.compareTo(new BigDecimal("0.00000001")) <= 0) {
|
||||
return flatPosition(current, action.cycleId(), snapshot);
|
||||
}
|
||||
return new TraderPositionState("position_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
|
||||
current.side(), newRatio, current.averageEntryPrice(), snapshot.markPrice(),
|
||||
pnlBps(current.side(), current.averageEntryPrice(), snapshot.markPrice()), current.liquidationBufferBps(),
|
||||
current.addCount(), remainingCapacity(newRatio), current.lastAddTime());
|
||||
}
|
||||
|
||||
private TraderPositionState refreshPosition(TraderPositionState current, String cycleId, TraderMarketSnapshot snapshot) {
|
||||
if (current.isFlat()) {
|
||||
return flatPosition(current, cycleId, snapshot);
|
||||
}
|
||||
return new TraderPositionState("position_state_" + cycleId, current.runId(), cycleId, current.symbol(),
|
||||
current.side(), current.positionRatio(), current.averageEntryPrice(), snapshot.markPrice(),
|
||||
pnlBps(current.side(), current.averageEntryPrice(), snapshot.markPrice()), current.liquidationBufferBps(),
|
||||
current.addCount(), remainingCapacity(current.positionRatio()), current.lastAddTime());
|
||||
}
|
||||
|
||||
private TraderPositionState flatPosition(TraderPositionState current, String cycleId, TraderMarketSnapshot snapshot) {
|
||||
return new TraderPositionState("position_state_" + cycleId, current.runId(), cycleId, current.symbol(),
|
||||
PositionSide.NONE, BigDecimal.ZERO, null, snapshot.markPrice(), BigDecimal.ZERO,
|
||||
current.liquidationBufferBps(), 0, properties.positionManager().maxTotalPositionRatio(), null);
|
||||
}
|
||||
|
||||
private TraderAccountState account(TraderDecisionCycle cycle, TraderPositionState position) {
|
||||
return new TraderAccountState("account_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(),
|
||||
BigDecimal.ZERO, position.positionRatio(), remainingCapacity(position.positionRatio()), 0);
|
||||
}
|
||||
|
||||
private TraderAccountState account(TraderAction action, TraderPositionState position) {
|
||||
return new TraderAccountState("account_state_" + action.cycleId(), action.runId(), action.cycleId(),
|
||||
BigDecimal.ZERO, position.positionRatio(), remainingCapacity(position.positionRatio()), 0);
|
||||
}
|
||||
|
||||
private TraderExecutionState execution(TraderDecisionCycle cycle) {
|
||||
return new TraderExecutionState("execution_state_" + cycle.cycleId(), cycle.runId(), cycle.cycleId(), cycle.symbol(),
|
||||
List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"),
|
||||
new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE);
|
||||
}
|
||||
|
||||
private TraderExecutionState execution(TraderAction action) {
|
||||
return new TraderExecutionState("execution_state_" + action.cycleId(), action.runId(), action.cycleId(), action.symbol(),
|
||||
List.of(), new BigDecimal("1.5"), 10, 0, new BigDecimal("1"), new BigDecimal("4"),
|
||||
new BigDecimal("5"), new BigDecimal("0.1"), new BigDecimal("0.001"), new BigDecimal("0.001"), BigDecimal.ONE);
|
||||
}
|
||||
|
||||
private BigDecimal weightedEntry(BigDecimal currentEntry, BigDecimal currentRatio, BigDecimal addPrice, BigDecimal addRatio) {
|
||||
BigDecimal total = currentRatio.add(addRatio);
|
||||
return currentEntry.multiply(currentRatio).add(addPrice.multiply(addRatio)).divide(total, MathContext.DECIMAL64);
|
||||
}
|
||||
|
||||
private BigDecimal pnlBps(PositionSide side, BigDecimal averageEntryPrice, BigDecimal currentPrice) {
|
||||
if (side == PositionSide.NONE || averageEntryPrice == null) {
|
||||
return BigDecimal.ZERO;
|
||||
}
|
||||
BigDecimal priceMove = side.isLong()
|
||||
? currentPrice.subtract(averageEntryPrice)
|
||||
: averageEntryPrice.subtract(currentPrice);
|
||||
return priceMove.divide(averageEntryPrice, MathContext.DECIMAL64).multiply(new BigDecimal("10000"));
|
||||
}
|
||||
|
||||
private BigDecimal remainingCapacity(BigDecimal positionRatio) {
|
||||
return properties.positionManager().maxTotalPositionRatio().subtract(positionRatio).max(BigDecimal.ZERO);
|
||||
}
|
||||
|
||||
private String key(String runId, String symbol) {
|
||||
return runId + "::" + symbol;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.quantai.trader.replay.state;
|
||||
|
||||
import com.quantai.trader.domain.TraderAccountState;
|
||||
import com.quantai.trader.domain.TraderExecutionState;
|
||||
import com.quantai.trader.domain.TraderPositionState;
|
||||
|
||||
public record TraderReplayState(
|
||||
TraderPositionState positionState,
|
||||
TraderAccountState accountState,
|
||||
TraderExecutionState executionState
|
||||
) {
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.quantai.trader.replay.state;
|
||||
|
||||
import com.quantai.trader.domain.TraderAction;
|
||||
import com.quantai.trader.domain.TraderDecisionCycle;
|
||||
import com.quantai.trader.domain.TraderMarketSnapshot;
|
||||
|
||||
public interface TraderReplayStateStore {
|
||||
TraderReplayState load(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot);
|
||||
|
||||
TraderReplayState advance(TraderReplayState current, TraderAction action, TraderMarketSnapshot snapshot);
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.quantai.trader.domain;
|
||||
|
||||
import com.quantai.trader.enums.PositionSide;
|
||||
import com.quantai.trader.enums.TraderActionType;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static com.quantai.trader.TestFixtures.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class TraderActionFactoryTest {
|
||||
private final TraderActionFactory factory = new TraderActionFactory();
|
||||
|
||||
@Test
|
||||
void mapsReduceDecisionRatioOntoActionPositionRatio() {
|
||||
TraderRiskDecision riskDecision = new TraderRiskDecision(
|
||||
"risk-1", "run-1", "cycle-1", "pm-cycle-1",
|
||||
true, TraderActionType.REDUCE_LONG, TraderActionType.REDUCE_LONG, null, java.util.Map.of());
|
||||
|
||||
TraderAction action = factory.create(pmDecision(TraderActionType.REDUCE_LONG, PositionSide.LONG),
|
||||
riskDecision, "BTC-USDT-PERP");
|
||||
|
||||
assertThat(action.positionRatio()).isEqualByComparingTo("0.50");
|
||||
assertThat(action.reduceOnly()).isTrue();
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import com.quantai.trader.outbox.TraderOutboxEvent;
|
||||
import com.quantai.trader.outbox.TraderOutboxRepository;
|
||||
import com.quantai.trader.persistence.TraderDecisionTraceWriter;
|
||||
import com.quantai.trader.position.TraderPositionManager;
|
||||
import com.quantai.trader.replay.state.P0ReplayStateStore;
|
||||
import com.quantai.trader.risk.TraderRiskGate;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
@@ -34,18 +35,20 @@ class TraderP0CycleRunnerTest {
|
||||
EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository);
|
||||
RecordingTraceWriter traceWriter = new RecordingTraceWriter();
|
||||
RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
TraderP0CycleRunner runner = new TraderP0CycleRunner(
|
||||
propertiesWithArtifactRoot(artifactRoot),
|
||||
new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()),
|
||||
properties,
|
||||
new TraderArtifactLoader(properties, objectMapper()),
|
||||
new ArtifactTraderModelService(),
|
||||
new TraderPositionManager(),
|
||||
new TraderRiskGate(),
|
||||
new TraderActionFactory(),
|
||||
evidenceAppender,
|
||||
traceWriter,
|
||||
outboxRepository);
|
||||
outboxRepository,
|
||||
new P0ReplayStateStore(properties));
|
||||
|
||||
TraderCycleResult result = runner.runFlatCycle(new ReplayMarketEvent(
|
||||
TraderCycleResult result = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"),
|
||||
new BigDecimal("1.2"), new BigDecimal("1000")));
|
||||
|
||||
@@ -65,18 +68,20 @@ class TraderP0CycleRunnerTest {
|
||||
EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository);
|
||||
RecordingTraceWriter traceWriter = new RecordingTraceWriter();
|
||||
RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
TraderP0CycleRunner runner = new TraderP0CycleRunner(
|
||||
propertiesWithArtifactRoot(artifactRoot),
|
||||
new TraderArtifactLoader(propertiesWithArtifactRoot(artifactRoot), objectMapper()),
|
||||
properties,
|
||||
new TraderArtifactLoader(properties, objectMapper()),
|
||||
new ArtifactTraderModelService(),
|
||||
new TraderPositionManager(),
|
||||
new TraderRiskGate(),
|
||||
new TraderActionFactory(),
|
||||
evidenceAppender,
|
||||
traceWriter,
|
||||
outboxRepository);
|
||||
outboxRepository,
|
||||
new P0ReplayStateStore(properties));
|
||||
|
||||
TraderCycleResult result = runner.runFlatCycle(new ReplayMarketEvent(
|
||||
TraderCycleResult result = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("100"), new BigDecimal("99.5"),
|
||||
new BigDecimal("1.2"), BigDecimal.ZERO));
|
||||
|
||||
@@ -87,6 +92,40 @@ class TraderP0CycleRunnerTest {
|
||||
assertThat(evidenceRepository.items()).hasSize(4);
|
||||
}
|
||||
|
||||
@Test
|
||||
void laterCycleUsesPositionStateFromEarlierOpen() throws IOException {
|
||||
writeArtifactBundle(artifactRoot);
|
||||
RecordingEvidenceRepository evidenceRepository = new RecordingEvidenceRepository();
|
||||
EvidenceAppender evidenceAppender = new EvidenceAppender(evidenceRepository);
|
||||
RecordingTraceWriter traceWriter = new RecordingTraceWriter();
|
||||
RecordingOutboxRepository outboxRepository = new RecordingOutboxRepository();
|
||||
var properties = propertiesWithArtifactRoot(artifactRoot);
|
||||
TraderP0CycleRunner runner = new TraderP0CycleRunner(
|
||||
properties,
|
||||
new TraderArtifactLoader(properties, objectMapper()),
|
||||
new ArtifactTraderModelService(),
|
||||
new TraderPositionManager(),
|
||||
new TraderRiskGate(),
|
||||
new TraderActionFactory(),
|
||||
evidenceAppender,
|
||||
traceWriter,
|
||||
outboxRepository,
|
||||
new P0ReplayStateStore(properties));
|
||||
|
||||
TraderCycleResult first = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-state-1", "BTC-USDT-PERP", T0, new BigDecimal("100"), new BigDecimal("99.5"),
|
||||
new BigDecimal("1.2"), new BigDecimal("1000")));
|
||||
TraderCycleResult second = runner.runCycle(new ReplayMarketEvent(
|
||||
"run-state-1", "BTC-USDT-PERP", T0.plusSeconds(60), new BigDecimal("101"), new BigDecimal("100.5"),
|
||||
new BigDecimal("1.2"), new BigDecimal("1000")));
|
||||
|
||||
assertThat(first.action().actionType()).isEqualTo(TraderActionType.OPEN_LONG);
|
||||
assertThat(second.action().actionType()).isEqualTo(TraderActionType.ADD_LONG);
|
||||
assertThat(traceWriter.positionStates()).extracting("side").containsExactly(
|
||||
com.quantai.trader.enums.PositionSide.NONE,
|
||||
com.quantai.trader.enums.PositionSide.LONG);
|
||||
}
|
||||
|
||||
private static final class RecordingEvidenceRepository implements TraderEvidenceRepository {
|
||||
private final List<TraderEvidence> items = new ArrayList<>();
|
||||
|
||||
@@ -115,16 +154,22 @@ class TraderP0CycleRunnerTest {
|
||||
|
||||
private static final class RecordingTraceWriter implements TraderDecisionTraceWriter {
|
||||
private final List<TraderAction> actions = new ArrayList<>();
|
||||
private final List<TraderPositionState> positionStates = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void persistCycleTrace(TraderDecisionCycle cycle, TraderMarketSnapshot snapshot, TraderModelOutput modelOutput,
|
||||
TraderPositionState positionState, TraderPositionManagerDecision pmDecision,
|
||||
TraderRiskDecision riskDecision, TraderAction action) {
|
||||
actions.add(action);
|
||||
positionStates.add(positionState);
|
||||
}
|
||||
|
||||
List<TraderAction> actions() {
|
||||
return actions;
|
||||
}
|
||||
|
||||
List<TraderPositionState> positionStates() {
|
||||
return positionStates;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user