Skip to content

Commit

Permalink
fix reentrancy
Browse files Browse the repository at this point in the history
  • Loading branch information
syntrust committed Sep 5, 2024
1 parent 3254603 commit 242341f
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 6 deletions.
16 changes: 13 additions & 3 deletions contracts/StorageContract.sol
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,20 @@ abstract contract StorageContract is DecentralizedKV {
/// @notice Treasury address
address public treasury;

/// @notice
/// @notice Prepaid timestamp of last mined
uint256 public prepaidLastMineTime;

/// @notice Locker to prevent from reentrancy
bool private locked;

/// @notice Prevent from reentrancy
modifier noReentrant() {
require(!locked, "StorageContract: No reentrancy allowed!");
locked = true;
_;
locked = false;
}

// TODO: Reserve extra slots (to a total of 50?) in the storage layout for future upgrades

/// @notice Emitted when a block is mined.
Expand Down Expand Up @@ -239,7 +250,6 @@ abstract contract StorageContract is DecentralizedKV {
MiningLib.update(infos[_shardId], _minedTs, _diff);

require(treasuryReward + minerReward <= address(this).balance, "StorageContract: not enough balance");
// TODO: avoid reentrancy attack
payable(treasury).transfer(treasuryReward);
payable(_miner).transfer(minerReward);
emit MinedBlock(_shardId, _diff, infos[_shardId].blockMined, _minedTs, _miner, minerReward);
Expand Down Expand Up @@ -307,7 +317,7 @@ abstract contract StorageContract is DecentralizedKV {
bytes calldata _randaoProof,
bytes[] calldata _inclusiveProofs,
bytes[] calldata _decodeProof
) public virtual {
) public virtual noReentrant {
_mine(
_blockNum, _shardId, _miner, _nonce, _encodedSamples, _masks, _randaoProof, _inclusiveProofs, _decodeProof
);
Expand Down
102 changes: 99 additions & 3 deletions contracts/test/StorageContractTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ import "forge-std/Test.sol";
import "forge-std/Vm.sol";

contract StorageContractTest is Test {
uint256 constant STORAGE_COST = 1000;
uint256 constant STORAGE_COST = 10000000;
uint256 constant SHARD_SIZE_BITS = 19;
uint256 constant MAX_KV_SIZE = 17;
uint256 constant PREPAID_AMOUNT = 2 * STORAGE_COST;
TestStorageContract storageContract;

function setUp() public {
storageContract = new TestStorageContract(
StorageContract.Config(MAX_KV_SIZE, SHARD_SIZE_BITS, 2, 0, 0, 0), 0, STORAGE_COST, 0
StorageContract.Config(MAX_KV_SIZE, SHARD_SIZE_BITS, 2, 0, 0, 0),
0,
STORAGE_COST,
340282366367469178095360967382638002176
);
storageContract.initialize(0, PREPAID_AMOUNT, 0, address(0x1), address(0x1));
storageContract.initialize(0, PREPAID_AMOUNT, 0, vm.addr(1), address(0x1));
}

function testMiningReward() public {
Expand All @@ -40,4 +43,97 @@ contract StorageContractTest is Test {
(,, reward) = storageContract.miningRewards(0, 1);
assertEq(reward, storageContract.paymentIn(PREPAID_AMOUNT + STORAGE_COST * 2, 0, 1));
}

function testRewardMiner() public {
address miner = vm.addr(2);
uint256 mineTs = 10000;
uint256 diff = 1;

vm.expectRevert("StorageContract: not enough balance");
storageContract.rewardMiner(0, miner, mineTs, 1);

vm.deal(address(storageContract), 1000);

(,, uint256 reward) = storageContract.miningRewards(0, mineTs);
storageContract.rewardMiner(0, miner, mineTs, diff);
(uint256 l, uint256 d, uint256 b) = storageContract.infos(0);
assertEq(l, mineTs);
assertEq(d, diff);
assertEq(b, 1);
assertEq(miner.balance, reward);
}

function testReentrancy() public {
vm.pauseGasMetering();
uint256 prefund = 1000;
// Without reentrancy protection, the fund could be drained by 29 times re-entrances given current params.
vm.deal(address(storageContract), prefund);
storageContract.setKvEntryCount(1);
Attacker attacker = new Attacker(storageContract);
vm.prank(address(attacker));

uint256 _blockNum = 1;
uint256 _shardId = 0;
uint256 _nonce = 0;
bytes32[] memory _encodedSamples = new bytes32[](0);
uint256[] memory _masks = new uint256[](0);
bytes memory _randaoProof = "0x01";
bytes[] memory _inclusiveProofs = new bytes[](0);
bytes[] memory _decodeProof = new bytes[](0);

vm.expectRevert("StorageContract: No reentrancy allowed!");
storageContract.mine(
_blockNum,
_shardId,
address(attacker),
_nonce,
_encodedSamples,
_masks,
_randaoProof,
_inclusiveProofs,
_decodeProof
);
}
}

contract Attacker {
// cannot access imported vm directly
address internal constant VM_ADDRESS = address(uint160(uint256(keccak256("hevm cheat code"))));
Vm vm = Vm(VM_ADDRESS);
TestStorageContract storageContract;
uint256 blockNumber = 1;
uint256 count = 0;

constructor(TestStorageContract _storageContract) {
storageContract = _storageContract;
}

fallback() external payable {
uint256 _shardId = 0;
uint256 _nonce = 0;
bytes32[] memory _encodedSamples = new bytes32[](0);
uint256[] memory _masks = new uint256[](0);
bytes memory _randaoProof = "0x01";
bytes[] memory _inclusiveProofs = new bytes[](0);
bytes[] memory _decodeProof = new bytes[](0);

blockNumber += 60;
vm.roll(blockNumber + 20);
vm.warp(block.number * 12);
uint256 reward = storageContract.miningReward(_shardId, blockNumber);
if (address(storageContract).balance >= reward) {
storageContract.mine(
blockNumber,
_shardId,
address(this),
_nonce,
_encodedSamples,
_masks,
_randaoProof,
_inclusiveProofs,
_decodeProof
);
count++;
}
}
}
19 changes: 19 additions & 0 deletions contracts/test/TestStorageContract.sol
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,23 @@ contract TestStorageContract is StorageContract {
function miningRewards(uint256 _shardId, uint256 _minedTs) public view returns (bool, uint256, uint256) {
return _miningReward(_shardId, _minedTs);
}

function rewardMiner(uint256 _shardId, address _miner, uint256 _minedTs, uint256 _diff) public {
return _rewardMiner(_shardId, _miner, _minedTs, _diff);
}

function _mine(
uint256 _blockNum,
uint256 _shardId,
address _miner,
uint256 _nonce,
bytes32[] memory _encodedSamples,
uint256[] memory _masks,
bytes calldata _randaoProof,
bytes[] calldata _inclusiveProofs,
bytes[] calldata _decodeProof
) internal override {
uint256 mineTs = _getMinedTs(_blockNum);
_rewardMiner(_shardId, _miner, mineTs, 1);
}
}

0 comments on commit 242341f

Please sign in to comment.