Skip to content

Commit

Permalink
docs: polish explanations in "_withdraw"
Browse files Browse the repository at this point in the history
fix: return correct amount in "_withdraw"
refactor: reorder parameters in "WithdrawFromFlowStream"
refactor: rename "amount" to "withdrawAmount"
refactor: rename "previous" to "initial"
refactor: rename variables in "_withdraw"
refactor: use "scaled" terminology instead of "normalized"
  • Loading branch information
PaulRBerg committed Sep 17, 2024
1 parent eb096bc commit 2425778
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 114 deletions.
154 changes: 86 additions & 68 deletions src/SablierFlow.sol
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,18 @@ contract SablierFlow is
uint8 tokenDecimals = _streams[streamId].tokenDecimals;

// Calculate the ongoing debt accrued by multiplying the elapsed time by the rate per second.
uint128 normalizedOngoingDebt = elapsedTime * _streams[streamId].ratePerSecond.unwrap();
uint128 scaledOngoingDebt = elapsedTime * _streams[streamId].ratePerSecond.unwrap();

// If the token decimals are 18, return the normalized ongoing debt and the `block.timestamp`.
// If the token decimals are 18, return the scaled ongoing debt and the `block.timestamp`.
if (tokenDecimals == 18) {
return normalizedOngoingDebt;
return scaledOngoingDebt;
}

// Safe to use unchecked because we use {SafeCast}.
// Safe to use unchecked due to {SafeCast}.
unchecked {
uint8 factor = 18 - tokenDecimals;
// Since debt is denoted in token decimals, denormalize the amount.
ongoingDebt = (normalizedOngoingDebt / (10 ** factor)).toUint128();
// Since debt is denoted in token decimals, descale the amount.
ongoingDebt = (scaledOngoingDebt / (10 ** factor)).toUint128();
}
}

Expand Down Expand Up @@ -760,9 +760,9 @@ contract SablierFlow is
}

