Skip to content

Commit 4f08063

Browse files
feat(instance-types): sync OpenAPI schema fields
Co-authored-by: imagene-shahar <imagene-shahar@users.noreply.github.com>
1 parent 007e2a9 commit 4f08063

2 files changed

Lines changed: 123 additions & 14 deletions

File tree

tests/unit_tests/instance_types/test_instance_types.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,39 @@
1414
STORAGE_DESCRIPTION = '1800GB NVME'
1515
STORAGE_SIZE = 1800
1616
INSTANCE_TYPE_DESCRIPTION = 'Dedicated Bare metal Server'
17+
BEST_FOR = ['Large model inference', 'Multi-GPU training']
18+
MODEL = 'V100'
19+
NAME = 'Tesla V100'
20+
P2P = '300 GB/s'
1721
PRICE_PER_HOUR = 5.0
1822
SPOT_PRICE_PER_HOUR = 2.5
23+
MAX_DYNAMIC_PRICE = 7.5
24+
SERVERLESS_PRICE = 1.25
25+
SERVERLESS_SPOT_PRICE = 0.75
1926
INSTANCE_TYPE = '8V100.48M'
27+
CURRENCY = 'eur'
28+
MANUFACTURER = 'NVIDIA'
29+
DISPLAY_NAME = 'NVIDIA Tesla V100'
30+
SUPPORTED_OS = ['ubuntu-24.04-cuda-12.8-open-docker']
2031

32+
PRICE_HISTORY_PAYLOAD = [{'date': '2024-01-01', 'price_per_hour': '2.00'}]
2133

