Skip to content

Commit

Permalink
test(invariant): include delay in each period
Browse files Browse the repository at this point in the history
  • Loading branch information
andreivladbrg committed Sep 20, 2024
1 parent 2a7e4e2 commit 3e2c5fd
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 16 deletions.
18 changes: 10 additions & 8 deletions test/invariant/Flow.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ contract Flow_Invariant_Test is Base_Test {
/// @dev For non-voided streams, the difference between the total amount streamed and the sum of total debt and
/// total withdrawn should never exceed 1. This is indirectly checking that withdrawals do not cause the streamed
/// amount to deviate from the theoretical streamed amount by more than 1.
function invariant_TotalStreamedApproxEqTotalDebtPlusWithdrawn() external view {
function invariant_TotalStreamedEqTotalDebtPlusWithdrawn() external view {
uint256 lastStreamId = flowStore.lastStreamId();
for (uint256 i = 0; i < lastStreamId; ++i) {
uint256 streamId = flowStore.streamIds(i);
Expand All @@ -286,10 +286,10 @@ contract Flow_Invariant_Test is Base_Test {
uint256 totalStreamedAmount =
calculateTotalStreamedAmount(flowStore.streamIds(i), flow.getTokenDecimals(streamId));

assertLe(
totalStreamedAmount - flow.totalDebtOf(streamId) - flowStore.withdrawnAmounts(streamId),
1,
"Invariant violation: total debt - streamed amount - withdrawn amount > 1"
assertEq(
totalStreamedAmount,
flow.totalDebtOf(streamId) + flowStore.withdrawnAmounts(streamId),
"Invariant violation: total streamed amount = total debt + withdrawn amount"
);
}
}
Expand All @@ -310,11 +310,13 @@ contract Flow_Invariant_Test is Base_Test {
FlowStore.Period memory period = flowStore.getPeriod(streamId, i);

// If end time is 0, it means the current period is still active.
uint40 elapsed = period.end > 0 ? period.end - period.start : uint40(block.timestamp) - period.start;
uint40 elapsed = period.end > 0
? period.end - period.start - period.delay
: uint40(block.timestamp) - period.start - period.delay;

totalStreamedAmount += period.ratePerSecond * elapsed;
totalStreamedAmount += (period.ratePerSecond * elapsed) / 10 ** (18 - decimals);
}

return totalStreamedAmount / 10 ** (18 - decimals);
return totalStreamedAmount;
}
}
23 changes: 18 additions & 5 deletions test/invariant/handlers/FlowHandler.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ contract FlowHandler is BaseHandler {
vm.assume(newRatePerSecond.unwrap() > mvt / 100 && newRatePerSecond.unwrap() <= 1e18);
}

uint128 previousRatePerSecond = flow.getRatePerSecond(currentStreamId).unwrap();

// The rate per second must be different from the current rate per second.
vm.assume(newRatePerSecond.unwrap() != flow.getRatePerSecond(currentStreamId).unwrap());
vm.assume(newRatePerSecond.unwrap() != previousRatePerSecond);

// Adjust the rate per second.
flow.adjustRatePerSecond(currentStreamId, newRatePerSecond);

flowStore.updatePeriods(currentStreamId, newRatePerSecond.unwrap(), "adjustRatePerSecond");
flowStore.updateDelay(currentStreamId, previousRatePerSecond, decimals);
flowStore.pushPeriod(currentStreamId, newRatePerSecond.unwrap(), "adjustRatePerSecond");
}

function deposit(
Expand Down Expand Up @@ -157,10 +160,14 @@ contract FlowHandler is BaseHandler {
// Paused streams cannot be paused again.
vm.assume(!flow.isPaused(currentStreamId));

flowStore.updateDelay(
currentStreamId, flow.getRatePerSecond(currentStreamId).unwrap(), flow.getTokenDecimals(currentStreamId)
);

// Pause the stream.
flow.pause(currentStreamId);

flowStore.updatePeriods(currentStreamId, 0, "pause");
flowStore.pushPeriod(currentStreamId, 0, "pause");
}

function refund(
Expand Down Expand Up @@ -226,7 +233,7 @@ contract FlowHandler is BaseHandler {
// Restart the stream.
flow.restart(currentStreamId, ratePerSecond);

flowStore.updatePeriods(currentStreamId, ratePerSecond.unwrap(), "restart");
flowStore.pushPeriod(currentStreamId, ratePerSecond.unwrap(), "restart");
}

function void(
Expand All @@ -249,7 +256,7 @@ contract FlowHandler is BaseHandler {
// Void the stream.
flow.void(currentStreamId);

flowStore.updatePeriods(currentStreamId, 0, "void");
flowStore.pushPeriod(currentStreamId, 0, "void");
}

function withdraw(
Expand Down Expand Up @@ -285,5 +292,11 @@ contract FlowHandler is BaseHandler {

// Update the withdrawn amount.
flowStore.updateStreamWithdrawnAmountsSum(currentStreamId, flow.getToken(currentStreamId), amount);

// If the stream isn't paused, update the delay:
uint128 ratePerSecond = flow.getRatePerSecond(currentStreamId).unwrap();
if (ratePerSecond > 0) {
flowStore.updateDelay(currentStreamId, ratePerSecond, flow.getTokenDecimals(currentStreamId));
}
}
}
44 changes: 41 additions & 3 deletions test/invariant/stores/FlowStore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ contract FlowStore {
/// @param ratePerSecond The rate per second for this period.
/// @param start The start time of the period.
/// @param end The end time of the period.
/// @param delay The delay for the period.
struct Period {
string typeOfPeriod;
uint128 ratePerSecond;
uint40 start;
uint40 end;
uint40 delay;
}

/// @dev Each stream is mapped to an array of periods. This is used to calculate the total streamed amount.
Expand All @@ -57,23 +59,59 @@ contract FlowStore {
// Store the stream id and the period during which provided ratePerSecond applies.
streamIds.push(streamId);
periods[streamId].push(
Period({ typeOfPeriod: "create", ratePerSecond: ratePerSecond, start: uint40(block.timestamp), end: 0 })
Period({
typeOfPeriod: "create",
ratePerSecond: ratePerSecond,
start: uint40(block.timestamp),
end: 0,
delay: 0
})
);

// Update the last stream id.
lastStreamId = streamId;
}

function updatePeriods(uint256 streamId, uint128 ratePerSecond, string memory typeOfPeriod) external {
function pushPeriod(uint256 streamId, uint128 ratePerSecond, string memory typeOfPeriod) external {
// Update the end time of the previous period.
periods[streamId][periods[streamId].length - 1].end = uint40(block.timestamp);

// Push the new period with the provided rate per second.
periods[streamId].push(
Period({ typeOfPeriod: typeOfPeriod, ratePerSecond: ratePerSecond, start: uint40(block.timestamp), end: 0 })
Period({
ratePerSecond: ratePerSecond,
start: uint40(block.timestamp),
end: 0,
delay: 0,
typeOfPeriod: typeOfPeriod
})
);
}

function updateDelay(uint256 streamId, uint128 ratePerSecond, uint8 decimals) external {
// Skip the delay update if the decimals are 18.
if (decimals == 18) {
return;
}

uint256 periodCount = periods[streamId].length - 1;
uint128 factor = uint128(10 ** (18 - decimals));
uint40 blockTimestamp = uint40(block.timestamp);
uint40 start = periods[streamId][periodCount].start;

uint128 rescaledStreamedAmount = ratePerSecond * (blockTimestamp - start) / factor * factor;

uint40 delay;
if (rescaledStreamedAmount > ratePerSecond) {
delay = blockTimestamp - start - uint40(rescaledStreamedAmount / ratePerSecond);
// Since we are reverse engineering the delay, we need to subtract 1 from the delay, which would normally be
// added in the constant interval calculation
delay = delay > 0 ? delay - 1 : 0;
}

periods[streamId][periodCount].delay += delay;
}

function updatePreviousValues(
uint256 streamId,
uint40 snapshotTime,
Expand Down

0 comments on commit 3e2c5fd

Please sign in to comment.