/// @dev See the documentation for the user-facing functions that call this internal function.
function _withdraw(uint256 streamId, address to, uint128 amount) internal returns (uint128) {
function _withdraw(uint256 streamId, address to, uint128 withdrawAmount) internal returns (uint128) {
// Check: the withdraw amount is not zero.
if (amount == 0) {
if (withdrawAmount == 0) {
revert Errors.SablierFlow_WithdrawAmountZero(streamId);
}

Expand All @@ -777,105 +777,123 @@ contract SablierFlow is
revert Errors.SablierFlow_WithdrawalAddressNotRecipient({ streamId: streamId, caller: msg.sender, to: to });
}

// Calculate the total debt.
uint128 totalDebt = _totalDebtOf(streamId);
// Calculate the total debt at the beginning of the withdrawal.
uint128 initialTotalDebt = _totalDebtOf(streamId);

// Calculate the withdrawable amount.
uint128 balance = _streams[streamId].balance;
uint128 withdrawableAmount;
// Load the initial balance.
uint128 initialBalance = _streams[streamId].balance;

if (balance < totalDebt) {
// If the stream balance is less than the total debt, the withdrawable amount is the balance.
withdrawableAmount = balance;
} else {
// Otherwise, the withdrawable amount is the total debt.
withdrawableAmount = totalDebt;
// If the stream balance is less than the total debt, the withdrawable amount is the balance.
uint128 withdrawableAmount;
if (initialBalance < initialTotalDebt) {
withdrawableAmount = initialBalance;
}
// Otherwise, the withdrawable amount is the total debt.
else {
withdrawableAmount = initialTotalDebt;
}

// Check: the withdraw amount is not greater than the withdrawable amount.
if (amount > withdrawableAmount) {
revert Errors.SablierFlow_Overdraw(streamId, amount, withdrawableAmount);
}

if (amount <= _streams[streamId].snapshotDebt) {
// The `if` condition is triggered when:
// - The stream is paused or voided i.e. total debt = snapshot debt, allowing users to withdraw the entire
// withdrawable amount.
// - The amount does not exceed the snapshot debt for non paused streams.
//
// Effect: reduce the amount from the snapshot debt and leave ongoing debt unchanged.
_streams[streamId].snapshotDebt -= amount;
} else {
// Otherwise, reduce the differece from the ongoing debt by adjusting the snapshot time.
// Note:
/// - If rate per second is zero i.e. a paused stream, the `if` condition will be executed. Therefore, the
/// following division will never throw a division by zero error.
//
// Explaination:
// - Division by rps introduces many-to-one relation between amount and time. There can exist a range
// [amount, amount + rps), which maps to the same time. Therefore, we need to adjust the amount withdrawn to
// ensure that it matches the lower bound of the range. This guarantees that the streamed amount is not lost
// due to the rounding by the division.
// Steps:
// - We set snapshot debt to 0.
// - the difference, amount - snapshot debt, is deducted from the the ongoing debt.
// - To do that, we calculate the time it would take to stream the difference at the current rate per
// second.
// - We then add this time to the snapshot time to get the new snapshot time.
// - The new snapshot time = snapshot time + (amount - snapshot debt) / rate per second.
_streams[streamId].snapshotTime += uint40(
((amount - _streams[streamId].snapshotDebt) * (10 ** (18 - _streams[streamId].tokenDecimals)))
/ _streams[streamId].ratePerSecond.unwrap()
);
if (withdrawAmount > withdrawableAmount) {
revert Errors.SablierFlow_Overdraw(streamId, withdrawAmount, withdrawableAmount);
}

uint128 ongoingDebt;

// If the withdraw amount is less than the snapshot debt, use the snapshot debt as a funding source for the
// withdrawal and leave both the withdraw amount the ongoing debt unchanged.
//
// The condition is evaluated true in the following cases:
// - The stream is not paused and the amount does not exceed the snapshot debt.
// - The stream is paused or voided, i.e. total debt == snapshot debt.
if (withdrawAmount <= _streams[streamId].snapshotDebt) {
_streams[streamId].snapshotDebt -= withdrawAmount;
}
// Otherwise, adjust the snapshot time, set the snapshot debt to zero, and also adjust the withdraw amount.
//
// Dividing by the rps produces a many-to-one relation between time inputs and streamed amounts. There exists a
// range [amount, amount + rps) that maps to the same time. This is especially problematic for tokens with small
// decimals, e.g., USDC which has 6 decimals.
//
// To solve this, we need to adjust the amount withdrawn to ensure that it equals the lower bound of the range.
// This guarantees that part of the streamed amount is not lost due to rounding errors.
//
// Steps:
// - Calculate the difference between the withdraw amount the snapshot debt.
// - Scale the difference up to 18 decimals.
// - Divide it by the rate per second, which is also an 18-decimal number, and obtain the time it would take to
// stream the difference at the current rate per second.
// - Add the resultant value to the snapshot time.
// - Set the snapshot debt to zero.
// - Recalculate the ongoing debt based on the new snapshot time.
// - Set the withdraw amount to the initial total debt minus the ongoing debt. This may result in a value less
// than the initial withdraw amount.
//
// Note: the rate per second cannot be zero because this can only happen when the stream is paused. In that
// case, the `if` condition will be executed.
else {
uint128 difference;
unchecked {
difference = withdrawAmount - _streams[streamId].snapshotDebt;
}
uint128 scaledDifference = difference * uint128(10 ** (18 - _streams[streamId].tokenDecimals));
uint128 rps = _streams[streamId].ratePerSecond.unwrap();
_streams[streamId].snapshotTime += uint40(scaledDifference / rps);

// Set the snapshot debt to zero.
_streams[streamId].snapshotDebt = 0;

// Adjust the amount withdrawn so that previous total debt - new total debt = amount withdrawn. Note
// that, at this point, new total debt = ongoing debt.
amount = totalDebt - _ongoingDebtOf(streamId);
// Adjust the withdraw amount. At this point, new total debt == ongoing debt.
ongoingDebt = _ongoingDebtOf(streamId);
withdrawAmount = initialTotalDebt - ongoingDebt;
}

// Effect: update the stream balance.
_streams[streamId].balance -= amount;
_streams[streamId].balance -= withdrawAmount;

// Load the variables in memory.
IERC20 token = _streams[streamId].token;
UD60x18 protocolFee = protocolFee[token];

// Calculate the protocol fee amount and the net withdraw amount.
uint128 netWithdrawnAmount;
uint128 feeAmount;

uint128 protocolFeeAmount;
if (protocolFee > ZERO) {
// Calculate the protocol fee amount and the net withdraw amount.
(feeAmount, netWithdrawnAmount) = Helpers.calculateAmountsFromFee({ totalAmount: amount, fee: protocolFee });
(protocolFeeAmount, netWithdrawnAmount) =
Helpers.calculateAmountsFromFee({ totalAmount: withdrawAmount, fee: protocolFee });

// Safe to use unchecked because addition cannot overflow.
unchecked {
// Effect: update the protocol revenue.
protocolRevenue[token] += feeAmount;
protocolRevenue[token] += protocolFeeAmount;
}
} else {
netWithdrawnAmount = amount;
netWithdrawnAmount = withdrawAmount;
}

// Interaction: perform the ERC-20 transfer.
token.safeTransfer({ to: to, value: netWithdrawnAmount });

// Protocol Invariant: the difference in total debt should be equal to the difference in the stream balance.
assert(totalDebt - _totalDebtOf(streamId) == balance - _streams[streamId].balance);
// Protocol Invariant: the new total debt is equal to the ongoing debt.
uint128 newTotalDebt = _totalDebtOf(streamId);
// TODO: this assertion does not work, it leads to failed tests
// assert(newTotalDebt == ongoingDebt);

// Protocol Invariant: the difference between total debts should be equal to the difference between stream
// balances.
assert(initialTotalDebt - newTotalDebt == initialBalance - _streams[streamId].balance);

// Log the withdrawal.
emit ISablierFlow.WithdrawFromFlowStream({
streamId: streamId,
to: to,
token: token,
caller: msg.sender,
protocolFeeAmount: feeAmount,
withdrawAmount: netWithdrawnAmount
withdrawAmount: netWithdrawnAmount,
protocolFeeAmount: protocolFeeAmount
});

// Return the amount withdrawn + protocol fee.
return amount;
return netWithdrawnAmount + protocolFeeAmount;
}
}
12 changes: 6 additions & 6 deletions src/interfaces/ISablierFlow.sol
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,17 @@ interface ISablierFlow is
/// @param to The address that received the withdrawn tokens.
/// @param token The contract address of the ERC-20 token that was withdrawn.
/// @param caller The address that performed the withdrawal, which can be the recipient or an approved operator.
/// @param protocolFeeAmount The amount of protocol fee deducted from the withdrawn amount, denoted in token's
/// decimals.
/// @param withdrawAmount The amount withdrawn to the recipient after subtracting the protocol fee, denoted in
/// token's decimals.
/// @param protocolFeeAmount The amount of protocol fee deducted from the withdrawn amount, denoted in token's
/// decimals.
event WithdrawFromFlowStream(
uint256 indexed streamId,
address indexed to,
IERC20 indexed token,
address caller,
uint128 protocolFeeAmount,
uint128 withdrawAmount
uint128 withdrawAmount,
uint128 protocolFeeAmount
);

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -403,7 +403,7 @@ interface ISablierFlow is
/// @param to The address receiving the withdrawn tokens.
/// @param amount The amount to withdraw, denoted in token's decimals.
/// @return amountWithdrawn The amount withdrawn to the recipient including protocol fee, denoted in token's
/// decimals. This may slightly differ from `amount` provided.
/// decimals. This may slightly differ from the `amount` provided.
function withdraw(uint256 streamId, address to, uint128 amount) external returns (uint128 amountWithdrawn);

/// @notice Withdraws the entire withdrawable amount from the stream to the provided address `to`.
Expand All @@ -419,6 +419,6 @@ interface ISablierFlow is
/// @param streamId The ID of the stream to withdraw from.
/// @param to The address receiving the withdrawn tokens.
/// @return amountWithdrawn The amount withdrawn to the recipient including protocol fee, denoted in token's
/// decimals. This may slightly differ from `coveredDebt` amount.
/// decimals. This may slightly differ from the covered debt value.
function withdrawMax(uint256 streamId, address to) external returns (uint128 amountWithdrawn);
}
8 changes: 4 additions & 4 deletions test/integration/concrete/batch/batch.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ contract Batch_Integration_Concrete_Test is Integration_Test {
to: users.recipient,
token: usdc,
caller: users.sender,
protocolFeeAmount: 0,
withdrawAmount: WITHDRAW_AMOUNT_6D
withdrawAmount: WITHDRAW_AMOUNT_6D,
protocolFeeAmount: 0
});