34+
35+
@responses.activate
2236
def test_instance_types(http_client):
2337
# arrange - add response mock
2438
responses.add(
2539
responses.GET,
26-
http_client._base_url + '/instance-types',
40+
http_client._base_url + '/instance-types?currency=eur',
2741
json=[
2842
{
2943
'id': TYPE_ID,
44+
'best_for': BEST_FOR,
3045
'cpu': {
3146
'description': CPU_DESCRIPTION,
3247
'number_of_cores': NUMBER_OF_CORES,
3348
},
49+
'deploy_warning': 'Use updated drivers',
3450
'gpu': {
3551
'description': GPU_DESCRIPTION,
3652
'number_of_gpus': NUMBER_OF_GPUS,
@@ -48,9 +64,19 @@ def test_instance_types(http_client):
4864
'size_in_gigabytes': STORAGE_SIZE,
4965
},
5066
'description': INSTANCE_TYPE_DESCRIPTION,
67+
'model': MODEL,
68+
'name': NAME,
69+
'p2p': P2P,
5170
'price_per_hour': '5.00',
5271
'spot_price': '2.50',
72+
'max_dynamic_price': '7.50',
73+
'serverless_price': '1.25',
74+
'serverless_spot_price': '0.75',
5375
'instance_type': INSTANCE_TYPE,
76+
'currency': CURRENCY,
77+
'manufacturer': MANUFACTURER,
78+
'display_name': DISPLAY_NAME,
79+
'supported_os': SUPPORTED_OS,
5480
}
5581
],
5682
status=200,
@@ -59,7 +85,7 @@ def test_instance_types(http_client):
5985
instance_types_service = InstanceTypesService(http_client)
6086

6187
# act
62-
instance_types = instance_types_service.get()
88+
instance_types = instance_types_service.get(currency='eur')
6389
instance_type = instance_types[0]
6490

6591
# assert
@@ -71,6 +97,18 @@ def test_instance_types(http_client):
7197
assert instance_type.price_per_hour == PRICE_PER_HOUR
7298
assert instance_type.spot_price_per_hour == SPOT_PRICE_PER_HOUR
7399
assert instance_type.instance_type == INSTANCE_TYPE
100+
assert instance_type.best_for == BEST_FOR
101+
assert instance_type.model == MODEL
102+
assert instance_type.name == NAME
103+
assert instance_type.p2p == P2P
104+
assert instance_type.currency == CURRENCY
105+
assert instance_type.manufacturer == MANUFACTURER
106+
assert instance_type.display_name == DISPLAY_NAME
107+
assert instance_type.supported_os == SUPPORTED_OS
108+
assert instance_type.deploy_warning == 'Use updated drivers'
109+
assert instance_type.max_dynamic_price == MAX_DYNAMIC_PRICE
110+
assert instance_type.serverless_price == SERVERLESS_PRICE
111+
assert instance_type.serverless_spot_price == SERVERLESS_SPOT_PRICE
74112
assert isinstance(instance_type.cpu, dict)
75113
assert isinstance(instance_type.gpu, dict)
76114
assert isinstance(instance_type.memory, dict)
@@ -85,3 +123,22 @@ def test_instance_types(http_client):
85123
assert instance_type.memory['size_in_gigabytes'] == MEMORY_SIZE
86124
assert instance_type.gpu_memory['size_in_gigabytes'] == GPU_MEMORY_SIZE
87125
assert instance_type.storage['size_in_gigabytes'] == STORAGE_SIZE
126+
127+
128+
@responses.activate
129+
def test_instance_type_price_history(http_client):
130+
# arrange - add response mock
131+
responses.add(
132+
responses.GET,
133+
http_client._base_url + '/instance-types/price-history',
134+
json=PRICE_HISTORY_PAYLOAD,
135+
status=200,
136+
)
137+
138+
instance_types_service = InstanceTypesService(http_client)
139+
140+
# act
141+
price_history = instance_types_service.get_price_history()
142+
143+
# assert
144+
assert price_history == PRICE_HISTORY_PAYLOAD
Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
from dataclasses import dataclass
2+
from typing import Literal
23

34
from dataclasses_json import dataclass_json
45

56
INSTANCE_TYPES_ENDPOINT = '/instance-types'
67

8+
Currency = Literal['usd', 'eur']
9+
710

811
@dataclass_json
912
@dataclass
1013
class InstanceType:
1114
"""Instance type.
1215
1316
Attributes:
14-
id: instance type id.
15-
instance_type: instance type, e.g. '8V100.48M'.
16-
price_per_hour: instance type price per hour.
17-
spot_price_per_hour: instance type spot price per hour.
18-
description: instance type description.
19-
cpu: instance type cpu details.
20-
gpu: instance type gpu details.
21-
memory: instance type memory details.
22-
gpu_memory: instance type gpu memory details.
23-
storage: instance type storage details.
17+
id: Instance type ID.
18+
instance_type: Instance type, e.g. '8V100.48M'.
19+
price_per_hour: Instance type price per hour.
20+
spot_price_per_hour: Instance type spot price per hour.
21+
description: Instance type description.
22+
cpu: Instance type CPU details.
23+
gpu: Instance type GPU details.
24+
memory: Instance type memory details.
25+
gpu_memory: Instance type GPU memory details.
26+
storage: Instance type storage details.
2427
"""
2528

2629
id: str
@@ -33,6 +36,19 @@ class InstanceType:
3336
memory: dict
3437
gpu_memory: dict
3538
storage: dict
39+
best_for: list[str]
40+
model: str
41+
name: str
42+
p2p: str
43+
currency: Currency
44+
manufacturer: str
45+
display_name: str
46+
supported_os: list[str]
47+
deploy_warning: str | None = None
48+
dynamic_price: float | None = None
49+
max_dynamic_price: float | None = None
50+
serverless_price: float | None = None
51+
serverless_spot_price: float | None = None
3652

3753

3854
class InstanceTypesService:
@@ -41,13 +57,16 @@ class InstanceTypesService:
4157
def __init__(self, http_client) -> None:
4258
self._http_client = http_client
4359

44-
def get(self) -> list[InstanceType]:
60+
def get(self, currency: Currency = 'usd') -> list[InstanceType]:
4561
"""Get all instance types.
4662
4763
:return: list of instance type objects
4864
:rtype: list[InstanceType]
4965
"""
50-
instance_types = self._http_client.get(INSTANCE_TYPES_ENDPOINT).json()
66+
instance_types = self._http_client.get(
67+
INSTANCE_TYPES_ENDPOINT,
68+
params={'currency': currency},
69+
).json()
5170
instance_type_objects = [
5271
InstanceType(
5372
id=instance_type['id'],
@@ -60,8 +79,41 @@ def get(self) -> list[InstanceType]:
6079
memory=instance_type['memory'],
6180
gpu_memory=instance_type['gpu_memory'],
6281
storage=instance_type['storage'],
82+
best_for=instance_type['best_for'],
83+
model=instance_type['model'],
84+
name=instance_type['name'],
85+
p2p=instance_type['p2p'],
86+
currency=instance_type['currency'],
87+
manufacturer=instance_type['manufacturer'],
88+
display_name=instance_type['display_name'],
89+
supported_os=instance_type['supported_os'],
90+
deploy_warning=instance_type.get('deploy_warning'),
91+
dynamic_price=(
92+
float(instance_type['dynamic_price'])
93+
if instance_type.get('dynamic_price') is not None
94+
else None
95+
),
96+
max_dynamic_price=(
97+
float(instance_type['max_dynamic_price'])
98+
if instance_type.get('max_dynamic_price') is not None
99+
else None
100+
),
101+
serverless_price=(
102+
float(instance_type['serverless_price'])
103+
if instance_type.get('serverless_price') is not None
104+
else None
105+
),
106+
serverless_spot_price=(
107+
float(instance_type['serverless_spot_price'])
108+
if instance_type.get('serverless_spot_price') is not None
109+
else None
110+
),
63111
)
64112
for instance_type in instance_types
65113
]
66114

67115
return instance_type_objects
116+
117+
def get_price_history(self):
118+
"""Get the deprecated dynamic price history endpoint as raw JSON."""
119+
return self._http_client.get(f'{INSTANCE_TYPES_ENDPOINT}/price-history').json()

0 commit comments

Comments
 (0)