Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
cmake_minimum_required(VERSION 3.18)

if(UNIX)
if(EXISTS "/usr/bin/gcc-12")
set(SGEMM_DEFAULT_C_COMPILER "/usr/bin/gcc-12")
elseif(EXISTS "/usr/bin/cc")
set(SGEMM_DEFAULT_C_COMPILER "/usr/bin/cc")
endif()

if(EXISTS "/usr/bin/g++-12")
set(SGEMM_DEFAULT_CXX_COMPILER "/usr/bin/g++-12")
elseif(EXISTS "/usr/bin/c++")
set(SGEMM_DEFAULT_CXX_COMPILER "/usr/bin/c++")
endif()

if(NOT DEFINED CMAKE_C_COMPILER AND DEFINED SGEMM_DEFAULT_C_COMPILER)
set(CMAKE_C_COMPILER "${SGEMM_DEFAULT_C_COMPILER}")
endif()
if(NOT DEFINED CMAKE_CXX_COMPILER AND DEFINED SGEMM_DEFAULT_CXX_COMPILER)
set(CMAKE_CXX_COMPILER "${SGEMM_DEFAULT_CXX_COMPILER}")
endif()
if(NOT DEFINED CMAKE_CUDA_HOST_COMPILER AND DEFINED SGEMM_DEFAULT_CXX_COMPILER)
set(CMAKE_CUDA_HOST_COMPILER "${SGEMM_DEFAULT_CXX_COMPILER}")
endif()
endif()

project(sgemm_optimization
VERSION 2.1.0
DESCRIPTION "SGEMM optimization from naive to Tensor Core"
Expand All @@ -15,15 +40,26 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# ── CUDA 配置 ────────────────────────────────────────────────────
find_package(CUDAToolkit REQUIRED)
get_target_property(SGEMM_CUDART_LIBRARY CUDA::cudart IMPORTED_LOCATION)
get_filename_component(SGEMM_CUDA_LIBRARY_DIR "${SGEMM_CUDART_LIBRARY}" DIRECTORY)

if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
set(CMAKE_CUDA_ARCHITECTURES native)
else()
set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86 89 90)
endif()
# `native` has proven unreliable with some toolkit/driver combinations and
# can silently drop WMMA-capable codegen. Default to a portable set that
# keeps pre-Volta fallback builds working while still emitting Tensor Core
# code for modern GPUs.
set(CMAKE_CUDA_ARCHITECTURES 52 60 61 70 75 80 86 89 90)
endif()

set(SGEMM_HAS_WMMA_TARGET 0)
foreach(cuda_arch IN LISTS CMAKE_CUDA_ARCHITECTURES)
string(REGEX MATCH "^[0-9]+" SGEMM_CUDA_ARCH_NUMBER "${cuda_arch}")
if(SGEMM_CUDA_ARCH_NUMBER AND SGEMM_CUDA_ARCH_NUMBER GREATER_EQUAL 70)
set(SGEMM_HAS_WMMA_TARGET 1)
break()
endif()
endforeach()

# ── 输出目录 ─────────────────────────────────────────────────────
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