vm.expectEmit({ emitter: address(flow) });
Expand All @@ -362,8 +362,8 @@ contract Batch_Integration_Concrete_Test is Integration_Test {
to: users.recipient,
token: usdc,
protocolFeeAmount: 0,
caller: users.sender,
withdrawAmount: WITHDRAW_AMOUNT_6D
withdrawAmount: WITHDRAW_AMOUNT_6D,
caller: users.sender
});

vm.expectEmit({ emitter: address(flow) });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ contract UncoveredDebtOf_Integration_Concrete_Test is Integration_Test {
// Simulate the passage of time to accumulate uncovered debt for one month.
vm.warp({ newTimestamp: WARP_SOLVENCY_PERIOD + ONE_MONTH });

uint128 totalStreamed = getDenormalizedAmount(RATE_PER_SECOND_U128 * (SOLVENCY_PERIOD + ONE_MONTH), 6);
uint128 totalStreamed = getDescaledAmount(RATE_PER_SECOND_U128 * (SOLVENCY_PERIOD + ONE_MONTH), 6);

// It should return non-zero value.
uint128 actualUncoveredDebt = flow.uncoveredDebtOf(defaultStreamId);
Expand Down
36 changes: 18 additions & 18 deletions test/integration/concrete/withdraw/withdraw.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -214,25 +214,25 @@ contract Withdraw_Integration_Concrete_Test is Integration_Test {
struct Vars {
uint128 feeAmount;
// Previous values.
uint128 previousProtocolRevenue;
uint40 previousSnapshotTime;
uint128 previousStreamBalance;
uint256 previousTokenBalance;
uint128 previousTotalDebt;
uint256 previousUserBalance;
uint128 initialProtocolRevenue;
uint40 initialSnapshotTime;
uint128 initialStreamBalance;
uint256 initialTokenBalance;
uint128 initialTotalDebt;
uint256 initialUserBalance;
}

Vars internal vars;

function _test_Withdraw(uint256 streamId, address to, uint128 withdrawAmount) private {
IERC20 token = flow.getToken(streamId);

vars.previousProtocolRevenue = flow.protocolRevenue(token);
vars.previousTokenBalance = token.balanceOf(address(flow));
vars.previousTotalDebt = flow.totalDebtOf(streamId);
vars.previousSnapshotTime = flow.getSnapshotTime(streamId);
vars.previousStreamBalance = flow.getBalance(streamId);
vars.previousUserBalance = token.balanceOf(to);
vars.initialProtocolRevenue = flow.protocolRevenue(token);
vars.initialTokenBalance = token.balanceOf(address(flow));
vars.initialTotalDebt = flow.totalDebtOf(streamId);
vars.initialSnapshotTime = flow.getSnapshotTime(streamId);
vars.initialStreamBalance = flow.getBalance(streamId);
vars.initialUserBalance = token.balanceOf(to);

vm.expectEmit({ emitter: address(flow) });
emit MetadataUpdate({ _tokenId: streamId });
Expand All @@ -245,24 +245,24 @@ contract Withdraw_Integration_Concrete_Test is Integration_Test {
}

// Assert the protocol revenue.
assertEq(flow.protocolRevenue(token), vars.previousProtocolRevenue + vars.feeAmount, "protocol revenue");
assertEq(flow.protocolRevenue(token), vars.initialProtocolRevenue + vars.feeAmount, "protocol revenue");

// Check the states after the withdrawal.
assertEq(
vars.previousTokenBalance - token.balanceOf(address(flow)),
vars.initialTokenBalance - token.balanceOf(address(flow)),
actualWithdrawnAmount - vars.feeAmount,
"token balance == amount withdrawn - fee amount"
);
assertEq(
vars.previousTotalDebt - flow.totalDebtOf(streamId), actualWithdrawnAmount, "total debt == amount withdrawn"
vars.initialTotalDebt - flow.totalDebtOf(streamId), actualWithdrawnAmount, "total debt == amount withdrawn"
);
assertEq(
vars.previousStreamBalance - flow.getBalance(streamId),
vars.initialStreamBalance - flow.getBalance(streamId),
actualWithdrawnAmount,
"stream balance == amount withdrawn"
);
assertEq(
token.balanceOf(to) - vars.previousUserBalance,
token.balanceOf(to) - vars.initialUserBalance,
actualWithdrawnAmount - vars.feeAmount,
"user balance == token balance - fee amount"
);
Expand All @@ -275,7 +275,7 @@ contract Withdraw_Integration_Concrete_Test is Integration_Test {
);

// It should update snapshot time.
assertGe(flow.getSnapshotTime(streamId), vars.previousSnapshotTime, "snapshot time");
assertGe(flow.getSnapshotTime(streamId), vars.initialSnapshotTime, "snapshot time");

// It should return the actual withdrawn amount.
assertGe(withdrawAmount, actualWithdrawnAmount, "withdrawn amount");
Expand Down
2 changes: 1 addition & 1 deletion test/integration/fuzz/Fuzz.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ abstract contract Shared_Integration_Fuzz_Test is Integration_Test {
uint128 amountSeed = uint128(uint256(keccak256(abi.encodePacked(flow.nextStreamId(), decimals))));
// Bound the amount between a realistic range.
uint128 amount = boundUint128(amountSeed, 1, 1_000_000_000e18);
uint128 depositAmount = getDenormalizedAmount(amount, decimals);
uint128 depositAmount = getDescaledAmount(amount, decimals);

// Deposit into the stream.
deposit(streamId, depositAmount);
Expand Down
2 changes: 1 addition & 1 deletion test/integration/fuzz/coveredDebtOf.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ contract CoveredDebtOf_Integration_Fuzz_Test is Shared_Integration_Fuzz_Test {

// Assert that the covered debt equals the ongoing debt.
uint128 actualCoveredDebt = flow.coveredDebtOf(streamId);
uint128 expectedCoveredDebt = getDenormalizedAmount(ratePerSecond * (warpTimestamp - MAY_1_2024), decimals);
uint128 expectedCoveredDebt = getDescaledAmount(ratePerSecond * (warpTimestamp - MAY_1_2024), decimals);
assertEq(actualCoveredDebt, expectedCoveredDebt);
}

Expand Down
Loading

0 comments on commit 2425778

Please sign in to comment.