@@ -374,6 +374,11 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374374 return false ;
375375}
376376
377+ // / Determines if the given value is a dense tensor instead of a sparse one.
378+ static bool isDenseTensor (Value v) {
379+ return (sparse_tensor::getSparseTensorType (v).isAllDense ());
380+ }
381+
377382// / Test for sorted COO with suitable data and coordinates types.
378383static bool isAdmissibleCOO (SparseTensorType &aTp) {
379384 return aTp.isCompressedLvl (0 ) && aTp.isOrderedLvl (0 ) && !aTp.isUniqueLvl (0 ) &&
@@ -656,6 +661,109 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
656661 return success ();
657662}
658663
664+ // Match and rewrite 2:4 SpMM kernels.
665+ static LogicalResult rewrite2To4SpMM (PatternRewriter &rewriter,
666+ linalg::GenericOp op) {
667+ Location loc = op.getLoc ();
668+ Value A = op.getOperand (0 );
669+ Value B = op.getOperand (1 );
670+ Value C = op.getOperand (2 ); // we have C = AB
671+ SmallVector<Value> tokens;
672+
673+ // All input should be dense tensors.
674+ if (!isDenseTensor (A) || !isDenseTensor (B) || !isDenseTensor (C))
675+ return failure ();
676+
677+ Value bufA = genTensorToMemref (rewriter, loc, A);
678+ Value matA = genAllocCopy (rewriter, loc, bufA, tokens);
679+ Value bufB = genTensorToMemref (rewriter, loc, B);
680+ Value matB = genAllocCopy (rewriter, loc, bufB, tokens);
681+ Value bufC = genTensorToMemref (rewriter, loc, C);
682+ Value matC = genAllocCopy (rewriter, loc, bufC, tokens);
683+ genBlockingWait (rewriter, loc, tokens);
684+ tokens.clear ();
685+ Value szm = linalg::createOrFoldDimOp (rewriter, loc, matA, 0 );
686+ Value szk = linalg::createOrFoldDimOp (rewriter, loc, matB, 0 );
687+ Value szn = linalg::createOrFoldDimOp (rewriter, loc, matC, 1 );
688+
689+ Type indexTp = rewriter.getIndexType ();
690+ Type dnTensorHandleTp = rewriter.getType <gpu::SparseDnTensorHandleType>();
691+ Type spMatHandleTp = rewriter.getType <gpu::SparseSpMatHandleType>();
692+ Type tokenTp = rewriter.getType <gpu::AsyncTokenType>();
693+ Value token = genFirstWait (rewriter, loc);
694+ Operation *spGenA = rewriter.create <gpu::Create2To4SpMatOp>(
695+ loc, spMatHandleTp, tokenTp, token, szm, szk, matA);
696+
697+ Value spMatA = spGenA->getResult (0 );
698+ token = spGenA->getResult (1 );
699+ auto dmatB = rewriter.create <gpu::CreateDnTensorOp>(
700+ loc, dnTensorHandleTp, tokenTp, token, matB,
701+ SmallVector<Value>{szk, szn});
702+ Value dnB = dmatB.getResult (0 );
703+ token = dmatB.getAsyncToken ();
704+ auto dmatC = rewriter.create <gpu::CreateDnTensorOp>(
705+ loc, dnTensorHandleTp, tokenTp, token, matC,
706+ SmallVector<Value>{szm, szn});
707+ Value dnC = dmatC.getResult (0 );
708+ token = dmatC.getAsyncToken ();
709+
710+ auto dmatCType = llvm::cast<ShapedType>(matC.getType ()).getElementType ();
711+
712+ // Precompute buffersize for SpMM.
713+ SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
714+ TypeRange bufferTypes (bufferTypes_);
715+ auto bufferComp = rewriter.create <gpu::SpMMBufferSizeOp>(
716+ loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
717+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
718+ /* computeType=*/ dmatCType);
719+
720+ token = bufferComp.getAsyncToken ();
721+ Value bufferSz = bufferComp.getResult (0 );
722+ auto buf = genAllocBuffer (rewriter, loc, bufferSz, token);
723+ Value buffer = buf.getResult (0 );
724+ token = buf.getAsyncToken ();
725+
726+ Value bufferSz2 = bufferComp.getResult (1 );
727+ auto buf2 = genAllocBuffer (rewriter, loc, bufferSz2, token);
728+ Value buffer2 = buf2.getResult (0 );
729+ token = buf2.getAsyncToken ();
730+
731+ Value bufferSz3 = bufferComp.getResult (2 );
732+ auto buf3 = genAllocBuffer (rewriter, loc, bufferSz3, token);
733+ Value buffer3 = buf3.getResult (0 );
734+ token = buf3.getAsyncToken ();
735+
736+ auto dnCType = llvm::cast<ShapedType>(matC.getType ()).getElementType ();
737+
738+ // Perform the SpMM.
739+ auto spmmComp = rewriter.create <gpu::SpMMOp>(
740+ loc, tokenTp, token, spMatA, dnB, dnC, /* computeType=*/ dnCType,
741+ SmallVector<Value>{buffer, buffer2, buffer3});
742+ token = spmmComp.getAsyncToken ();
743+
744+ // Copy data back to host and free all the resources.
745+ token = rewriter.create <gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
746+ .getAsyncToken ();
747+ token = rewriter.create <gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
748+ .getAsyncToken ();
749+ token = rewriter.create <gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
750+ .getAsyncToken ();
751+ SmallVector<Value> newDynamicSizes;
752+
753+ token = genDeallocMemRef (rewriter, loc, buffer, token);
754+ token = genDeallocMemRef (rewriter, loc, buffer2, token);
755+ token = genDeallocMemRef (rewriter, loc, buffer3, token);
756+ token = genDeallocMemRef (rewriter, loc, matA, token);
757+ token = genDeallocMemRef (rewriter, loc, matB, token);
758+ token = genCopyMemRef (rewriter, loc, bufC, matC, token);
759+ token = genDeallocMemRef (rewriter, loc, matC, token);
760+ tokens.push_back (token);
761+ genBlockingWait (rewriter, loc, tokens);
762+ tokens.clear ();
763+ rewriter.replaceOpWithNewOp <bufferization::ToTensorOp>(op, bufC);
764+ return success ();
765+ }
766+
659767// / Match and rewrite SDDMM kernel.
660768static LogicalResult rewriteSDDMM (PatternRewriter &rewriter,
661769 linalg::GenericOp op, bool enableRT) {
@@ -906,6 +1014,9 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
9061014 // TODO: add transposed {i, k}, {k, j}
9071015 // TODO: maybe add transposed {i, j} in future
9081016 maps == infer ({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs (op)) {
1017+ if (op->getAttr (" DENSE24" ))
1018+ return rewrite2To4SpMM (rewriter, op);
1019+
9091020 return rewriteSpMM (rewriter, op, enableRT);
9101021 }
9111022
0 commit comments