-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathtensor.py
More file actions
1804 lines (1444 loc) · 49.8 KB
/
tensor.py
File metadata and controls
1804 lines (1444 loc) · 49.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
"""
Example usage::
import numpy as np
from singa import tensor
from singa import device
# create a tensor with shape (2,3), default CppCPU device and float32
x = tensor.Tensor((2, 3))
x.set_value(0.4)
# create a tensor from a numpy array
npy = np.zeros((3, 3), dtype=np.float32)
y = tensor.from_numpy(npy)
y.uniform(-1, 1) # sample values from the uniform distribution
z = tensor.mult(x, y) # gemm -> z of shape (2, 3)
x += z # element-wise addition
dev = device.get_default_device()
x.to_device(dev) # move the data to a gpu device
s = tensor.to_numpy(x) # tensor -> numpy array
There are two sets of tensor functions,
Tensor member functions
which would change the internal state of the Tensor instance.
Tensor module functions
which accept Tensor instances as arguments and return Tensor instances.
Every Tesor instance must be initialized before reading data from it.
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from deprecated import deprecated
from builtins import object
import numpy as np
from functools import reduce
import re
from . import singa_wrap as singa
from .device import get_default_device
int32 = 2 #core.proto.kInt32
float32 = 0 #core.proto.kFloat32
CTensor = singa.Tensor
class Tensor(object):
'''Python Tensor, which wraps a swig converted Tensor from CPP Tensor.
Args:
shape (tuple<int>): a tuple of integers for the tensor shape. If shape
is not specified, the created tensor is called a dummy tensor.
device: a swig device. If None, the default host device is used.
dtype: data type. currently, most operations only accept float32.
data: a numpy array or swig tensor.
requires_grad: boolean indicator for computing the gradient.
stores_grad: boolean indicator for storing and returning the gradient.
Some intermediate tensors' gradient can be released
during the backward propagation. A tensor may require
grad but not store grad; But if a tensor stores grad
then it must require grad.
'''
tensor_count = 0
def __init__(self,
shape=(),
device=None,
dtype=float32,
data=None,
requires_grad=True,
stores_grad=False,
creator=None,
name=None):
if device is None:
device = get_default_device()
if isinstance(data, np.ndarray):
self.data = CTensor(list(data.shape), device, dtype)
copy_from_numpy(self.data, data)
elif isinstance(data, CTensor):
self.data = data
assert data.device().id() == device.id(), 'not the same device'
else:
self.data = CTensor(list(shape), device, dtype)
self.shape = tuple(self.data.shape())
self.device = device
self.dtype = self.data.data_type()
self.requires_grad = requires_grad
self.stores_grad = stores_grad
if name is None:
self.name = 'Dummy#{}'.format(Tensor.tensor_count)
Tensor.tensor_count += 1
else:
self.name = name
if creator is None:
from . import autograd
self.creator = autograd.Dummy(self, name)
else:
self.creator = creator
def __getitem__(self, keys):
if type(keys) != tuple:
keys = (keys,)
ret = self.clone()
axis_index = 0
for key in keys:
if type(key) == int:
key += self.shape[axis_index] if key < 0 else 0
if not (key >= 0 and key < self.shape[axis_index]):
raise ValueError("Invalid Index")
ret.data = singa.SliceOn(ret.data, key, key + 1, axis_index)
elif type(key) == slice:
start = key.start if key.start else 0
end = key.stop if key.stop else self.shape[axis_index]
start += self.shape[axis_index] if start < 0 else 0
end += self.shape[axis_index] if end < 0 else 0
if not (start >= 0 and start < end and
end <= self.shape[axis_index]):
raise ValueError("Invalid Index")
ret.data = singa.SliceOn(ret.data, start, end, axis_index)
else:
raise ValueError("Invalid Index")
axis_index += 1
return ret
def is_dummy(self):
'''
Returns:
True if the tensor is a dummy tensor
'''
match = re.match(r'Dummy#\d+', self.name)
if match:
return True
else:
return False
def ndim(self):
'''
Returns:
the number of dimensions of the tensor.
'''
return self.data.nDim()
def is_empty(self):
'''
Returns:
True if the tensor is empty according to its shape
'''
return self.ndim() == 0
def is_transpose(self):
'''
Returns:
True if the internal data is transposed; otherwise False.
'''
return self.data.transpose()
def transpose(self, axes=None):
''' To transpose the tensor
Args:
axes: axes to transpose
Returns:
new transposed tensor
'''
t = Tensor(self.shape, self.device, self.dtype)
if axes is None:
tshape = [self.shape[x] for x in range(len(t.shape))]
t.shape = tuple(tshape)
t.data = singa.DefaultTranspose(self.data)
else:
if (len(axes) != len(self.shape)):
raise ValueError('dimensions do not match')
tshape = [self.shape[x] for x in axes]
t.shape = tuple(tshape)
t.data = singa.Transpose(self.data, list(axes))
return t
def size(self): # TODO(wangwei) compute size
'''
Returns:
the number of elements of the tensor.
'''
return self.data.Size()
def memsize(self):
'''
Returns:
the number of Bytes allocated for this tensor.
'''
return self.data.MemSize()
def contiguous(self):
t = Tensor(self.shape, self.device, self.dtype)
t.data = singa.Contiguous(self.data)
return t
def reshape(self, shape):
'''Return a new tensor with the given shape, and the original
tensor is not changed.
Args:
shape (list<int>): new shape, which should have the same
volumn as the original shape.
Returns:
new tensor reshaped
'''
t = Tensor(self.shape, self.device, self.dtype)
assert product(self.shape) == product(shape), \
'product of shape should be equal'
t.shape = shape
t.data = singa.Reshape(self.data, shape)
return t
def reset_like(self, t):
'''Reset the shape, dtype and device as the given tensor.
Args:
t (Tensor): a tensor
'''
self.data.ResetLike(t.data)
self.shape = t.shape
self.device = t.device
self.dtype = t.dtype
def as_type(self, dtype):
'''Change the data type.
Args:
dtype: accepts 'int', 'float', 'singa.kFloat32', 'singa.kInt'
Returns:
new tensor with new type
'''
if dtype == singa.kInt:
pass
elif dtype == singa.kFloat32:
pass
elif dtype == 'int':
dtype = singa.kInt
elif dtype == 'float':
dtype = singa.kFloat32
else:
raise TypeError("invalid data type %s" % dtype)
t = Tensor(self.shape, self.device, dtype)
t.data = self.data.AsType(dtype)
return t
def to_device(self, device):
'''Move the tensor data onto a given device.
Args:
device: a swig Device converted from CudaGPU or CppCPU or OpenclGPU
'''
self.data.ToDevice(device)
self.device = device
def to_host(self):
'''Move the tensor data onto the default host CppCPU device.
'''
self.data.ToHost()
self.device = get_default_device()
def l2(self):
'''
Returns:
the L2 norm.
'''
return self.data.L2()
def l1(self):
'''
Returns:
the L1 norm.
'''
return self.data.L1()
def set_value(self, x, inplace=True):
'''Set all elements of the tensor to be the give value.
Args:
x (float): a float value to be set to all elements.
inplace: inplace flag
Returns:
this tensor
'''
# assert type(x) == float, 'set value only accepts float input'
# if isinstance(x, float):
if not inplace:
# return new tensor filled with value
raise NotImplementedError
self.data.SetFloatValue(float(x))
return self
def copy_from_numpy(self, np_array, offset=0):
''' Copy the data from the numpy array.
Args:
np_array: source numpy array
offset (int): destination offset
'''
assert np_array.size == self.size(), 'tensor shape should be the same'
if not np_array.ndim == 1:
np_array = np_array.flatten()
dt = np_array.dtype
if dt == np.float32:
self.data.CopyFloatDataFromHostPtr(np_array)
elif dt == int or dt == np.int32:
self.data.CopyIntDataFromHostPtr(np_array)
else:
print('Not implemented yet for ', dt)
def copy_data(self, t):
'''Copy data from other Tensor instance.
Args:
t (Tensor): source Tensor.
'''
assert (t.size() == self.size()), "tensor shape should be the same"
assert isinstance(t, Tensor), 't must be a singa Tensor instance'
self.data.CopyData(t.data)
def copy_from(self, t, offset=0):
''' Copy the data from the numpy array or other Tensor instance
Args:
t (Tensor or np array): source Tensor or numpy array
offset (int): destination offset
'''
if isinstance(t, Tensor):
self.copy_data(t)
elif isinstance(t, np.ndarray):
self.copy_from_numpy(t)
else:
raise ValueError("t should be Tensor or numpy array.")
def clone(self):
'''
Returns:
a new Tensor which does deep copy of this tensor
'''
return _call_singa_func(self.data.Clone)
def repeat(self, repeats, axis):
'''Repeat data of a tensor
Args:
repeats(int or a sequence): the number that the tensor need to repeat for
axis (int):the axis to do repeat
If it is None, then the repeated tensor will be flattened.If it isn't None,
the repeats could be sequence, but it's size should match the axis's shape
Returns:
the tensor which has been repeated
'''
t = Tensor()
t_ndim = self.ndim()
if isinstance(repeats, int) or isinstance(repeats, complex):
if repeats < 0:
raise ValueError(
"'repeats' should not be negative: {}".format(repeats))
if axis != None and axis < 0:
axis += t_ndim
# broadcast = True
if axis is None:
axis = 9999
t.shape = (product(self.shape) * repeats,)
Repeats = [
repeats,
]
t.data = self.data.Repeat(Repeats, axis)
elif axis >= 0:
t_shape = list(self.shape)
t_shape[axis] = self.shape[axis] * repeats
t.shape = tuple(t_shape)
Repeats = [
repeats,
]
t.data = self.data.Repeat(Repeats, axis)
elif isinstance(repeats, tuple) or isinstance(repeats, list):
for rep in repeats:
if rep < 0:
raise ValueError(
"'repeats' should be int or sequence: {}".format(
repeats))
if axis != None and axis < 0:
axis += t_ndim
if axis is None:
raise ValueError(
"when axis us None, 'repeats' should be int: {}".format(
repeats))
elif axis >= 0:
t_shape = list(self.shape)
t_shape[axis] = sum(repeats)
t.shape = tuple(t_shape)
t.data = self.data.Repeat(list(repeats), axis)
else:
raise ValueError('repeats should be int or sequence')
return t
def T(self):
''' shallow copy.
Returns:
a new Tensor which shares the underlying data memory (shallow copy).
'''
return _call_singa_func(singa.DefaultTranspose, self.data)
def copy(self):
'''shallow copy calls copy constructor of singa::Tensor
Returns:
new tensor copied
'''
return _call_singa_func(CTensor, self.data)
def deepcopy(self):
'''Same as clone().
Returns:
a new Tensor
'''
return self.clone()
def bernoulli(self, p, inplace=True):
'''Sample 0/1 for each element according to the given probability.
Args:
p (float): with probability p, each element is sample to 1.
inplace: inplace flag
Returns:
this tensor
'''
if not inplace:
# return new tensor
raise NotImplementedError
singa.Bernoulli(float(p), self.data)
return self
def gaussian(self, mean, std, inplace=True):
'''Generate a value for each element following a Gaussian distribution.
Args:
mean (float): mean of the distribution
std (float): standard variance of the distribution
inplace: inplace flag
Returns:
this tensor
'''
if not inplace:
# return new tensor
raise NotImplementedError
singa.Gaussian(float(mean), float(std), self.data)
return self
def uniform(self, low, high, inplace=True):
'''Generate a value for each element following a uniform distribution.
Args:
low (float): the lower bound
high (float): the hight bound
inplace: inplace flag
Returns:
this tensor
'''
if not inplace:
# return new tensor
raise NotImplementedError
singa.Uniform(float(low), float(high), self.data)
return self
@deprecated(reason="use broadcast instead")
def add_column(self, v):
'''(DEPRECATED, use broadcast)Add a tensor to each column of this tensor.
Args:
v (Tensor): a Tensor to be added as a column to this tensor.
'''
singa.AddColumn(v.data, self.data)
@deprecated(reason="use broadcast instead")
def add_row(self, v):
'''(DEPRECATED, use broadcast)Add a tensor to each row of this tensor.
Args:
v (Tensor): a Tensor to be added as a row to this tensor.
'''
singa.AddRow(v.data, self.data)
@deprecated(reason="use broadcast instead")
def div_column(self, v):
'''(DEPRECATED, use broadcast)Divide each column of this tensor by v.
Args:
v (Tensor): 1d tensor of the same length the column of self.
'''
singa.DivColumn(v.data, self.data)
@deprecated(reason="use broadcast instead")
def div_row(self, v):
'''(DEPRECATED, use broadcast)Divide each row of this tensor by v.
Args:
v (Tensor): 1d tensor of the same length the row of self.
'''
singa.DivRow(v.data, self.data)
@deprecated(reason="use broadcast instead")
def mult_column(self, v):
'''(DEPRECATED, use broadcast)Multiply each column of this tensor by v element-wisely.
Args:
v (Tensor): 1d tensor of the same length the column of self.
'''
singa.MultColumn(v.data, self.data)
@deprecated(reason="use broadcast instead")
def mult_row(self, v):
'''(DEPRECATED, use broadcast)Multiply each row of this tensor by v element-wisely.
Args:
v (Tensor): 1d tensor of the same length the row of self.
'''
singa.MultRow(v.data, self.data)
'''
python operators (+=, -=, *=, /=) for singa::Tensor unary operators
'''
def __iadd__(self, x):
''' inplace element-wise addition with a tensor or a float value.
Args:
x (float or Tensor): input value
Returns:
this tensor
'''
if isinstance(x, Tensor):
self.data += x.data
else:
self.data += float(x)
return self
def __isub__(self, x):
''' inplace element-wise subtraction with a tensor or a float value.
Args:
x (float or Tensor): input value
Returns:
this tensor
'''
if isinstance(x, Tensor):
self.data -= x.data
else:
self.data -= float(x)
return self
def __imul__(self, x):
''' inplace element-wise multiplication with a tensor or a float value.
Args:
x (float or Tensor): input value
Returns:
this tensor
'''
if isinstance(x, Tensor):
self.data *= x.data
else:
self.data *= float(x)
return self
def __itruediv__(self, x):
''' inplace element-wise division by a tensor or a float value.
Args:
x (float or Tensor): input value
Returns:
this tensor
'''
if isinstance(x, Tensor):
self.data /= x.data
else:
self.data /= float(x)
return self
'''
python operators (+, -, *, /, <, <=, >, >=) for singa binary operators
https://docs.python.org/2/library/operator.html#mapping-operators-to-functions
'''
def __add__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__add__(self.data, rhs.data))
else:
return _call_singa_func(singa.AddFloat, self.data, rhs)
def __sub__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__sub__(self.data, rhs.data))
else:
return _call_singa_func(singa.SubFloat, self.data, rhs)
def __mul__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__mul__(self.data, rhs.data))
else:
return _call_singa_func(singa.MultFloat, self.data, rhs)
def __div__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__div__(self.data, rhs.data))
else:
return _call_singa_func(singa.DivFloat, self.data, rhs)
def __truediv__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__div__(self.data, rhs.data))
else:
return _call_singa_func(singa.DivFloat, self.data, rhs)
def __floordiv__(self, rhs):
if isinstance(rhs, Tensor):
tmp = from_raw_tensor(singa.__div__(self.data, rhs.data))
return _call_singa_func(singa.Floor, tmp.data)
else:
tmp = _call_singa_func(singa.DivFloat, self.data, rhs)
return _call_singa_func(singa.Floor, tmp.data)
def __lt__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__lt__(self.data, rhs.data))
else:
return _call_singa_func(singa.LTFloat, self.data, rhs)
def __le__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__le__(self.data, rhs.data))
else:
return _call_singa_func(singa.LEFloat, self.data, rhs)
def __gt__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__gt__(self.data, rhs.data))
else:
return _call_singa_func(singa.GTFloat, self.data, rhs)
def __ge__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__ge__(self.data, rhs.data))
else:
return _call_singa_func(singa.GEFloat, self.data, rhs)
def __eq__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__eq__(self.data, rhs.data))
elif rhs is None:
return False
else:
return _call_singa_func(singa.EQFloat, self.data, rhs)
def __radd__(self, lhs):
lhs = float(lhs)
one = Tensor(self.shape, self.device, self.dtype)
one.set_value(lhs)
one += self
return one
def __rsub__(self, lhs):
lhs = float(lhs)
one = Tensor(self.shape, self.device, self.dtype)
one.set_value(lhs)
one -= self
return one
def __rmul__(self, lhs):
lhs = float(lhs)
one = Tensor(self.shape, self.device, self.dtype)
one.set_value(lhs)
one *= self
return one
def __rdiv__(self, lhs):
lhs = float(lhs)
one = Tensor(self.shape, self.device, self.dtype)
one.set_value(lhs)
one /= self
return one
def __rtruediv__(self, lhs):
lhs = float(lhs)
one = Tensor(self.shape, self.device, self.dtype)
one.set_value(lhs)
one /= self
return one
def __repr__(self):
return np.array2string(to_numpy(self))
''' alias Tensor to PlaceHolder
'''
PlaceHolder = Tensor
''' python functions for global functions in Tensor.h
'''
def from_raw_tensor(t):
x = Tensor(t.shape(), t.device(), t.data_type())
x.data = t
return x
def from_raw_tensors(tt):
ret = []
for t in list(tt):
ret.append(from_raw_tensor(t))
return ret
def zeros_like(t):
ret = Tensor(t.shape, t.device, t.dtype)
ret.set_value(float(0))
return ret
def ones_like(t):
ret = Tensor(t.shape, t.device, t.dtype)
ret.set_value(float(1))
return ret
def product(shape):
return reduce(lambda x, y: x * y, shape)
def sizeof(dtype):
'''Get size of datatype
Args:
dtype: singa datatype
Returns:
the number of bytes of the given SINGA data type defined in core.proto
'''
return singa.SizeOf(dtype)
def contiguous(tensor):
return _call_singa_func(singa.Contiguous, tensor.data)
def reshape(tensor, shape):
'''Reshape the input tensor with the given shape and
the original tensor is not changed
Args:
tensor (Tensor): the tensor to be changed
shape (list<int>): the new shape, which should have the same volumn as the
old shape.
Returns:
the new Tensor
'''
return _call_singa_func(singa.Reshape, tensor.data, shape)
def transpose(t, axes=None):
'''To transpose the tensor
Args:
t: input tensor
axes: axes to transpose
Returns:
the transposed tensor
'''
ret = t.transpose(axes)
return ret
def copy_data_to_from(dst, src, size, dst_offset=0, src_offset=0):
'''Copy the data between two Tensor instances which could be on different
devices.
Args:
dst (Tensor): destination Tensor
src (Tensor): source Tensor
size (int) : number of elements to copy
dst_offset (int): offset in terms of elements to the start of dst
src_offset (int): offset in terms of elements to the start of src
'''
singa.CopyDataToFrom(dst.data, src.data, size, dst_offset, src_offset)
def from_numpy(np_array, dev=None):
'''Create a Tensor instance with the shape, dtype and values from the numpy
array.
Args:
np_array: the numpy array.
Returns:
A Tensor instance allocated on the default CppCPU device.
'''
assert type(np_array) is np.ndarray, 'Must input numpy array'
# convert to float32 array
if np_array.dtype == np.float64:
np_array = np_array.astype(np.float32)
if np_array.dtype == np.int64 or np_array.dtype == int:
np_array = np_array.astype(np.int32)
if np_array.dtype == np.float32:
dtype = float32
else:
assert np_array.dtype == np.int32, \
'Only float and int tensors are supported'
dtype = int32
ret = Tensor(np_array.shape, dtype=dtype)
ret.copy_from_numpy(np_array)
if dev:
ret.to_device(dev)
return ret
def to_host(t):
'''Copy the data to a host tensor.
Args:
t (Tensor): a Tensor
Returns:
new Tensor at host
'''
ret = t.clone()
ret.to_host()
return ret
def to_numpy(t):
'''Copy the tensor into a numpy array.
Args:
t (Tensor): a Tensor
Returns:
a numpy array
'''
th = to_host(t)
if th.dtype == float32:
np_array = th.data.GetFloatValue(int(th.size()))
elif th.dtype == int32:
np_array = th.data.GetIntValue(int(th.size()))
else:
print('Not implemented yet for ', th.dtype)
return np_array.reshape(th.shape)
def abs(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = abs(x), x is an element of t
'''
return _call_singa_func(singa.Abs, t.data)
def exp(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = exp(x), x is an element of t
'''
return _call_singa_func(singa.Exp, t.data)
def ceil(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = ceil(x), x is an element of t
'''
return _call_singa_func(singa.Ceil, t.data)
def log(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = log(x), x is an element of t
'''
return _call_singa_func(singa.Log, t.data)
def sigmoid(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = sigmoid(x); x is an element of t
'''
return _call_singa_func(singa.Sigmoid, t.data)
def sign(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = sign(x)
'''
return _call_singa_func(singa.Sign, t.data)
def sqrt(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = sqrt(x), x is an element of t
'''
return _call_singa_func(singa.Sqrt, t.data)
def square(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = x * x, x is an element of t
'''
return _call_singa_func(singa.Square, t.data)
def tanh(t):
'''
Args: