Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions include/PTO/IR/PTO.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include "PTO/IR/PTOOps.h.inc"

namespace mlir {
class MLIRContext;
class TypeConverter;

namespace pto {
Expand Down Expand Up @@ -101,15 +102,17 @@ enum class PTOParserTargetArch {
A5,
};

void setPTOParserTargetArch(PTOParserTargetArch arch);
PTOParserTargetArch getPTOParserTargetArch();
void setPTOParserTargetArch(MLIRContext *context, PTOParserTargetArch arch);
PTOParserTargetArch getPTOParserTargetArch(MLIRContext *context);

class ScopedPTOParserTargetArch {
public:
explicit ScopedPTOParserTargetArch(PTOParserTargetArch arch);
explicit ScopedPTOParserTargetArch(MLIRContext *context,
PTOParserTargetArch arch);
~ScopedPTOParserTargetArch();

private:
MLIRContext *context;
PTOParserTargetArch previousArch;
};

Expand Down
2 changes: 1 addition & 1 deletion lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ static VerifierTargetArch getVerifierTargetArch(Operation *op) {
: VerifierTargetArch::A2A3;
}

switch (getPTOParserTargetArch()) {
switch (getPTOParserTargetArch(op ? op->getContext() : nullptr)) {
case PTOParserTargetArch::A5:
return VerifierTargetArch::A5;
case PTOParserTargetArch::A3:
Expand Down
41 changes: 30 additions & 11 deletions lib/PTO/IR/PTOTypeDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,50 @@
//===- PTOTypeDefs.cpp --------------------------------------------*- C++ -*-===//
#include "PTO/IR/PTO.h"
#include "mlir/IR/DialectImplementation.h"
#include <mutex>
#include <unordered_map>

using namespace mlir;
using namespace mlir::pto;

namespace {
thread_local PTOParserTargetArch currentParserTargetArch =
PTOParserTargetArch::Unspecified;
std::mutex parserTargetArchMutex;
std::unordered_map<const MLIRContext *, PTOParserTargetArch>
parserTargetArchByContext;
Comment on lines +24 to +26
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a global map and mutex to store context-scoped state is generally discouraged in MLIR. It introduces a global synchronization bottleneck that can impact performance during parallel verification and risks memory leaks if MLIRContext objects are destroyed without explicitly clearing the map.

A more idiomatic and efficient approach is to use an MLIRContext extension. Extensions are owned by the context, ensuring their lifetime is correctly managed, and they avoid global locks for lookup.

Example implementation:

class PTOParserArchExtension : public MLIRContext::Extension {
public:
  using Extension::Extension;
  PTOParserTargetArch arch = PTOParserTargetArch::Unspecified;
};

// In get/set functions:
context->getOrCreateExtension<PTOParserArchExtension>().arch = arch;

}

void mlir::pto::setPTOParserTargetArch(PTOParserTargetArch arch) {
currentParserTargetArch = arch;
void mlir::pto::setPTOParserTargetArch(MLIRContext *context,
PTOParserTargetArch arch) {
if (!context)
return;

std::lock_guard<std::mutex> lock(parserTargetArchMutex);
if (arch == PTOParserTargetArch::Unspecified) {
parserTargetArchByContext.erase(context);
return;
}
parserTargetArchByContext[context] = arch;
}

PTOParserTargetArch mlir::pto::getPTOParserTargetArch() {
return currentParserTargetArch;
PTOParserTargetArch mlir::pto::getPTOParserTargetArch(MLIRContext *context) {
if (!context)
return PTOParserTargetArch::Unspecified;

std::lock_guard<std::mutex> lock(parserTargetArchMutex);
auto it = parserTargetArchByContext.find(context);
if (it == parserTargetArchByContext.end())
return PTOParserTargetArch::Unspecified;
return it->second;
}

mlir::pto::ScopedPTOParserTargetArch::ScopedPTOParserTargetArch(
PTOParserTargetArch arch)
: previousArch(getPTOParserTargetArch()) {
setPTOParserTargetArch(arch);
MLIRContext *context, PTOParserTargetArch arch)
: context(context), previousArch(getPTOParserTargetArch(context)) {
setPTOParserTargetArch(context, arch);
}

mlir::pto::ScopedPTOParserTargetArch::~ScopedPTOParserTargetArch() {
setPTOParserTargetArch(previousArch);
setPTOParserTargetArch(context, previousArch);
}

static SmallVector<int64_t, 4> canonicalizeTileBufValidShape(ArrayRef<int64_t> validShape) {
Expand Down Expand Up @@ -419,7 +438,7 @@ static Type buildTileBufType(AsmParser &parser,
if (memorySpace != AddressSpace::LEFT)
return parsedBLayout;

switch (getPTOParserTargetArch()) {
switch (getPTOParserTargetArch(parser.getContext())) {
case PTOParserTargetArch::A3:
return BLayout::RowMajor;
case PTOParserTargetArch::A5:
Expand Down
36 changes: 36 additions & 0 deletions test/basic/issue487_tsel_i8_parser_arch_scope.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s --check-prefix=A5
// RUN: not ptoas --pto-arch=a3 %s -o - 2>&1 | FileCheck %s --check-prefix=A3

// Keep module attr absent on purpose: this exercises parser-scope arch fallback.
module {
func.func @f0() {
%mask = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%src0 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%src1 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%tmp = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=1, cols=64, v_row=1, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%dst = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=i8, rows=1, cols=64, v_row=1, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
outs(%dst : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
return
}

// Keep >=2 func.func to trigger verifier's IsolatedFromAbove parallel branch.
func.func @f1() {
%mask = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%src0 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%src1 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%tmp = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=1, cols=64, v_row=1, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%dst = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=i8, rows=1, cols=64, v_row=1, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
outs(%dst : !pto.tile_buf<loc=vec, dtype=i8, rows=2, cols=128, v_row=2, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
return
}
}

// A5-NOT: error: 'pto.tsel' op expects A2/A3 tsel src0, src1, and dst element type to be i16/i32/f16/f32
// A5: AICORE void f0()
// A5: AICORE void f1()

// A3: error: 'pto.tsel' op expects A2/A3 tsel src0, src1, and dst element type to be i16/i32/f16/f32
4 changes: 2 additions & 2 deletions tools/ptoas/ptoas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,8 +1038,8 @@ int main(int argc, char **argv) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
pto::ScopedPTOParserTargetArch scopedParserArch(
arch == "a5" ? pto::PTOParserTargetArch::A5
: pto::PTOParserTargetArch::A3);
&context, arch == "a5" ? pto::PTOParserTargetArch::A5
: pto::PTOParserTargetArch::A3);
module = parseSourceFile<ModuleOp>(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error: Failed to parse MLIR.\n";
Expand Down
Loading