@@ -21,39 +21,76 @@ extern "C" {
2121#endif
2222
2323
24+ /* Data Structures */
25+ struct GpuReduction ;
26+ typedef struct GpuReduction GpuReduction ;
27+
28+
2429/**
2530 * Supported array reduction operations.
2631 */
2732
2833typedef enum _ga_reduce_op {
29- GA_REDUCE_SUM , /* + */
30- GA_REDUCE_PROD , /* * */
31- GA_REDUCE_PRODNZ , /* * (!=0) */
32- GA_REDUCE_MIN , /* min() */
33- GA_REDUCE_MAX , /* max() */
34- GA_REDUCE_ARGMIN , /* argmin() */
35- GA_REDUCE_ARGMAX , /* argmax() */
36- GA_REDUCE_MINANDARGMIN , /* min(), argmin() */
37- GA_REDUCE_MAXANDARGMAX , /* max(), argmax() */
38- GA_REDUCE_AND , /* & */
39- GA_REDUCE_OR , /* | */
40- GA_REDUCE_XOR , /* ^ */
41- GA_REDUCE_ALL , /* &&/all() */
42- GA_REDUCE_ANY , /* ||/any() */
34+ /* dst , dstArg */
35+ GA_REDUCE_SUM , /* + */
36+ GA_REDUCE_PROD , /* * */
37+ GA_REDUCE_PRODNZ , /* * (!=0) */
38+ GA_REDUCE_MIN , /* min() */
39+ GA_REDUCE_MAX , /* max() */
40+ GA_REDUCE_ARGMIN , /* argmin() */
41+ GA_REDUCE_ARGMAX , /* argmax() */
42+ GA_REDUCE_MINANDARGMIN , /* min() , argmin() */
43+ GA_REDUCE_MAXANDARGMAX , /* max() , argmax() */
44+ GA_REDUCE_AND , /* & */
45+ GA_REDUCE_OR , /* | */
46+ GA_REDUCE_XOR , /* ^ */
47+ GA_REDUCE_ALL , /* &&/all() */
48+ GA_REDUCE_ANY , /* ||/any() */
4349} ga_reduce_op ;
4450
4551
52+ /* External Functions */
4653
4754/**
48- * @brief Compute a reduction over a list of axes to reduce.
55+ * @brief Create a new GPU reduction operator over a list of axes to reduce.
56+ *
57+ * @param [out] gr The reduction operator.
58+ * @param [in] gpuCtx The GPU context.
59+ * @param [in] op The reduction operation to perform.
60+ * @param [in] ndf The minimum number of destination dimensions to support.
61+ * @param [in] ndr The minimum number of reduction dimensions to support.
62+ * @param [in] srcTypeCode The data type of the source operand.
63+ * @param [in] flags Reduction operator creation flags. Currently must be
64+ * set to 0.
65+ *
66+ * @return GA_NO_ERROR if the operator was created successfully, or a non-zero
67+ * error code otherwise.
68+ */
69+
70+ GPUARRAY_PUBLIC int GpuReduction_new (GpuReduction * * grOut ,
71+ gpucontext * gpuCtx ,
72+ ga_reduce_op op ,
73+ unsigned ndf ,
74+ unsigned ndr ,
75+ int srcTypeCode ,
76+ int flags );
77+
78+ /**
79+ * @brief Deallocate an operator allocated by GpuReduction_new().
80+ */
81+
82+ GPUARRAY_PUBLIC void GpuReduction_free (GpuReduction * gr );
83+
84+ /**
85+ * @brief Invoke an operator allocated by GpuReduction_new() on a source tensor.
4986 *
5087 * Returns one (in the case of min-and-argmin/max-and-argmax, two) destination
5188 * tensors. The destination tensor(s)' axes are a strict subset of the axes of the
5289 * source tensor. The axes to be reduced are specified by the caller, and the
5390 * reduction is performed over these axes, which are then removed in the
5491 * destination.
55- *
56- * @param [in] op The reduction operation to perform .
92+ *
93+ * @param [in] gr The reduction operator .
5794 * @param [out] dst The destination tensor. Has the same type as the source.
5895 * @param [out] dstArg For argument of minima/maxima operations. Has type int64.
5996 * @param [in] src The source tensor.
@@ -76,19 +113,20 @@ typedef enum _ga_reduce_op {
76113 *
77114 * where (i3,i4,i1) are the coordinates of the maximum-
78115 * valued element within subtensor [i0,:,i2,:,:] of src.
79- * @return GA_NO_ERROR if the operation was successful, or a non-zero error
80- * code otherwise.
116+ * @param [in] flags Reduction operator invocation flags. Currently must be
117+ * set to 0.
118+ *
119+ * @return GA_NO_ERROR if the operator was invoked successfully, or a non-zero
120+ * error code otherwise.
81121 */
82122
83- GPUARRAY_PUBLIC int GpuArray_reduction (ga_reduce_op op ,
84- GpuArray * dst ,
85- GpuArray * dstArg ,
86- const GpuArray * src ,
87- unsigned reduxLen ,
88- const unsigned * reduxList );
89-
90-
91-
123+ GPUARRAY_PUBLIC int GpuReduction_call (GpuReduction * gr ,
124+ GpuArray * dst ,
125+ GpuArray * dstArg ,
126+ const GpuArray * src ,
127+ unsigned reduxLen ,
128+ const int * reduxList ,
129+ int flags );
92130
93131
94132#ifdef __cplusplus
0 commit comments