From bcb8105526b068192c0f85a89bef9904c0f080ed Mon Sep 17 00:00:00 2001 From: Ian He <39037239+ianhe8x@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:51:22 +1200 Subject: [PATCH] whitelist contract consumer --- contracts/StateChannel.sol | 52 +++++++++++++++++++++++++------- publish/ABI/StateChannel.json | 56 ++++++++++++++++++++++++++++++++++ test/ConsumerHost.test.ts | 57 +++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 11 deletions(-) diff --git a/contracts/StateChannel.sol b/contracts/StateChannel.sol index 34aad9e1..0ad5b147 100644 --- a/contracts/StateChannel.sol +++ b/contracts/StateChannel.sol @@ -76,6 +76,8 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { /// @notice The price of the channel mapping(uint256 => uint256) public channelPrice; + mapping(address => bool) public consumerContractWhitelist; + /// @dev ### EVENTS /// @notice Emitted when open a channel for Pay-as-you-go service event ChannelOpen( @@ -112,6 +114,8 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { address indexer, uint256 amount ); + /// @notice Emitted when set the parameter + event ConsumerContractWhitelistChanged(address consumerContract, bool status); /** * @dev ### FUNCTIONS @@ -144,6 +148,14 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { emit Parameter('terminateExpiration', abi.encodePacked(terminateExpiration)); } + function setConsumerContractWhitelist( + address consumerContract, + bool status + ) external onlyOwner { + consumerContractWhitelist[consumerContract] = status; + emit ConsumerContractWhitelistChanged(consumerContract, status); + } + /** * @notice Get the channel info * @param channelId channel id @@ -203,7 +215,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { callback ) ); - if (_isContract(consumer)) { + if (_isValidContractConsumer(consumer)) { require(consumer.supportsInterface(type(IConsumer).interfaceId), 'G018'); require(IConsumer(consumer).checkSign(channelId, payload, consumerSign), 'C006'); } else { @@ -270,7 +282,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { bytes32 payload = keccak256( abi.encode(channelId, indexer, consumer, price, preExpirationAt, expiration) ); - if (_isContract(consumer)) { + if (_isValidContractConsumer(consumer)) { require(IConsumer(consumer).checkSign(channelId, payload, consumerSign), 'C006'); } else { _checkSign(payload, consumerSign, consumer, false); @@ -308,7 +320,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { ); // check sign - if (_isContract(consumer)) { + if (_isValidContractConsumer(consumer)) { require(IConsumer(consumer).checkSign(channelId, payload, sign), 'C006'); } else { _checkSign(payload, sign, consumer, false); @@ -361,7 +373,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { ).getController(state.indexer); isIndexer = msg.sender == controller; } - if (_isContract(state.consumer)) { + if (_isValidContractConsumer(state.consumer)) { isConsumer = IConsumer(state.consumer).checkSender(query.channelId, msg.sender); } require(isIndexer || isConsumer, 'G008'); @@ -401,7 +413,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { ).getController(state.indexer); isIndexer = msg.sender == controller; } - if (_isContract(state.consumer)) { + if (_isValidContractConsumer(state.consumer)) { isConsumer = IConsumer(state.consumer).checkSender(channelId, msg.sender); } require(isIndexer || isConsumer, 'G008'); @@ -429,7 +441,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { require(state.status == ChannelStatus.Terminating, 'SC007'); if (state.terminateByIndexer) { bool isConsumer = msg.sender == state.consumer; - if (_isContract(state.consumer)) { + if (_isValidContractConsumer(state.consumer)) { isConsumer = IConsumer(state.consumer).checkSender(query.channelId, msg.sender); } require(isConsumer, 'G008'); @@ -471,7 +483,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { ) private view { address indexer = channels[channelId].indexer; address consumer = channels[channelId].consumer; - if (_isContract(consumer)) { + if (_isValidContractConsumer(consumer)) { require(IConsumer(consumer).checkSign(channelId, payload, consumerSign), 'C006'); } else { _checkSign(payload, consumerSign, consumer, false); @@ -572,7 +584,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { uint256 spent = channels[channelId].spent; address realConsumer = consumer; - if (_isContract(consumer)) { + if (_isValidContractConsumer(consumer)) { realConsumer = IConsumer(consumer).channelConsumer(channelId); } @@ -615,7 +627,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { ); } - if (_isContract(consumer)) { + if (_isValidContractConsumer(consumer)) { IConsumer(consumer).claimed(channelId, realRemain); } @@ -643,7 +655,8 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { bytes memory callback ) internal { address realConsumer = consumer; - if (_isContract(consumer)) { + bool isCConsumer = _isValidContractConsumer(consumer); + if (isCConsumer) { IConsumer cConsumer = IConsumer(consumer); realConsumer = cConsumer.channelConsumer(channelId); } @@ -654,7 +667,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { uint256 realAmount = 0; if (fundByReward < amount) { realAmount = amount - fundByReward; - if (_isContract(consumer)) { + if (isCConsumer) { IConsumer(consumer).paid(channelId, msg.sender, realAmount, callback); } IERC20(settings.getContractAddress(SQContracts.SQToken)).safeTransferFrom( @@ -667,4 +680,21 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter { channels[channelId].realTotal += realAmount; channels[channelId].total += amount; } + + /// @dev check if consumer is valid contract consumer + /// @return false if it is not a contract + /// true if it is a contract and implements IConsumer interface and in the whitelist + /// throw G018 if it doesn't implements IConsumer or not in the whitelist + function _isValidContractConsumer(address consumer) private view returns (bool) { + if (_isContract(consumer)) { + require( + consumer.supportsInterface(type(IConsumer).interfaceId) && + consumerContractWhitelist[consumer], + 'G018' + ); + return true; + } else { + return false; + } + } } diff --git a/publish/ABI/StateChannel.json b/publish/ABI/StateChannel.json index 26434512..5d48eefb 100644 --- a/publish/ABI/StateChannel.json +++ b/publish/ABI/StateChannel.json @@ -241,6 +241,25 @@ "name": "ChannelTerminate", "type": "event" }, + { + "anonymous": false, + "inputs": [ + { + "indexed": false, + "internalType": "address", + "name": "consumerContract", + "type": "address" + }, + { + "indexed": false, + "internalType": "bool", + "name": "status", + "type": "bool" + } + ], + "name": "ConsumerContractWhitelistChanged", + "type": "event" + }, { "anonymous": false, "inputs": [ @@ -435,6 +454,25 @@ "stateMutability": "nonpayable", "type": "function" }, + { + "inputs": [ + { + "internalType": "address", + "name": "", + "type": "address" + } + ], + "name": "consumerContractWhitelist", + "outputs": [ + { + "internalType": "bool", + "name": "", + "type": "bool" + } + ], + "stateMutability": "view", + "type": "function" + }, { "inputs": [ { @@ -637,6 +675,24 @@ "stateMutability": "nonpayable", "type": "function" }, + { + "inputs": [ + { + "internalType": "address", + "name": "consumerContract", + "type": "address" + }, + { + "internalType": "bool", + "name": "status", + "type": "bool" + } + ], + "name": "setConsumerContractWhitelist", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, { "inputs": [ { diff --git a/test/ConsumerHost.test.ts b/test/ConsumerHost.test.ts index c09b97d3..08ee9d28 100644 --- a/test/ConsumerHost.test.ts +++ b/test/ConsumerHost.test.ts @@ -279,6 +279,7 @@ describe('ConsumerHost Contract', () => { describe('Consumer Host State Channel should work', () => { beforeEach(async () => { + await stateChannel.setConsumerContractWhitelist(consumerHost.address, true); await registerRunner(token, indexerRegistry, staking, wallet_0, runner, etherParse('2000')); await consumerHost.connect(wallet_0).addSigner(hoster.address); await token.connect(wallet_0).transfer(consumer.address, etherParse('10')); @@ -464,5 +465,61 @@ describe('ConsumerHost Contract', () => { const cBalance4 = await consumerHost.consumers(consumer.address); expect(cBalance4.balance).to.equal(etherParse('0.28')); }); + + it('non whitelisted cconsumer will fail', async () => { + expect(await token.balanceOf(consumerHost.address)).to.equal(etherParse('20')); + await stateChannel.setConsumerContractWhitelist(consumerHost.address, false); + + const channelId = ethers.utils.randomBytes(32); + + const abi = ethers.utils.defaultAbiCoder; + const consumerSign = '0x'; + const amount = etherParse('2'); + const price = etherParse('0.1'); + const expiration = 60; + + const consumerCallback = abi.encode(['address', 'bytes'], [consumer.address, consumerSign]); + + const msg = abi.encode( + ['uint256', 'address', 'address', 'uint256', 'uint256', 'uint256', 'bytes32', 'bytes'], + [ + channelId, + runner.address, + consumerHost.address, + amount, + price, + expiration, + deploymentId, + consumerCallback, + ] + ); + const payloadHash = ethers.utils.keccak256(msg); + + const indexerSign = await runner.signMessage(ethers.utils.arrayify(payloadHash)); + const hosterSign = await hoster.signMessage(ethers.utils.arrayify(payloadHash)); + + const recoveredIndexer = ethers.utils.verifyMessage(ethers.utils.arrayify(payloadHash), indexerSign); + expect(runner.address).to.equal(recoveredIndexer); + + const recoveredHoster = ethers.utils.verifyMessage(ethers.utils.arrayify(payloadHash), hosterSign); + expect(hoster.address).to.equal(recoveredHoster); + + await expect( + stateChannel + .connect(hoster) + .open( + channelId, + runner.address, + consumerHost.address, + amount, + price, + expiration, + deploymentId, + consumerCallback, + indexerSign, + hosterSign + ) + ).to.revertedWith('G018'); + }); }); });