Skip to content

Commit 6d81a74

Browse files
Update contract clients for authentication
1 parent f4c3c9d commit 6d81a74

5 files changed

Lines changed: 62 additions & 19 deletions

File tree

virtuals_acp/client.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import signal
66
import sys
77
import threading
8-
8+
import jwt
9+
import socketio
10+
import requests
911
import time
12+
1013
from datetime import datetime, timezone, timedelta
1114
from importlib.metadata import version
1215
from typing import List, Optional, Union, Dict, Any, Callable
13-
14-
import requests
15-
import socketio
1616
from web3 import Web3
17+
from requests.auth import AuthBase
1718

1819
from virtuals_acp.account import ACPAccount
1920
from virtuals_acp.configs.configs import (
@@ -49,11 +50,6 @@
4950

5051

5152

52-
import jwt
53-
import requests
54-
55-
from requests.auth import AuthBase
56-
5753
class BearerAuth(AuthBase):
5854
def __init__(self, get_access_token: Callable[[], str]):
5955
self._get_access_token = get_access_token
@@ -70,7 +66,8 @@ def clear_token(self):
7066

7167

7268
class ACPApiClient:
73-
def __init__(self, acp_url: str, wallet_address: str, require_auth: bool = False):
69+
def __init__(self, acp_contract_client: BaseAcpContractClient, acp_url: str, wallet_address: str, require_auth: bool = False):
70+
self.acp_contract_client = acp_contract_client
7471
self.base_url = f"{acp_url}/api"
7572
self.wallet_address = wallet_address
7673
self.require_auth = require_auth
@@ -130,13 +127,17 @@ def get_access_token(self) -> str:
130127
needs_refresh = True
131128

132129
if not needs_refresh:
133-
return self.access_token
130+
# Access token is still valid
131+
if self.access_token:
132+
return self.access_token
133+
else:
134+
raise Exception("Access token needs refreshing!")
134135

135136
self.access_token = self.refresh_token()
136137
return self.access_token
137138

138139
def refresh_token(self) -> str:
139-
challenge = self.get_auth_challenge(self.wallet_address)
140+
challenge = self.get_auth_challenge()
140141
signature = self.acp_contract_client.sign_typed_data(challenge)
141142

142143
verified = self.verify_auth_challenge(
@@ -204,8 +205,8 @@ def __init__(
204205
"All contract clients must have the same agent wallet address"
205206
)
206207

207-
self.acp_client = ACPApiClient(self.acp_url, self.wallet_address)
208-
self.no_auth_acp_client = ACPApiClient(self.acp_url, self.wallet_address, require_auth=False)
208+
self.acp_client = ACPApiClient(self.acp_contract_client, self.acp_url, self.wallet_address)
209+
self.no_auth_acp_client = ACPApiClient(self.acp_contract_client, self.acp_url, self.wallet_address, require_auth=False)
209210

210211
# Socket.IO setup
211212
self.on_new_task = on_new_task
@@ -235,7 +236,6 @@ def init(self):
235236
logger.info(f"Initializing socket")
236237

237238
try:
238-
# TODO: auth needs to include access token now
239239
auth_data = {
240240
"walletAddress": self.wallet_address,
241241
"accessToken": self.acp_client.get_access_token()

virtuals_acp/contract_clients/base_contract_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,11 @@ def validate_session_key_on_chain(
121121
)
122122

123123
@abstractmethod
124-
def get_acp_version(self) -> str:
124+
def get_asset_manager_address(self) -> str:
125+
pass
126+
127+
@abstractmethod
128+
def sign_typed_data(self, typed_data: dict[str, Any]) -> str:
125129
pass
126130

127131
def _build_user_operation(
@@ -151,7 +155,7 @@ def handle_operation(self, trx_data: List[OperationPayload], chain_id: Optional[
151155

152156
@abstractmethod
153157
def get_job_id(
154-
self, receipt: Dict[str, Any], client_address: str, provider_address: str
158+
self, response: Dict[str, Any], client_address: str, provider_address: str
155159
) -> int:
156160
"""Abstract method to retrieve a job ID from a transaction hash and related addresses."""
157161
pass

virtuals_acp/contract_clients/contract_client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, Any, Optional, List
77

88
from eth_account import Account
9+
from eth_account.messages import encode_typed_data
910
from web3 import Web3
1011

1112
from virtuals_acp.alchemy import AlchemyAccountKit
@@ -237,4 +238,22 @@ def perform_x402_request(
237238
raise ACPError("Failed to perform X402 request", e)
238239

239240
def get_asset_manager_address(self) -> str:
240-
raise ACPError("Not Supported")
241+
raise ACPError("Not Supported")
242+
243+
def sign_typed_data(self, typed_data: dict[str, Any]) -> str:
244+
domain = typed_data["domain"]
245+
types = typed_data["types"]
246+
primary_type = typed_data["primaryType"]
247+
message = typed_data["message"]
248+
249+
# encode_typed_data expects (domain_data, types, primary_type, message_data)
250+
# It handles EIP-712 hashing internally
251+
signable = encode_typed_data(
252+
domain,
253+
types,
254+
primary_type,
255+
message,
256+
)
257+
258+
signed = self.account.sign_message(signable)
259+
return signed.signature.hex()

virtuals_acp/contract_clients/contract_client_v2.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Dict, Any, List, Optional
55

66
from eth_account import Account
7+
from eth_account.messages import encode_typed_data
78
from web3 import Web3
89

910
from virtuals_acp.abis.job_manager import JOB_MANAGER_ABI
@@ -198,4 +199,22 @@ def get_x402_payment_details(self, job_id: int) -> AcpJobX402PaymentDetails:
198199
raise ACPError("Failed to get X402 payment details", e)
199200

200201
def get_asset_manager_address(self) -> str:
201-
return self.memo_manager_contract.functions.assetManager().call()
202+
return self.memo_manager_contract.functions.assetManager().call()
203+
204+
def sign_typed_data(self, typed_data: dict[str, Any]) -> str:
205+
domain = typed_data["domain"]
206+
types = typed_data["types"]
207+
primary_type = typed_data["primaryType"]
208+
message = typed_data["message"]
209+
210+
# encode_typed_data expects (domain_data, types, primary_type, message_data)
211+
# It handles EIP-712 hashing internally
212+
signable = encode_typed_data(
213+
domain,
214+
types,
215+
primary_type,
216+
message,
217+
)
218+
219+
signed = self.account.sign_message(signable)
220+
return signed.signature.hex()

virtuals_acp/web3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from web3 import Web3
44
from virtuals_acp.abis.erc20_abi import ERC20_ABI
55

6+
# TODO: implement wrapper methods in base_contract_client
67

78
def getERC20Balance(
89
public_client: Web3,

0 commit comments

Comments
 (0)