Skip to content

Commit 1027f66

Browse files
author
Oscar Torreno
authored
Adding join function to the array class (#49)
1) Fix in the to_arrayfire and the from_arrayfire functions 2) Adding join function to the array class 3) Increasing the version
1 parent 27e5609 commit 1027f66

4 files changed

Lines changed: 46 additions & 9 deletions

File tree

CHANGES.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ Changelog
44

55
KHIVA uses `Semantic Versioning <http://semver.org/>`_
66

7+
Version 0.2.2
8+
=============
9+
10+
Added
11+
- Join function for KHIVA arrays.
12+
13+
Fixed
14+
- to_arrayfire and from_arrayfire functions.
15+
- KShape for double precision data.
16+
717
Version 0.2.0
818
=============
919

khiva/array.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010
########################################################################################################################
1111
# IMPORT
1212
########################################################################################################################
13-
import numpy as np
13+
1414
import ctypes
15+
import logging
16+
import sys
1517
from collections import deque
16-
from khiva.library import KhivaLibrary
1718
from enum import Enum
19+
20+
import numpy as np
1821
import pandas as pd
19-
import logging
20-
import sys
22+
23+
from khiva.library import KhivaLibrary
2124

2225

2326
########################################################################################################################
@@ -159,7 +162,8 @@ def from_arrayfire(cls, arrayfire):
159162
"""
160163
result = ctypes.c_void_p(0)
161164
KhivaLibrary().c_khiva_library.from_arrayfire(ctypes.pointer(arrayfire.arr), ctypes.pointer(result))
162-
return cls(array_reference=result, arrayfire_reference=True)
165+
arrayfire.arr.value = 0
166+
return cls(array_reference=result, arrayfire_reference=False)
163167

164168
def _create_array(self, data):
165169
""" Creates the KHIVA array in the device.
@@ -306,10 +310,24 @@ def to_pandas(self):
306310

307311
def display(self):
308312
"""
309-
Dispays the data stored in the KHIVA array.
313+
Displays the data stored in the KHIVA array.
310314
"""
311315
KhivaLibrary().c_khiva_library.display(ctypes.pointer(self.arr_reference))
312316

317+
def join(self, dim, other):
318+
"""
319+
Joins the first and second KHIVA arrays along the specified dimension.
320+
:param dim: The dimension along which the join occurs.
321+
:param other: The second input array.
322+
:return: KHIVA Array with the result of this operation.
323+
"""
324+
result = ctypes.c_void_p(0)
325+
KhivaLibrary().c_khiva_library.join(ctypes.pointer(ctypes.c_int(dim)),
326+
ctypes.pointer(self.arr_reference),
327+
ctypes.pointer(other.arr_reference),
328+
ctypes.pointer(result))
329+
return Array(array_reference=result)
330+
313331
def __len__(self):
314332
"""
315333
Return the length.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
author="Shapelets.io",
2222
author_email="dev@shapelets.io",
2323
name="khiva",
24-
version='0.2.0',
24+
version='0.2.2',
2525
long_description = LONG_DESC,
2626
description="Python bindings for khiva",
2727
license="MPL 2.0",

tests/unit_tests/array_unit_tests.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# IMPORT
1313
########################################################################################################################
1414
import unittest
15+
16+
import arrayfire as af
1517
import numpy as np
1618
import pandas as pd
19+
1720
from khiva.array import Array, dtype
18-
import arrayfire as af
1921
from khiva.library import set_backend, KHIVABackend
2022

2123

@@ -120,6 +122,12 @@ def test_get_type(self):
120122
expected = dtype.s64
121123
self.assertEqual(a.get_type(), expected)
122124

125+
def test_join(self):
126+
a = Array([1, 2, 3, 4], khiva_type=dtype.f64)
127+
b = Array([5, 6, 7, 8], khiva_type=dtype.f64)
128+
c = a.join(0, b)
129+
np.testing.assert_array_equal(c.to_numpy(), np.array([1, 2, 3, 4, 5, 6, 7, 8]))
130+
123131
def testPlus(self):
124132
a = Array([1, 2, 3, 4])
125133
b = Array([1, 2, 3, 4])
@@ -276,8 +284,9 @@ def testCopy(self):
276284

277285
def testArrayfire(self):
278286
a = af.Array([1, 2, 3, 4])
287+
a_data = a.to_list()
279288
b = Array.from_arrayfire(a)
280-
np.testing.assert_array_equal(np.asarray(a.to_list()), np.asarray(b.to_list()))
289+
np.testing.assert_array_equal(np.asarray(a_data), np.asarray(b.to_list()))
281290

282291
def testFromPandas(self):
283292
df = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]])

0 commit comments

Comments
 (0)