Expand All @@ -33,6 +69,7 @@ add_executable(sgemm_benchmark src/main.cu)
target_include_directories(sgemm_benchmark PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
target_compile_definitions(sgemm_benchmark PRIVATE SGEMM_HAS_WMMA_TARGET=${SGEMM_HAS_WMMA_TARGET})

target_link_libraries(sgemm_benchmark PRIVATE
CUDA::cudart
Expand Down Expand Up @@ -61,6 +98,8 @@ if(BUILD_TESTS)

add_executable(test_sgemm tests/test_sgemm.cu)
target_include_directories(test_sgemm PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_compile_definitions(test_sgemm PRIVATE SGEMM_HAS_WMMA_TARGET=${SGEMM_HAS_WMMA_TARGET})
target_link_options(test_sgemm PRIVATE -L${SGEMM_CUDA_LIBRARY_DIR})
target_link_libraries(test_sgemm PRIVATE
GTest::gtest_main
CUDA::cudart
Expand All @@ -74,6 +113,8 @@ if(BUILD_TESTS)
# 工具层测试
add_executable(test_utils tests/test_utils.cu)
target_include_directories(test_utils PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_compile_definitions(test_utils PRIVATE SGEMM_HAS_WMMA_TARGET=${SGEMM_HAS_WMMA_TARGET})
target_link_options(test_utils PRIVATE -L${SGEMM_CUDA_LIBRARY_DIR})
target_link_libraries(test_utils PRIVATE
GTest::gtest_main
CUDA::cudart
Expand All @@ -87,6 +128,8 @@ if(BUILD_TESTS)
# 性能回归测试
add_executable(test_performance tests/test_performance.cu)
target_include_directories(test_performance PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_compile_definitions(test_performance PRIVATE SGEMM_HAS_WMMA_TARGET=${SGEMM_HAS_WMMA_TARGET})
target_link_options(test_performance PRIVATE -L${SGEMM_CUDA_LIBRARY_DIR})
target_link_libraries(test_performance PRIVATE
GTest::gtest_main
CUDA::cudart
Expand Down
62 changes: 39 additions & 23 deletions src/kernels/tensor_core_sgemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include <cuda_runtime.h>
#include <functional>

#ifndef SGEMM_HAS_WMMA_TARGET
#define SGEMM_HAS_WMMA_TARGET 0
#endif

// ============================================================================
// WMMA Tile Dimensions
// ============================================================================
Expand All @@ -48,6 +52,7 @@ using tensor_core::WMMA_N;
* 检查当前设备是否支持 Tensor Core (sm_70+)
*/
inline bool tensorCoresAvailable() { return DeviceInfoCache::instance().hasTensorCores(); }
inline constexpr bool tensorCoreFastPathCompiled() { return SGEMM_HAS_WMMA_TARGET != 0; }

/**
* 检查给定维度是否适合 Tensor Core 加速
Expand Down Expand Up @@ -100,24 +105,9 @@ nullFallback(const float *, const float *, float *, int, int, int, cudaStream_t
// Tensor Core Compute - 纯 WMMA 计算路径
// ============================================================================

// WMMA is only available on sm_70+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
// WMMA is only emitted when the configured target architectures include sm_70+.
#if SGEMM_HAS_WMMA_TARGET
#include <mma.h>
#endif

/**
* FP32 → FP16 转换内核
*/
__global__ void float_to_half_kernel(const float *__restrict__ input, half *__restrict__ output,
int size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
output[idx] = __float2half(input[idx]);
}
}

// WMMA kernel is only available on sm_70+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700

/**
* 纯 Tensor Core WMMA 计算内核
Expand All @@ -132,6 +122,7 @@ __global__ void float_to_half_kernel(const float *__restrict__ input, half *__re
__global__ void tensor_core_sgemm_kernel_fp16(const half *__restrict__ A,
const half *__restrict__ B, float *__restrict__ C,
int M, int K, int N) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
int warpM = blockIdx.y;
int warpN = blockIdx.x;

Expand Down Expand Up @@ -159,6 +150,14 @@ __global__ void tensor_core_sgemm_kernel_fp16(const half *__restrict__ A,
}

nvcuda::wmma::store_matrix_sync(C + aRow * N + bCol, c_frag, N, nvcuda::wmma::mem_row_major);
#else
(void)A;
(void)B;
(void)C;
(void)M;
(void)K;
(void)N;
#endif
}

/**
Expand All @@ -179,13 +178,24 @@ inline void launch_tensor_core_sgemm_fp16_fast_path(const half *A, const half *B
}

#else
// Stub implementations for older architectures (will not be called)
// Stub implementations when no WMMA-capable target was configured.
inline void launch_tensor_core_sgemm_fp16_fast_path(const half *, const half *, float *, int, int,
int, cudaStream_t) {
// This function should never be called on pre-sm_70 GPUs
throw CudaError("Tensor Core fast path was not compiled for the configured CUDA architectures");
}
#endif

/**
* FP32 → FP16 转换内核
*/
__global__ void float_to_half_kernel(const float *__restrict__ input, half *__restrict__ output,
int size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
output[idx] = __float2half(input[idx]);
}
}

/**
* 纯 WMMA 计算路径入口(FP16 输入)
*
Expand All @@ -198,9 +208,13 @@ inline void launch_tensor_core_sgemm_fp16(const half *A, const half *B, float *C
return;
}

if (!tensorCoreFastPathCompiled()) {
throw CudaError("launch_tensor_core_sgemm_fp16 requires a build that targets sm_70+");
}

if (!tensorCoresAvailable() || !tensorCoreDimensionsSupported(M, K, N)) {
throw CudaError("launch_tensor_core_sgemm_fp16 requires sm_70+ and dimensions aligned "
"to 16");
throw CudaError(
"launch_tensor_core_sgemm_fp16 requires runtime sm_70+ support and dimensions aligned to 16");
}

launch_tensor_core_sgemm_fp16_fast_path(A, B, C, M, K, N, stream);
Expand Down Expand Up @@ -235,7 +249,8 @@ inline void launch_tensor_core_sgemm_with_fallback(const float *A, const float *
}

// Fallback 路径:设备或维度不支持 Tensor Core
if (!tensorCoresAvailable() || !tensorCoreDimensionsSupported(M, K, N)) {
if (!tensorCoreFastPathCompiled() || !tensorCoresAvailable() ||
!tensorCoreDimensionsSupported(M, K, N)) {
fallback(A, B, C, M, K, N, stream);
return;
}
Expand Down Expand Up @@ -283,7 +298,8 @@ inline void launch_tensor_core_sgemm_with_fallback(const float *A, const float *
inline void launch_tensor_core_sgemm_with_fallback(const float *A, const float *B, float *C, int M,
int K, int N, const FallbackKernel &fallback,
cudaStream_t stream = 0) {
launch_tensor_core_sgemm_with_fallback(A, B, C, M, K, N, fallback, stream);
launch_tensor_core_sgemm_with_fallback<const FallbackKernel &>(A, B, C, M, K, N, fallback,
stream);
}

// ============================================================================
Expand Down
4 changes: 2 additions & 2 deletions src/utils/benchmark_metrics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ inline float getTheoreticalPeakGflops() {
DeviceInfoCache &cache = DeviceInfoCache::instance();
const cudaDeviceProp &prop = cache.prop();

// 峰值 GFLOPS = SMs * cores/SM * 2 (FMA) * clock (GHz) * 1000 (MHz factor)
float peakGflops = prop.multiProcessorCount * cache.coresPerSM() * 2 * cache.clockGHz() * 1000;
// 峰值 GFLOPS = SMs * cores/SM * 2 (FMA) * clock (GHz)
float peakGflops = prop.multiProcessorCount * cache.coresPerSM() * 2 * cache.clockGHz();

return peakGflops;
}
Expand Down
22 changes: 22 additions & 0 deletions src/utils/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,30 @@ inline void initRandomMatrix(float *data, int rows, int cols, float min_val = -1
// Utility Functions
// ============================================================================

inline bool cudaDeviceAvailable() {
int device_count = 0;
cudaError_t status = cudaGetDeviceCount(&device_count);
if (status == cudaSuccess) {
return device_count > 0;
}

if (status == cudaErrorNoDevice || status == cudaErrorInsufficientDriver ||
status == cudaErrorInitializationError || status == cudaErrorSystemDriverMismatch) {
cudaGetLastError();
return false;
}

cudaGetLastError();
return false;
}

// Get GPU device properties
inline void printGPUInfo() {
if (!cudaDeviceAvailable()) {
printf("GPU Device: unavailable (no CUDA-capable device detected)\n\n");
return;
}

int device;
CUDA_CHECK(cudaGetDevice(&device));

Expand Down
36 changes: 36 additions & 0 deletions tests/gtest_cuda_environment.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <cstring>
#include <gtest/gtest.h>

#include "utils/cuda_utils.cuh"

inline bool isGtestListMode(int argc, char **argv) {
for (int i = 1; i < argc; ++i) {
if (std::strcmp(argv[i], "--gtest_list_tests") == 0) {
return true;
}
}
return false;
}

class CudaTestEnvironment : public ::testing::Environment {
public:
void SetUp() override {
if (!cudaDeviceAvailable()) {
GTEST_SKIP() << "No CUDA-capable device is detected";
}

printGPUInfo();
}
};

inline int runCudaAwareTests(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);

if (!isGtestListMode(argc, argv)) {
::testing::AddGlobalTestEnvironment(new CudaTestEnvironment());
}

return RUN_ALL_TESTS();
}
Loading
Loading