@@ -422,57 +422,47 @@ class vx_device {
422422 return 0 ;
423423 }
424424
425- int start_wg (uint64_t krnl_addr, uint64_t args_addr, uint32_t dim, const uint32_t *grid_dim, const uint32_t *block_dim, uint32_t lmem_size) {
426- // set kernel info
427- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_STARTUP_ADDR0, krnl_addr & 0xffffffff ), {
428- return err;
429- });
430- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_STARTUP_ADDR1, krnl_addr >> 32 ), {
431- return err;
432- });
433- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_STARTUP_ARG0, args_addr & 0xffffffff ), {
434- return err;
435- });
436- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_STARTUP_ARG1, args_addr >> 32 ), {
437- return err;
438- });
425+ int start_wg (uint64_t krnl_addr, uint64_t args_addr, uint32_t ndim, const uint32_t *grid_dim, const uint32_t *block_dim, uint32_t lmem_size) {
426+ uint32_t eff_block_dim[3 ], block_size, warp_step_x, warp_step_y, warp_step_z;
427+ prepare_kernel_launch_params (NUM_THREADS, NUM_WARPS, ndim, block_dim,
428+ eff_block_dim, &block_size, &warp_step_x, &warp_step_y, &warp_step_z);
429+ uint32_t _lmem_size = lmem_size;
439430
440- if (dim > 0 ) {
441- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_GRID_DIM0, grid_dim[0 ]), {
442- return err;
443- });
444- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_BLOCK_DIM0, block_dim[0 ]), {
445- return err;
446- });
447- if (dim > 1 ) {
448- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_GRID_DIM1, grid_dim[1 ]), {
449- return err;
450- });
451- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_BLOCK_DIM1, block_dim[1 ]), {
452- return err;
453- });
454- if (dim > 2 ) {
455- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_GRID_DIM2, grid_dim[2 ]), {
456- return err;
457- });
458- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_BLOCK_DIM2, block_dim[2 ]), {
459- return err; });
460- }
431+ {
432+ uint32_t threads_per_core = NUM_WARPS * NUM_THREADS;
433+ if (block_size > threads_per_core) {
434+ std::cerr << " Error: cannot schedule kernel with block_size > threads_per_core ("
435+ << block_size << " ," << threads_per_core << " )\n " ;
436+ return -1 ;
437+ }
438+ int warps_per_block = (block_size + NUM_THREADS - 1 ) / NUM_THREADS;
439+ int blocks_per_core = NUM_WARPS / warps_per_block;
440+ if (_lmem_size == 0 ) {
441+ uint64_t local_mem_size = (1ull << LMEM_LOG_SIZE);
442+ _lmem_size = static_cast <uint32_t >(local_mem_size / blocks_per_core);
461443 }
462444 }
463445
464- CHECK_ERR (this ->dcr_write (VX_DCR_BASE_LMEM_SIZE, lmem_size), {
465- return err;
466- });
446+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_STARTUP_ADDR0, krnl_addr & 0xffffffff ), { return err; });
447+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_STARTUP_ADDR1, static_cast <uint32_t >(krnl_addr >> 32 )), { return err; });
448+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_STARTUP_ARG0, args_addr & 0xffffffff ), { return err; });
449+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_STARTUP_ARG1, static_cast <uint32_t >(args_addr >> 32 )), { return err; });
450+ static const uint32_t grid_regs[3 ] = {VX_DCR_KMU_GRID_DIM_X, VX_DCR_KMU_GRID_DIM_Y, VX_DCR_KMU_GRID_DIM_Z};
451+ static const uint32_t block_regs[3 ] = {VX_DCR_KMU_BLOCK_DIM_X, VX_DCR_KMU_BLOCK_DIM_Y, VX_DCR_KMU_BLOCK_DIM_Z};
452+ for (uint32_t i = 0 ; i < 3 ; ++i) {
453+ CHECK_ERR (this ->dcr_write (grid_regs[i], (i < ndim) ? grid_dim[i] : 1 ), { return err; });
454+ CHECK_ERR (this ->dcr_write (block_regs[i], eff_block_dim[i]), { return err; });
455+ }
456+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_LMEM_SIZE, _lmem_size), { return err; });
457+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_BLOCK_SIZE, block_size), { return err; });
458+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_WARP_STEP_X, warp_step_x), { return err; });
459+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_WARP_STEP_Y, warp_step_y), { return err; });
460+ CHECK_ERR (this ->dcr_write (VX_DCR_KMU_WARP_STEP_Z, warp_step_z), { return err; });
467461
468- // start execution
469462 CHECK_FPGA_ERR (api_.fpgaWriteMMIO64 (fpga_, 0 , MMIO_CMD_TYPE, CMD_RUN), {
470463 return -1 ;
471464 });
472465
473- // clear mpm cache
474- mpm_cache_.clear ();
475-
476466 return 0 ;
477467 }
478468
0 commit comments