diff --git a/.gitmodules b/.gitmodules index 690924b..56bfede 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,9 @@ [submodule "lib/openzeppelin-contracts"] path = lib/openzeppelin-contracts url = https://github.com/OpenZeppelin/openzeppelin-contracts +[submodule "lib/openzeppelin-foundry-upgrades"] + path = lib/openzeppelin-foundry-upgrades + url = https://github.com/OpenZeppelin/openzeppelin-foundry-upgrades +[submodule "lib/openzeppelin-contracts-upgradeable"] + path = lib/openzeppelin-contracts-upgradeable + url = https://github.com/OpenZeppelin/openzeppelin-contracts-upgradeable diff --git a/foundry.toml b/foundry.toml index 6d5e0ad..5d231ed 100644 --- a/foundry.toml +++ b/foundry.toml @@ -3,6 +3,11 @@ src = "src" out = "out" libs = ["lib"] gas_reports = ["WanderStaking"] +ffi = true +ast = true +build_info = true +extra_output = ["storageLayout"] + # See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options [fuzz] diff --git a/lib/openzeppelin-contracts-upgradeable b/lib/openzeppelin-contracts-upgradeable new file mode 160000 index 0000000..fa52531 --- /dev/null +++ b/lib/openzeppelin-contracts-upgradeable @@ -0,0 +1 @@ +Subproject commit fa525310e45f91eb20a6d3baa2644be8e0adba31 diff --git a/lib/openzeppelin-foundry-upgrades b/lib/openzeppelin-foundry-upgrades new file mode 160000 index 0000000..16e0ae2 --- /dev/null +++ b/lib/openzeppelin-foundry-upgrades @@ -0,0 +1 @@ +Subproject commit 16e0ae21e0e39049f619f2396fa28c57fad07368 diff --git a/remappings.txt b/remappings.txt index 662ac05..74935de 100644 --- a/remappings.txt +++ b/remappings.txt @@ -1,4 +1,5 @@ -@openzeppelin/contracts/=lib/openzeppelin-contracts/contracts/ +@openzeppelin/contracts/=lib/openzeppelin-contracts-upgradeable/lib/openzeppelin-contracts/contracts/ +@openzeppelin/contracts-upgradeable/=lib/openzeppelin-contracts-upgradeable/contracts/ ds-test/=lib/openzeppelin-contracts/lib/forge-std/lib/ds-test/src/ erc4626-tests/=lib/openzeppelin-contracts/lib/erc4626-tests/ forge-std/=lib/forge-std/src/ diff --git a/src/WanderStaking.sol b/src/WanderStaking.sol index 023947f..759a9a6 100644 --- a/src/WanderStaking.sol +++ b/src/WanderStaking.sol @@ -3,24 +3,35 @@ pragma solidity ^0.8.13; import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; -import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; -import {Pausable} from "@openzeppelin/contracts/utils/Pausable.sol"; +import {PausableUpgradeable} from "@openzeppelin/contracts-upgradeable/utils/PausableUpgradeable.sol"; +import {Initializable} from "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol"; +import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; -contract WanderStaking is Ownable, Pausable { +contract WanderStaking is Initializable, PausableUpgradeable, OwnableUpgradeable, UUPSUpgradeable { using SafeERC20 for IERC20; - IERC20 public immutable token; - event Stake(address indexed user, uint256 amount); event Unstake(address indexed user, uint256 amount); + event SpendFromStake(address indexed user, address indexed to, uint256 amount); error ZeroAmount(); error InsufficientBalance(); - mapping(address => uint256) userStake; + IERC20 public token; uint256 internal totalStaked; + mapping(address => uint256) userStake; + + // @custom:oz-upgrades-unsafe-allow constructor + constructor() { + _disableInitializers(); + } + + function initialize(address initialOwner, IERC20 _token) public initializer { + __Pausable_init(); + __Ownable_init(initialOwner); + __UUPSUpgradeable_init(); - constructor(address initialOwner, IERC20 _token) Ownable(initialOwner) { token = _token; } @@ -32,6 +43,8 @@ contract WanderStaking is Ownable, Pausable { _unpause(); } + function _authorizeUpgrade(address newImplementation) internal override onlyOwner {} + function stake(uint256 amount) external whenNotPaused { if (amount == 0) { revert ZeroAmount(); @@ -62,6 +75,24 @@ contract WanderStaking is Ownable, Pausable { token.safeTransfer(msg.sender, amount); } + function spendFromStake(address to, uint256 amount) external whenNotPaused { + if (amount == 0) { + revert ZeroAmount(); + } + + if (userStake[msg.sender] < amount) { + revert InsufficientBalance(); + } + + userStake[msg.sender] -= amount; + totalStaked -= amount; + + emit Unstake(msg.sender, amount); + emit SpendFromStake(msg.sender, to, amount); + + token.safeTransfer(to, amount); + } + function getTotalStaked() external view returns (uint256) { return totalStaked; } diff --git a/test/WanderStaking.t.sol b/test/WanderStaking.t.sol index 5b2dc60..aab981d 100644 --- a/test/WanderStaking.t.sol +++ b/test/WanderStaking.t.sol @@ -6,6 +6,7 @@ import {WanderStaking} from "../src/WanderStaking.sol"; import {TestToken} from "../src/TestToken.sol"; import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; contract WanderStakingTest is Test { WanderStaking public staking; @@ -14,7 +15,13 @@ contract WanderStakingTest is Test { function setUp() public { token = new TestToken(); - staking = new WanderStaking(address(this), IERC20(token)); + WanderStaking stakingImpl = new WanderStaking(); + + bytes memory data = abi.encodeWithSignature("initialize(address,address)", address(this), address(token)); + ERC1967Proxy proxy = new ERC1967Proxy(address(stakingImpl), data); + + staking = WanderStaking(address(proxy)); + token.approve(address(staking), ~uint256(0)); } @@ -82,4 +89,41 @@ contract WanderStakingTest is Test { vm.expectRevert(); staking.unstake(unstakeAmount); } + + function test_spend(uint64 _amount, uint64 _spendAmount) public { + // hello nick + address to = 0x6F4E4664E9B519DEAB043676D9Aafe6c9621C088; + + vm.assume(_amount > 0); + vm.assume(_spendAmount > 0); + vm.assume(_amount >= _spendAmount); + uint256 amount = uint256(_amount) * (10 ** 18); + uint256 spendAmount = uint256(_spendAmount) * (10 ** 18); + + token.mint(address(this), amount); + staking.stake(amount); + + staking.spendFromStake(to, spendAmount); + + assert(token.balanceOf(to) == spendAmount); + assert(staking.getAmountStaked(address(this)) == amount - spendAmount); + assert(token.balanceOf(address(staking)) == amount - spendAmount); + } + + function test_spend_over(uint64 _amount, uint64 _spendAmount) public { + // hello nick + address to = 0x6F4E4664E9B519DEAB043676D9Aafe6c9621C088; + + vm.assume(_amount > 0); + vm.assume(_spendAmount > 0); + vm.assume(_amount < _spendAmount); + uint256 amount = uint256(_amount) * (10 ** 18); + uint256 spendAmount = uint256(_spendAmount) * (10 ** 18); + + token.mint(address(this), amount); + staking.stake(amount); + + vm.expectRevert(WanderStaking.InsufficientBalance.selector); + staking.spendFromStake(to, spendAmount); + } }