We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 550ff94 commit 8adcc9eCopy full SHA for 8adcc9e
1 file changed
pufferlib/src/kernels.cu
@@ -26,7 +26,7 @@ typedef __nv_bfloat16 precision_t;
26
constexpr bool USE_BF16 = true;
27
constexpr int PRECISION_SIZE = 2;
28
static constexpr cudaDataType_t CUBLAS_PRECISION = CUDA_R_16BF;
29
-static constexpr cublasComputeType_t CUBLAS_COMPUTE_PRECISION = CUBLAS_COMPUTE_32F_FAST_16BF;
+static constexpr cublasComputeType_t CUBLAS_COMPUTE_PRECISION = CUBLAS_COMPUTE_32F; // Note: fast bf16 is not deterministic
30
#define NCCL_PRECISION ncclBfloat16
31
#define to_float(x) __bfloat162float(x)
32
#define from_float(x) __float2bfloat16(x)
0 commit comments