Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions contracts/StateChannel.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
/// @notice The price of the channel
mapping(uint256 => uint256) public channelPrice;

mapping(address => bool) public consumerContractWhitelist;

/// @dev ### EVENTS
/// @notice Emitted when open a channel for Pay-as-you-go service
event ChannelOpen(
Expand Down Expand Up @@ -112,6 +114,8 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
address indexer,
uint256 amount
);
/// @notice Emitted when set the parameter
event ConsumerContractWhitelistChanged(address consumerContract, bool status);

/**
* @dev ### FUNCTIONS
Expand Down Expand Up @@ -144,6 +148,14 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
emit Parameter('terminateExpiration', abi.encodePacked(terminateExpiration));
}

function setConsumerContractWhitelist(
address consumerContract,
bool status
) external onlyOwner {
consumerContractWhitelist[consumerContract] = status;
emit ConsumerContractWhitelistChanged(consumerContract, status);
}

/**
* @notice Get the channel info
* @param channelId channel id
Expand Down Expand Up @@ -203,7 +215,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
callback
)
);
if (_isContract(consumer)) {
if (_isValidContractConsumer(consumer)) {
require(consumer.supportsInterface(type(IConsumer).interfaceId), 'G018');
require(IConsumer(consumer).checkSign(channelId, payload, consumerSign), 'C006');
} else {
Expand Down Expand Up @@ -270,7 +282,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
bytes32 payload = keccak256(
abi.encode(channelId, indexer, consumer, price, preExpirationAt, expiration)
);
if (_isContract(consumer)) {
if (_isValidContractConsumer(consumer)) {
require(IConsumer(consumer).checkSign(channelId, payload, consumerSign), 'C006');
} else {
_checkSign(payload, consumerSign, consumer, false);
Expand Down Expand Up @@ -308,7 +320,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
);

// check sign
if (_isContract(consumer)) {
if (_isValidContractConsumer(consumer)) {
require(IConsumer(consumer).checkSign(channelId, payload, sign), 'C006');
} else {
_checkSign(payload, sign, consumer, false);
Expand Down Expand Up @@ -361,7 +373,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
).getController(state.indexer);
isIndexer = msg.sender == controller;
}
if (_isContract(state.consumer)) {
if (_isValidContractConsumer(state.consumer)) {
isConsumer = IConsumer(state.consumer).checkSender(query.channelId, msg.sender);
}
require(isIndexer || isConsumer, 'G008');
Expand Down Expand Up @@ -401,7 +413,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
).getController(state.indexer);
isIndexer = msg.sender == controller;
}
if (_isContract(state.consumer)) {
if (_isValidContractConsumer(state.consumer)) {
isConsumer = IConsumer(state.consumer).checkSender(channelId, msg.sender);
}
require(isIndexer || isConsumer, 'G008');
Expand Down Expand Up @@ -429,7 +441,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
require(state.status == ChannelStatus.Terminating, 'SC007');
if (state.terminateByIndexer) {
bool isConsumer = msg.sender == state.consumer;
if (_isContract(state.consumer)) {
if (_isValidContractConsumer(state.consumer)) {
isConsumer = IConsumer(state.consumer).checkSender(query.channelId, msg.sender);
}
require(isConsumer, 'G008');
Expand Down Expand Up @@ -471,7 +483,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
) private view {
address indexer = channels[channelId].indexer;
address consumer = channels[channelId].consumer;
if (_isContract(consumer)) {
if (_isValidContractConsumer(consumer)) {
require(IConsumer(consumer).checkSign(channelId, payload, consumerSign), 'C006');
} else {
_checkSign(payload, consumerSign, consumer, false);
Expand Down Expand Up @@ -572,7 +584,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
uint256 spent = channels[channelId].spent;

address realConsumer = consumer;
if (_isContract(consumer)) {
if (_isValidContractConsumer(consumer)) {
realConsumer = IConsumer(consumer).channelConsumer(channelId);
}

Expand Down Expand Up @@ -615,7 +627,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
);
}

if (_isContract(consumer)) {
if (_isValidContractConsumer(consumer)) {
IConsumer(consumer).claimed(channelId, realRemain);
}

Expand Down Expand Up @@ -643,7 +655,8 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
bytes memory callback
) internal {
address realConsumer = consumer;
if (_isContract(consumer)) {
bool isCConsumer = _isValidContractConsumer(consumer);
if (isCConsumer) {
IConsumer cConsumer = IConsumer(consumer);
realConsumer = cConsumer.channelConsumer(channelId);
}
Expand All @@ -654,7 +667,7 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
uint256 realAmount = 0;
if (fundByReward < amount) {
realAmount = amount - fundByReward;
if (_isContract(consumer)) {
if (isCConsumer) {
IConsumer(consumer).paid(channelId, msg.sender, realAmount, callback);
}
IERC20(settings.getContractAddress(SQContracts.SQToken)).safeTransferFrom(
Expand All @@ -667,4 +680,21 @@ contract StateChannel is Initializable, OwnableUpgradeable, SQParameter {
channels[channelId].realTotal += realAmount;
channels[channelId].total += amount;
}

/// @dev check if consumer is valid contract consumer
/// @return false if it is not a contract
/// true if it is a contract and implements IConsumer interface and in the whitelist
/// throw G018 if it doesn't implements IConsumer or not in the whitelist
function _isValidContractConsumer(address consumer) private view returns (bool) {
if (_isContract(consumer)) {
require(
consumer.supportsInterface(type(IConsumer).interfaceId) &&
consumerContractWhitelist[consumer],
'G018'
);
return true;
} else {
return false;
}
}
}
56 changes: 56 additions & 0 deletions publish/ABI/StateChannel.json
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,25 @@
"name": "ChannelTerminate",
"type": "event"
},
{
"anonymous": false,
"inputs": [
{
"indexed": false,
"internalType": "address",
"name": "consumerContract",
"type": "address"
},
{
"indexed": false,
"internalType": "bool",
"name": "status",
"type": "bool"
}
],
"name": "ConsumerContractWhitelistChanged",
"type": "event"
},
{
"anonymous": false,
"inputs": [
Expand Down Expand Up @@ -435,6 +454,25 @@
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"name": "consumerContractWhitelist",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
Expand Down Expand Up @@ -637,6 +675,24 @@
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "consumerContract",
"type": "address"
},
{
"internalType": "bool",
"name": "status",
"type": "bool"
}
],
"name": "setConsumerContractWhitelist",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
Expand Down
57 changes: 57 additions & 0 deletions test/ConsumerHost.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ describe('ConsumerHost Contract', () => {

describe('Consumer Host State Channel should work', () => {
beforeEach(async () => {
await stateChannel.setConsumerContractWhitelist(consumerHost.address, true);
await registerRunner(token, indexerRegistry, staking, wallet_0, runner, etherParse('2000'));
await consumerHost.connect(wallet_0).addSigner(hoster.address);
await token.connect(wallet_0).transfer(consumer.address, etherParse('10'));
Expand Down Expand Up @@ -464,5 +465,61 @@ describe('ConsumerHost Contract', () => {
const cBalance4 = await consumerHost.consumers(consumer.address);
expect(cBalance4.balance).to.equal(etherParse('0.28'));
});

it('non whitelisted cconsumer will fail', async () => {
expect(await token.balanceOf(consumerHost.address)).to.equal(etherParse('20'));
await stateChannel.setConsumerContractWhitelist(consumerHost.address, false);

const channelId = ethers.utils.randomBytes(32);

const abi = ethers.utils.defaultAbiCoder;
const consumerSign = '0x';
const amount = etherParse('2');
const price = etherParse('0.1');
const expiration = 60;

const consumerCallback = abi.encode(['address', 'bytes'], [consumer.address, consumerSign]);

const msg = abi.encode(
['uint256', 'address', 'address', 'uint256', 'uint256', 'uint256', 'bytes32', 'bytes'],
[
channelId,
runner.address,
consumerHost.address,
amount,
price,
expiration,
deploymentId,
consumerCallback,
]
);
const payloadHash = ethers.utils.keccak256(msg);

const indexerSign = await runner.signMessage(ethers.utils.arrayify(payloadHash));
const hosterSign = await hoster.signMessage(ethers.utils.arrayify(payloadHash));

const recoveredIndexer = ethers.utils.verifyMessage(ethers.utils.arrayify(payloadHash), indexerSign);
expect(runner.address).to.equal(recoveredIndexer);

const recoveredHoster = ethers.utils.verifyMessage(ethers.utils.arrayify(payloadHash), hosterSign);
expect(hoster.address).to.equal(recoveredHoster);

await expect(
stateChannel
.connect(hoster)
.open(
channelId,
runner.address,
consumerHost.address,
amount,
price,
expiration,
deploymentId,
consumerCallback,
indexerSign,
hosterSign
)
).to.revertedWith('G018');
});
});
});