-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_numpy_type.py
More file actions
126 lines (92 loc) · 3.14 KB
/
test_numpy_type.py
File metadata and controls
126 lines (92 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations
from pydantic import BaseModel, RootModel
import numpy as np
from fastapi.testclient import TestClient
from labthings_fastapi.testing import create_thing_without_server
from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict
import labthings_fastapi as lt
class ArrayModel(RootModel):
root: NDArray
def check_field_works_with_list(data):
class Model(BaseModel):
a: NDArray
m = Model(a=data)
assert isinstance(m.a, np.ndarray)
d = m.model_dump()
assert (d["a"] == data).all()
m.model_json_schema()
m.model_dump_json()
def check_field_works_with_ndarray(data):
class Model(BaseModel):
a: NDArray
m = Model(a=data)
assert isinstance(m.a, np.ndarray)
d = m.model_dump()
assert (d["a"] == data.tolist()).all()
m.model_json_schema()
m.model_dump_json()
def test_1d():
check_field_works_with_list([1])
check_field_works_with_list([1, 2, 3])
check_field_works_with_list(np.arange(10))
def test_3d():
check_field_works_with_list([[[1]]])
check_field_works_with_list([[[2]]])
def test_2d():
check_field_works_with_list([[1, 2]])
def test_0d():
class Model(BaseModel):
a: NDArray
m = Model(a=1)
assert m.a == 1
d = m.model_dump()
assert d["a"] == 1
m.model_json_schema()
m.model_dump_json()
class MyNumpyThing(lt.Thing):
"""A thing that uses numpy types."""
@lt.action
def action_with_arrays(self, a: NDArray) -> NDArray:
return a * 2
@lt.action
def read_array(self) -> NDArray:
return np.array([1, 2])
@lt.property
def array_property(self) -> NDArray:
return np.array([3, 4, 5])
def test_thing_description():
"""Make sure the TD validates when numpy types are used."""
thing = create_thing_without_server(MyNumpyThing)
assert thing.validate_thing_description() is None
def test_denumpifying_dict():
"""Check DenumpifyingDict converts arrays to lists."""
d = DenumpifyingDict(
root={
"a": np.array([1, 2, 3]),
"b": [np.arange(10), np.arange(10)],
"c": {"ca": np.array([1, 2, 3])},
"d": {"da": [np.arange(10), np.arange(10)]},
"e": None,
"f": 1,
}
)
dump = d.model_dump()
assert dump["a"] == [1, 2, 3]
assert dump["e"] is None
assert dump["f"] == 1
d.model_dump_json()
def test_rootmodel():
"""Check that RootModels with NDArray convert between array and list."""
for input in [[0, 1, 2], np.arange(3)]:
m = ArrayModel(root=input)
assert isinstance(m.root, np.ndarray)
assert (m.model_dump() == [0, 1, 2]).all()
def test_numpy_over_http():
"""Read numpy array over http."""
server = lt.ThingServer({"np_thing": MyNumpyThing})
with TestClient(server.app) as client:
np_thing_client = lt.ThingClient.from_url("/np_thing/", client=client)
arrayprop = np_thing_client.array_property
assert np.array_equal(np.asarray(arrayprop), np.array([3, 4, 5]))
array = np_thing_client.read_array()
assert np.array_equal(np.asarray(array), np.array([1, 2]))