Skip to content

Commit ffd0204

Browse files
authored
Fix issue with numpy interop (#51)
Reading from an Array: - When an Array of real or complex numbers with more than two dimensions was benig converted to numpy this conversion didn't return an np.array with a proper shape. Creating an Array: - When an Array of real numbers with more than two dimensions was built. The shape in the device (arrayfire) was not correct. - For complex numbers it occurrs with more than one dimension.
1 parent 1c719c6 commit ffd0204

3 files changed

Lines changed: 48 additions & 27 deletions

File tree

khiva/array.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import ctypes
1515
import logging
1616
import sys
17-
from collections import deque
1817
from enum import Enum
1918

2019
import numpy as np
@@ -177,21 +176,24 @@ def _create_array(self, data):
177176
if isinstance(data, pd.DataFrame):
178177
data = data.values
179178
shape = np.array(data.shape)
180-
shape = shape[shape > 1]
181-
shape = deque(shape)
182-
shape.rotate(1)
179+
180+
if data.size > 1:
181+
trimmed_dims = shape
182+
for _ in range(0, 3):
183+
if trimmed_dims[-1] == 1:
184+
trimmed_dims = trimmed_dims[:-1]
185+
shape = trimmed_dims[::-1]
186+
else:
187+
shape = np.array([1])
188+
183189
c_array_n = (ctypes.c_longlong * len(shape))(*(np.array(shape)).astype(np.longlong))
184190
c_ndims = ctypes.c_uint(len(shape))
185191
c_complex = np.iscomplexobj(data)
186192

187193
if c_complex:
188-
data = np.array([data.real, data.imag])
189-
c = deque(range(1, len(data.shape)))
190-
c.rotate(1)
191-
c.append(0)
192-
array_joint = np.transpose(data, c).flatten()
193-
else:
194-
array_joint = data.flatten()
194+
data = np.dstack((data.real.flatten(), data.imag.flatten()))
195+
196+
array_joint = data.flatten()
195197

196198
c_array_joint = (_get_array_type(self.khiva_type.value) * len(array_joint))(
197199
*array_joint)
@@ -212,25 +214,22 @@ def _get_data(self):
212214
c_result_array = (_get_array_type(self.khiva_type.value) * self.result_l)(*initialized_result_array)
213215
KhivaLibrary().c_khiva_library.get_data(ctypes.pointer(self.arr_reference), ctypes.pointer(c_result_array))
214216

215-
dims = self.get_dims()
216-
if dims[dims > 1].size > 0:
217-
dims = dims[dims > 1]
218-
else:
219-
dims = np.array([1])
220-
221217
a = np.array(c_result_array)
222218

223219
if self._is_complex():
224220
a = np.array(np.split(a, self.result_l / 2))
225221
a = np.apply_along_axis(lambda args: [complex(*args)], 1, a)
226-
a = a.reshape(dims)
227-
c = deque(range(len(a.shape)))
228-
c.rotate(-1)
229-
a = np.transpose(a, c)
222+
223+
# Clean up the last n dimensions if these are equal to 1
224+
if a.size > 1:
225+
trimmed_dims = self.get_dims()
226+
for _ in range(0, 3):
227+
if trimmed_dims[-1] == 1:
228+
trimmed_dims = trimmed_dims[:-1]
230229
else:
231-
dims = deque(dims)
232-
dims.rotate(1)
233-
a = a.reshape(dims)
230+
trimmed_dims = np.array([1])
231+
232+
a = a.reshape(trimmed_dims[::-1])
234233

235234
a = a.astype(_get_numpy_type(self.khiva_type.value))
236235
return a

tests/unit_tests/array_unit_tests.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,43 @@ class ArrayTest(unittest.TestCase):
3030
def setUp(self):
3131
set_backend(KHIVABackend.KHIVA_BACKEND_CPU)
3232

33+
def test_real_1d_creation(self):
34+
a = Array([1, 5, 3, 1])
35+
np.testing.assert_array_equal(a.dims, np.array([4, 1, 1, 1]))
36+
37+
def test_single_value_creation(self):
38+
a = Array([1])
39+
np.testing.assert_array_equal(a.dims, np.array([1, 1, 1, 1]))
40+
3341
def test_real_1d(self):
3442
a = Array([1, 2, 3, 4, 5, 6, 7, 8])
3543
expected = np.array([1, 2, 3, 4, 5, 6, 7, 8])
3644
np.testing.assert_array_equal(a.to_numpy(), expected)
3745

46+
def test_real_2d_creation(self):
47+
a = Array([[1, 5, 3, 1], [2, 6, 9, 8], [3, 4, 1, 3]])
48+
np.testing.assert_array_equal(a.dims, np.array([4, 3, 1, 1]))
49+
3850
def test_real_2d(self):
3951
a = Array([[1, 2, 3, 4], [5, 6, 7, 8]])
4052
expected = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
4153
np.testing.assert_array_equal(a.to_numpy(), expected)
4254

55+
def test_real_3d_creation(self):
56+
a = Array([[[1, 5, 3, 1], [2, 6, 9, 8], [3, 4, 1, 3]],
57+
[[3, 7, 4, 2], [4, 8, 1, 9], [1, 5, 9, 2]]])
58+
np.testing.assert_array_equal(a.dims, np.array([4, 3, 2, 1]))
59+
4360
def test_real_3d(self):
4461
a = Array([[[1, 5], [2, 6]], [[3, 7], [4, 8]]])
4562
expected = np.array([[[1, 5], [2, 6]], [[3, 7], [4, 8]]])
4663
np.testing.assert_array_equal(a.to_numpy(), expected)
4764

65+
def test_real_3d_large_column(self):
66+
a = Array([[[1, 5, 3], [2, 6, 9]], [[3, 7, 4], [4, 8, 1]]])
67+
expected = np.array([[[1, 5, 3], [2, 6, 9]], [[3, 7, 4], [4, 8, 1]]])
68+
np.testing.assert_array_equal(a.to_numpy(), expected)
69+
4870
def test_real_4d(self):
4971
a = Array([[[[1, 9], [2, 10]], [[3, 11], [4, 12]]], [[[5, 13], [6, 14]], [[7, 15], [8, 16]]]])
5072
expected = np.array([[[[1, 9], [2, 10]], [[3, 11], [4, 12]]], [[[5, 13], [6, 14]], [[7, 15], [8, 16]]]])
@@ -254,7 +276,7 @@ def testCols(self):
254276
def testRow(self):
255277
a = Array(np.transpose([[1, 2], [3, 4]]), dtype.s32)
256278
c = a.get_row(0)
257-
np.testing.assert_array_equal(c.to_numpy(), [1, 2])
279+
np.testing.assert_array_equal(c.to_numpy(), np.transpose(np.array([[1, 2]])))
258280

259281
def testRows(self):
260282
a = Array(np.transpose([[1, 2], [3, 4], [5, 6]]), dtype.s32)

tests/unit_tests/matrix_unit_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def test_find_best_n_motifs_multiple_profiles(self):
6868
find_best_n_motifs_result = find_best_n_motifs(stomp_result[0], stomp_result[1], 3, 1)
6969
a = find_best_n_motifs_result[1].to_numpy()
7070
b = find_best_n_motifs_result[2].to_numpy()
71-
np.testing.assert_array_almost_equal(a, np.array([[12, 12], [12, 12]]), decimal=self.DECIMAL)
72-
np.testing.assert_array_almost_equal(b, np.array([[1, 1], [1, 1]]), decimal=self.DECIMAL)
71+
np.testing.assert_array_almost_equal(a, np.array([[[12], [12]], [[12], [12]]]), decimal=self.DECIMAL)
72+
np.testing.assert_array_almost_equal(b, np.array([[[1], [1]], [[1], [1]]]), decimal=self.DECIMAL)
7373

7474
def test_find_best_n_motifs_mirror(self):
7575
stomp_result = stomp_self_join(Array([10.1, 11, 10.2, 10.15, 10.775, 10.1, 11, 10.2], dtype.f32), 3)

0 commit comments

Comments
 (0)