|
23 | 23 |
|
24 | 24 | from writerai import Writer, AsyncWriter, APIResponseValidationError |
25 | 25 | from writerai._types import Omit |
26 | | -from writerai._utils import maybe_transform |
27 | 26 | from writerai._models import BaseModel, FinalRequestOptions |
28 | | -from writerai._constants import RAW_RESPONSE_HEADER |
29 | 27 | from writerai._streaming import Stream, AsyncStream |
30 | 28 | from writerai._exceptions import WriterError, APIStatusError, APITimeoutError, APIResponseValidationError |
31 | 29 | from writerai._base_client import ( |
|
36 | 34 | DefaultAsyncHttpxClient, |
37 | 35 | make_request_options, |
38 | 36 | ) |
39 | | -from writerai.types.chat_chat_params import ChatChatParamsNonStreaming |
40 | 37 |
|
41 | 38 | from .utils import update_env |
42 | 39 |
|
@@ -725,60 +722,21 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str |
725 | 722 |
|
726 | 723 | @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
727 | 724 | @pytest.mark.respx(base_url=base_url) |
728 | | - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 725 | + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Writer) -> None: |
729 | 726 | respx_mock.post("/v1/chat").mock(side_effect=httpx.TimeoutException("Test timeout error")) |
730 | 727 |
|
731 | 728 | with pytest.raises(APITimeoutError): |
732 | | - self.client.post( |
733 | | - "/v1/chat", |
734 | | - body=cast( |
735 | | - object, |
736 | | - maybe_transform( |
737 | | - dict( |
738 | | - messages=[ |
739 | | - { |
740 | | - "content": "Write a haiku about programming", |
741 | | - "role": "user", |
742 | | - } |
743 | | - ], |
744 | | - model="palmyra-x5", |
745 | | - ), |
746 | | - ChatChatParamsNonStreaming, |
747 | | - ), |
748 | | - ), |
749 | | - cast_to=httpx.Response, |
750 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
751 | | - ) |
| 729 | + client.chat.with_streaming_response.chat(messages=[{"role": "user"}], model="model").__enter__() |
752 | 730 |
|
753 | 731 | assert _get_open_connections(self.client) == 0 |
754 | 732 |
|
755 | 733 | @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
756 | 734 | @pytest.mark.respx(base_url=base_url) |
757 | | - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 735 | + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Writer) -> None: |
758 | 736 | respx_mock.post("/v1/chat").mock(return_value=httpx.Response(500)) |
759 | 737 |
|
760 | 738 | with pytest.raises(APIStatusError): |
761 | | - self.client.post( |
762 | | - "/v1/chat", |
763 | | - body=cast( |
764 | | - object, |
765 | | - maybe_transform( |
766 | | - dict( |
767 | | - messages=[ |
768 | | - { |
769 | | - "content": "Write a haiku about programming", |
770 | | - "role": "user", |
771 | | - } |
772 | | - ], |
773 | | - model="palmyra-x5", |
774 | | - ), |
775 | | - ChatChatParamsNonStreaming, |
776 | | - ), |
777 | | - ), |
778 | | - cast_to=httpx.Response, |
779 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
780 | | - ) |
781 | | - |
| 739 | + client.chat.with_streaming_response.chat(messages=[{"role": "user"}], model="model").__enter__() |
782 | 740 | assert _get_open_connections(self.client) == 0 |
783 | 741 |
|
784 | 742 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) |
@@ -1594,60 +1552,25 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte |
1594 | 1552 |
|
1595 | 1553 | @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
1596 | 1554 | @pytest.mark.respx(base_url=base_url) |
1597 | | - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1555 | + async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: |
1598 | 1556 | respx_mock.post("/v1/chat").mock(side_effect=httpx.TimeoutException("Test timeout error")) |
1599 | 1557 |
|
1600 | 1558 | with pytest.raises(APITimeoutError): |
1601 | | - await self.client.post( |
1602 | | - "/v1/chat", |
1603 | | - body=cast( |
1604 | | - object, |
1605 | | - maybe_transform( |
1606 | | - dict( |
1607 | | - messages=[ |
1608 | | - { |
1609 | | - "content": "Write a haiku about programming", |
1610 | | - "role": "user", |
1611 | | - } |
1612 | | - ], |
1613 | | - model="palmyra-x5", |
1614 | | - ), |
1615 | | - ChatChatParamsNonStreaming, |
1616 | | - ), |
1617 | | - ), |
1618 | | - cast_to=httpx.Response, |
1619 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1620 | | - ) |
| 1559 | + await async_client.chat.with_streaming_response.chat( |
| 1560 | + messages=[{"role": "user"}], model="model" |
| 1561 | + ).__aenter__() |
1621 | 1562 |
|
1622 | 1563 | assert _get_open_connections(self.client) == 0 |
1623 | 1564 |
|
1624 | 1565 | @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
1625 | 1566 | @pytest.mark.respx(base_url=base_url) |
1626 | | - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1567 | + async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: |
1627 | 1568 | respx_mock.post("/v1/chat").mock(return_value=httpx.Response(500)) |
1628 | 1569 |
|
1629 | 1570 | with pytest.raises(APIStatusError): |
1630 | | - await self.client.post( |
1631 | | - "/v1/chat", |
1632 | | - body=cast( |
1633 | | - object, |
1634 | | - maybe_transform( |
1635 | | - dict( |
1636 | | - messages=[ |
1637 | | - { |
1638 | | - "content": "Write a haiku about programming", |
1639 | | - "role": "user", |
1640 | | - } |
1641 | | - ], |
1642 | | - model="palmyra-x5", |
1643 | | - ), |
1644 | | - ChatChatParamsNonStreaming, |
1645 | | - ), |
1646 | | - ), |
1647 | | - cast_to=httpx.Response, |
1648 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1649 | | - ) |
1650 | | - |
| 1571 | + await async_client.chat.with_streaming_response.chat( |
| 1572 | + messages=[{"role": "user"}], model="model" |
| 1573 | + ).__aenter__() |
1651 | 1574 | assert _get_open_connections(self.client) == 0 |
1652 | 1575 |
|
1653 | 1576 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) |
|
0 commit comments