Skip to content

Commit 42c69c7

Browse files
Validate job phase when delivering job
1 parent 2377fcf commit 42c69c7

2 files changed

Lines changed: 16 additions & 23 deletions

File tree

tests/unit/test_job.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FeeType,
1616
OperationPayload,
1717
)
18+
from virtuals_acp.exceptions import ACPError
1819
from virtuals_acp.fare import Fare, FareAmount
1920

2021
TEST_AGENT_ADDRESS = "0x1234567890123456789012345678901234567890"
@@ -558,6 +559,7 @@ def test_should_create_completed_memo_with_deliverable(
558559
mock_memo = MagicMock(spec=ACPMemo)
559560
mock_memo.next_phase = ACPJobPhase.EVALUATION
560561
basic_job.memos = [mock_memo]
562+
basic_job.phase = ACPJobPhase.TRANSACTION
561563

562564
mock_operation = MagicMock(spec=OperationPayload)
563565
mock_contract_client = mock_acp_client.contract_client_by_address.return_value
@@ -575,16 +577,13 @@ def test_should_create_completed_memo_with_deliverable(
575577
mock_contract_client.create_memo.assert_called_once()
576578
assert result == "0xdelivery"
577579

578-
def test_should_raise_error_when_no_evaluation_memo(self, basic_job):
579-
"""Should raise ValueError when latest memo is not EVALUATION phase"""
580-
mock_memo = MagicMock(spec=ACPMemo)
581-
mock_memo.next_phase = ACPJobPhase.TRANSACTION
582-
basic_job.memos = [mock_memo]
580+
def test_should_raise_error_when_not_in_transaction_phase(self, basic_job):
581+
"""Should raise ACPError when job is not in transaction phase"""
582+
basic_job.phase = ACPJobPhase.NEGOTIATION
583583

584-
# DeliverablePayload is Union[str, Dict], so just use a string
585584
deliverable = "Test deliverable"
586585

587-
with pytest.raises(ValueError, match="No transaction memo found"):
586+
with pytest.raises(ACPError, match="Job is not in transaction phase"):
588587
basic_job.deliver(deliverable)
589588

590589
class TestEvaluate:
@@ -937,6 +936,7 @@ def test_should_create_payable_delivery_with_percentage_fee(
937936
mock_memo = MagicMock(spec=ACPMemo)
938937
mock_memo.next_phase = ACPJobPhase.EVALUATION
939938
basic_job.memos = [mock_memo]
939+
basic_job.phase = ACPJobPhase.TRANSACTION
940940

941941
mock_contract_client = mock_acp_client.contract_client_by_address.return_value
942942
mock_contract_client.approve_allowance.return_value = MagicMock()
@@ -967,6 +967,7 @@ def test_should_skip_fee_when_requested(self, basic_job, mock_acp_client):
967967
mock_memo = MagicMock(spec=ACPMemo)
968968
mock_memo.next_phase = ACPJobPhase.EVALUATION
969969
basic_job.memos = [mock_memo]
970+
basic_job.phase = ACPJobPhase.TRANSACTION
970971

971972
mock_contract_client = mock_acp_client.contract_client_by_address.return_value
972973
mock_contract_client.approve_allowance.return_value = MagicMock()
@@ -988,15 +989,13 @@ def test_should_skip_fee_when_requested(self, basic_job, mock_acp_client):
988989
call_args = mock_contract_client.create_payable_memo.call_args[1]
989990
assert call_args['fee_type'] == FeeType.NO_FEE
990991

991-
def test_should_raise_error_when_no_evaluation_memo(self, basic_job):
992-
"""Should raise ValueError when not in EVALUATION phase"""
993-
mock_memo = MagicMock(spec=ACPMemo)
994-
mock_memo.next_phase = ACPJobPhase.TRANSACTION
995-
basic_job.memos = [mock_memo]
992+
def test_should_raise_error_when_not_in_transaction_phase(self, basic_job):
993+
"""Should raise ACPError when job is not in transaction phase"""
994+
basic_job.phase = ACPJobPhase.NEGOTIATION
996995

997996
fare = FareAmount(1000000, basic_job.base_fare)
998997

999-
with pytest.raises(ValueError, match="No transaction memo found"):
998+
with pytest.raises(ACPError, match="Job is not in transaction phase"):
1000999
basic_job.deliver_payable({}, fare)
10011000

10021001
class TestCreatePayableNotification:

virtuals_acp/job.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,8 @@ def _get_memo_by_id(self, memo_id: int) -> Optional[ACPMemo]:
468468
return next((m for m in self.memos if m.id == memo_id), None)
469469

470470
def deliver(self, deliverable: DeliverablePayload) -> str | None:
471-
if (
472-
self.latest_memo is None
473-
or self.latest_memo.next_phase != ACPJobPhase.EVALUATION
474-
):
475-
raise ValueError("No transaction memo found")
471+
if self.phase != ACPJobPhase.TRANSACTION:
472+
raise ACPError("Job is not in transaction phase")
476473

477474
operations: List[OperationPayload] = []
478475

@@ -496,11 +493,8 @@ def deliver_payable(
496493
skip_fee: bool = False,
497494
expired_at: Optional[datetime] = None,
498495
) -> str | None:
499-
if (
500-
self.latest_memo is None
501-
or self.latest_memo.next_phase != ACPJobPhase.EVALUATION
502-
):
503-
raise ValueError("No transaction memo found")
496+
if self.phase != ACPJobPhase.TRANSACTION:
497+
raise ACPError("Job is not in transaction phase")
504498

505499
if expired_at is None:
506500
expired_at = datetime.now(timezone.utc) + timedelta(minutes=5)

0 commit comments

Comments
 (0)