|
10 | 10 | import inspect |
11 | 11 | import json |
12 | 12 | import os |
| 13 | +import sys |
| 14 | +import types |
13 | 15 | import unittest |
14 | 16 | from unittest.mock import MagicMock, patch |
15 | 17 |
|
16 | 18 | import httpx |
17 | 19 |
|
18 | 20 | from cohere.manually_maintained.cohere_aws.mode import Mode |
19 | 21 |
|
| 22 | +if "tokenizers" not in sys.modules: |
| 23 | + tokenizers_stub = types.ModuleType("tokenizers") |
| 24 | + tokenizers_stub.Tokenizer = object |
| 25 | + sys.modules["tokenizers"] = tokenizers_stub |
| 26 | + |
| 27 | +if "fastavro" not in sys.modules: |
| 28 | + fastavro_stub = types.ModuleType("fastavro") |
| 29 | + fastavro_stub.parse_schema = lambda schema: schema |
| 30 | + fastavro_stub.reader = lambda *args, **kwargs: iter(()) |
| 31 | + fastavro_stub.writer = lambda *args, **kwargs: None |
| 32 | + sys.modules["fastavro"] = fastavro_stub |
| 33 | + |
20 | 34 |
|
21 | 35 | class TestSigV4HostHeader(unittest.TestCase): |
22 | 36 | """Fix 1: The headers dict passed to AWSRequest for SigV4 signing must |
@@ -47,8 +61,11 @@ def capture_aws_request(**kwargs): # type: ignore |
47 | 61 | mock_session.get_credentials.return_value = MagicMock() |
48 | 62 | mock_boto3.Session.return_value = mock_session |
49 | 63 |
|
50 | | - with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \ |
51 | | - patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3): |
| 64 | + import cohere.aws_client as aws_client_module |
| 65 | + |
| 66 | + with patch.object(aws_client_module, "lazy_botocore", return_value=mock_botocore), patch.object( |
| 67 | + aws_client_module, "lazy_boto3", return_value=mock_boto3 |
| 68 | + ): |
52 | 69 |
|
53 | 70 | from cohere.aws_client import map_request_to_bedrock |
54 | 71 |
|
|
0 commit comments