Skip to content

Commit 95957a5

Browse files
committed
minor udpate
1 parent 2799f57 commit 95957a5

3 files changed

Lines changed: 12 additions & 5 deletions

File tree

ci/blackbox.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ parse_args() {
7474
--log=*) LOGFILE=${i#*=} ;;
7575
--nohup) TEMPBUILD=1 ;;
7676
--help) show_help; exit 0 ;;
77+
--*) echo "Invalid argument: $i"; show_usage; exit 1 ;;
7778
*) show_usage; exit 1 ;;
7879
esac
7980
done
@@ -140,6 +141,11 @@ main() {
140141
set_driver_path
141142
set_app_path
142143

144+
if [ $SAIF -eq 1 ] && [ "$DRIVER" = "simx" ]; then
145+
echo "Error: SAIF is not supported with the simx driver"
146+
exit 1
147+
fi
148+
143149
# execute on default installed GPU
144150
if [ "$DRIVER" = "gpu" ]; then
145151
run_app

kernel/include/vx_tensor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,8 @@ struct wgmma_context {
790790

791791
// Load A fragment (NRA=4 config)
792792
template <mem_layout src_layout = row_major, typename Frag>
793-
static __attribute__((always_inline)) void load_a_sync(Frag &dst, const void *src, size_t ldm) {
793+
static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) {
794+
stactic_assert(Frag::Use == matrix_a, "only matrix_a fragment can be loaded from registers in wgmma_context");
794795
ctx_a::template load_matrix_sync<src_layout>(dst, src, ldm);
795796
}
796797

tests/regression/sgemm_tcu_wg/kernel.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ __kernel void kernel_main(kernel_arg_t* __UNIFORM__ arg) {
5757
auto A_warp = A_smem + warp_rank * ctx::xtileM * ctx::tileK;
5858
auto desc_b = vt::vx_make_smem_desc(B_smem, ctx::xtileN * sizeof(ctx::input_t));
5959

60-
#if defined(WGMMA_RS)
60+
#if defined(WGMMA_RS)
6161
// RS: A from registers, B from smem
6262
ctx::fragment_a fragA;
63-
ctx::load_a_sync(fragA, A_warp, ctx::tileK);
63+
ctx::load_matrix_sync(fragA, A_warp, ctx::tileK);
6464
ctx::wgmma_sync(fragC, fragA, desc_b, fragC);
65-
#else
65+
#else
6666
// SS: both from smem
6767
auto desc_a = vt::vx_make_smem_desc(A_warp, ctx::tileK * sizeof(ctx::input_t));
6868
ctx::wgmma_sync(fragC, desc_a, desc_b, fragC);
69-
#endif
69+
#endif
7070

7171
__syncthreads();
7272
}

0 commit comments

Comments
 (0)