forked from YdrMaster/operators
-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathrandom_sample.cuh
More file actions
40 lines (32 loc) · 1.49 KB
/
random_sample.cuh
File metadata and controls
40 lines (32 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef __CUDA_RANDOM_SAMPLE_H__
#define __CUDA_RANDOM_SAMPLE_H__
#include "../../../devices/cuda/cuda_handle.h"
#include "operators.h"
struct RandomSampleCudaDescriptor {
Device device;
int device_id;
DT dtype;
int voc;
DT rDtype;
int rLength;
int step;
};
typedef struct RandomSampleCudaDescriptor *RandomSampleCudaDescriptor_t;
infiniopStatus_t cudaCreateRandomSampleDescriptor(CudaHandle_t handle,
RandomSampleCudaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);
infiniopStatus_t random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
int voc, DT dtype);
infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, unsigned long int *size);
infiniopStatus_t cudaRandomSample(RandomSampleCudaDescriptor_t desc,
void *workspace,
uint64_t workspace_size,
void *result,
void const *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream);
infiniopStatus_t cudaDestroyRandomSampleDescriptor(RandomSampleCudaDescriptor_t desc);
#endif