Skip to content

Commit 1658456

Browse files
authored
[AMDGPU] Introduce custom MIR formatting for s_wait_alu (llvm#176316)
This patch implements a custom printer/parser for the immediate operand of s_wait_alu that prints/parses the decoded counter values. Format: ``` .<counter1>_<value1>_<counter2>_<value2> ``` Example: `s_wait_alu .VaVdst_1_VmVsrc_1` ; Which is equivalent to this: `s_wait_alu 8167` Features: - If a counter is at its maximum value it won't get printed. - The parser will error out if a counter is greater or equal to its max value. - If all counters are disabled we can use 'AllOff'. - For now we also accept numeric values for backwards compatibility with older MIR. Note: This is similar to llvm#96004 but for `s_wait_alu`.
1 parent 8523600 commit 1658456

23 files changed

Lines changed: 583 additions & 261 deletions

llvm/lib/Target/AMDGPU/AMDGPUMIRFormatter.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,61 @@
1313

1414
#include "AMDGPUMIRFormatter.h"
1515
#include "SIMachineFunctionInfo.h"
16+
#include "llvm/TargetParser/TargetParser.h"
1617

1718
using namespace llvm;
1819

20+
const char SWaitAluImmPrefix = '.';
21+
StringLiteral SWaitAluDelim = "_";
22+
23+
StringLiteral VaVdstName = "VaVdst";
24+
StringLiteral VaSdstName = "VaSdst";
25+
StringLiteral VaSsrcName = "VaSsrc";
26+
StringLiteral HoldCntName = "HoldCnt";
27+
StringLiteral VmVsrcName = "VmVsrc";
28+
StringLiteral VaVccName = "VaVcc";
29+
StringLiteral SaSdstName = "SaSdst";
30+
31+
StringLiteral AllOff = "AllOff";
32+
33+
void AMDGPUMIRFormatter::printSWaitAluImm(uint64_t Imm, raw_ostream &OS) const {
34+
bool NonePrinted = true;
35+
ListSeparator Delim(SWaitAluDelim);
36+
auto PrintFieldIfNotMax = [&](StringRef Descr, uint64_t Num, unsigned Max) {
37+
if (Num != Max) {
38+
OS << Delim << Descr << SWaitAluDelim << Num;
39+
NonePrinted = false;
40+
}
41+
};
42+
OS << SWaitAluImmPrefix;
43+
PrintFieldIfNotMax(VaVdstName, AMDGPU::DepCtr::decodeFieldVaVdst(Imm),
44+
AMDGPU::DepCtr::getVaVdstBitMask());
45+
PrintFieldIfNotMax(VaSdstName, AMDGPU::DepCtr::decodeFieldVaSdst(Imm),
46+
AMDGPU::DepCtr::getVaSdstBitMask());
47+
PrintFieldIfNotMax(VaSsrcName, AMDGPU::DepCtr::decodeFieldVaSsrc(Imm),
48+
AMDGPU::DepCtr::getVaSsrcBitMask());
49+
PrintFieldIfNotMax(
50+
HoldCntName,
51+
AMDGPU::DepCtr::decodeFieldHoldCnt(Imm,
52+
AMDGPU::getIsaVersion(STI.getCPU())),
53+
AMDGPU::DepCtr::getHoldCntBitMask(AMDGPU::getIsaVersion(STI.getCPU())));
54+
PrintFieldIfNotMax(VmVsrcName, AMDGPU::DepCtr::decodeFieldVmVsrc(Imm),
55+
AMDGPU::DepCtr::getVmVsrcBitMask());
56+
PrintFieldIfNotMax(VaVccName, AMDGPU::DepCtr::decodeFieldVaVcc(Imm),
57+
AMDGPU::DepCtr::getVaVccBitMask());
58+
PrintFieldIfNotMax(SaSdstName, AMDGPU::DepCtr::decodeFieldSaSdst(Imm),
59+
AMDGPU::DepCtr::getSaSdstBitMask());
60+
if (NonePrinted)
61+
OS << AllOff;
62+
}
63+
1964
void AMDGPUMIRFormatter::printImm(raw_ostream &OS, const MachineInstr &MI,
2065
std::optional<unsigned int> OpIdx, int64_t Imm) const {
2166

2267
switch (MI.getOpcode()) {
68+
case AMDGPU::S_WAITCNT_DEPCTR:
69+
printSWaitAluImm(Imm, OS);
70+
break;
2371
case AMDGPU::S_DELAY_ALU:
2472
assert(OpIdx == 0);
2573
printSDelayAluImm(Imm, OS);
@@ -39,6 +87,8 @@ bool AMDGPUMIRFormatter::parseImmMnemonic(const unsigned OpCode,
3987
{
4088

4189
switch (OpCode) {
90+
case AMDGPU::S_WAITCNT_DEPCTR:
91+
return parseSWaitAluImmMnemonic(OpIdx, Imm, Src, ErrorCallback);
4292
case AMDGPU::S_DELAY_ALU:
4393
return parseSDelayAluImmMnemonic(OpIdx, Imm, Src, ErrorCallback);
4494
default:
@@ -90,6 +140,89 @@ void AMDGPUMIRFormatter::printSDelayAluImm(int64_t Imm,
90140
Outdep(Id1);
91141
}
92142

143+
bool AMDGPUMIRFormatter::parseSWaitAluImmMnemonic(
144+
const unsigned int OpIdx, int64_t &Imm, StringRef &Src,
145+
MIRFormatter::ErrorCallbackType &ErrorCallback) const {
146+
// TODO: For now accept integer masks for compatibility with old MIR.
147+
if (!Src.consumeInteger(10, Imm))
148+
return false;
149+
150+
// Initialize with all checks off.
151+
Imm = AMDGPU::DepCtr::getDefaultDepCtrEncoding(STI);
152+
// The input is in the form: .Name1_Num1_Name2_Num2
153+
// Drop the '.' prefix.
154+
bool ConsumePrefix = Src.consume_front(SWaitAluImmPrefix);
155+
if (!ConsumePrefix)
156+
return ErrorCallback(Src.begin(), "expected prefix");
157+
if (Src.empty())
158+
return ErrorCallback(Src.begin(), "expected <CounterName>_<CounterNum>");
159+
160+
// Special case for all off.
161+
if (Src == AllOff)
162+
return false;
163+
164+
// Parse a counter name, number pair in each iteration.
165+
while (!Src.empty()) {
166+
// Src: Name1_Num1_Name2_Num2
167+
// ^
168+
size_t DelimIdx = Src.find(SWaitAluDelim);
169+
if (DelimIdx == StringRef::npos)
170+
return ErrorCallback(Src.begin(), "expected <CounterName>_<CounterNum>");
171+
// Src: Name1_Num1_Name2_Num2
172+
// ^^^^^
173+
StringRef Name = Src.substr(0, DelimIdx);
174+
// Save the position of the name for accurate error reporting.
175+
StringRef::iterator NamePos = Src.begin();
176+
[[maybe_unused]] bool ConsumeName = Src.consume_front(Name);
177+
assert(ConsumeName && "Expected name");
178+
[[maybe_unused]] bool ConsumeDelim = Src.consume_front(SWaitAluDelim);
179+
assert(ConsumeDelim && "Expected delimiter");
180+
// Src: Num1_Name2_Num2
181+
// ^
182+
DelimIdx = Src.find(SWaitAluDelim);
183+
// Src: Num1_Name2_Num2
184+
// ^^^^
185+
int64_t Num;
186+
// Save the position of the number for accurate error reporting.
187+
StringRef::iterator NumPos = Src.begin();
188+
if (Src.consumeInteger(10, Num) || Num < 0)
189+
return ErrorCallback(NumPos,
190+
"expected non-negative integer counter number");
191+
unsigned Max;
192+
if (Name == VaVdstName) {
193+
Max = AMDGPU::DepCtr::getVaVdstBitMask();
194+
Imm = AMDGPU::DepCtr::encodeFieldVaVdst(Imm, Num);
195+
} else if (Name == VmVsrcName) {
196+
Max = AMDGPU::DepCtr::getVmVsrcBitMask();
197+
Imm = AMDGPU::DepCtr::encodeFieldVmVsrc(Imm, Num);
198+
} else if (Name == VaSdstName) {
199+
Max = AMDGPU::DepCtr::getVaSdstBitMask();
200+
Imm = AMDGPU::DepCtr::encodeFieldVaSdst(Imm, Num);
201+
} else if (Name == VaSsrcName) {
202+
Max = AMDGPU::DepCtr::getVaSsrcBitMask();
203+
Imm = AMDGPU::DepCtr::encodeFieldVaSsrc(Imm, Num);
204+
} else if (Name == HoldCntName) {
205+
const AMDGPU::IsaVersion &Version = AMDGPU::getIsaVersion(STI.getCPU());
206+
Max = AMDGPU::DepCtr::getHoldCntBitMask(Version);
207+
Imm = AMDGPU::DepCtr::encodeFieldHoldCnt(Imm, Num, Version);
208+
} else if (Name == VaVccName) {
209+
Max = AMDGPU::DepCtr::getVaVccBitMask();
210+
Imm = AMDGPU::DepCtr::encodeFieldVaVcc(Imm, Num);
211+
} else if (Name == SaSdstName) {
212+
Max = AMDGPU::DepCtr::getSaSdstBitMask();
213+
Imm = AMDGPU::DepCtr::encodeFieldSaSdst(Imm, Num);
214+
} else {
215+
return ErrorCallback(NamePos, "invalid counter name");
216+
}
217+
// Don't allow the values to reach their maximum value.
218+
if (Num >= Max)
219+
return ErrorCallback(NumPos, "counter value too large");
220+
// Src: Name2_Num2
221+
Src.consume_front(SWaitAluDelim);
222+
}
223+
return false;
224+
}
225+
93226
bool AMDGPUMIRFormatter::parseSDelayAluImmMnemonic(
94227
const unsigned int OpIdx, int64_t &Imm, llvm::StringRef &Src,
95228
llvm::MIRFormatter::ErrorCallbackType &ErrorCallback) const

llvm/lib/Target/AMDGPU/AMDGPUMIRFormatter.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#ifndef LLVM_LIB_TARGET_AMDGPUMIRFORMATTER_H
1717
#define LLVM_LIB_TARGET_AMDGPUMIRFORMATTER_H
1818

19+
#include "Utils/AMDGPUBaseInfo.h"
1920
#include "llvm/CodeGen/MIRFormatter.h"
2021

2122
namespace llvm {
@@ -25,7 +26,7 @@ struct PerFunctionMIParsingState;
2526

2627
class AMDGPUMIRFormatter final : public MIRFormatter {
2728
public:
28-
AMDGPUMIRFormatter() = default;
29+
explicit AMDGPUMIRFormatter(const MCSubtargetInfo &STI) : STI(STI) {}
2930
~AMDGPUMIRFormatter() override = default;
3031

3132
/// Implement target specific printing for machine operand immediate value, so
@@ -48,9 +49,17 @@ class AMDGPUMIRFormatter final : public MIRFormatter {
4849
ErrorCallbackType ErrorCallback) const override;
4950

5051
private:
52+
const MCSubtargetInfo &STI;
53+
/// Prints the string to represent s_wait_alu immediate value.
54+
void printSWaitAluImm(uint64_t Imm, raw_ostream &OS) const;
5155
/// Print the string to represent s_delay_alu immediate value
5256
void printSDelayAluImm(int64_t Imm, llvm::raw_ostream &OS) const;
5357

58+
/// Parse the immediate pseudo literal for s_wait_alu
59+
bool parseSWaitAluImmMnemonic(
60+
const unsigned int OpIdx, int64_t &Imm, StringRef &Src,
61+
MIRFormatter::ErrorCallbackType &ErrorCallback) const;
62+
5463
/// Parse the immediate pseudo literal for s_delay_alu
5564
bool parseSDelayAluImmMnemonic(
5665
const unsigned int OpIdx, int64_t &Imm, llvm::StringRef &Src,

llvm/lib/Target/AMDGPU/AMDGPUWaitSGPRHazards.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
1818
#include "SIInstrInfo.h"
1919
#include "llvm/ADT/SetVector.h"
20+
#include "llvm/TargetParser/TargetParser.h"
2021

2122
using namespace llvm;
2223

@@ -182,9 +183,12 @@ class AMDGPUWaitSGPRHazards {
182183
Mask = AMDGPU::DepCtr::encodeFieldVaVdst(
183184
Mask, std::min(AMDGPU::DepCtr::decodeFieldVaVdst(Mask1),
184185
AMDGPU::DepCtr::decodeFieldVaVdst(Mask2)));
186+
const AMDGPU::IsaVersion &Version = AMDGPU::getIsaVersion(ST->getCPU());
185187
Mask = AMDGPU::DepCtr::encodeFieldHoldCnt(
186-
Mask, std::min(AMDGPU::DepCtr::decodeFieldHoldCnt(Mask1),
187-
AMDGPU::DepCtr::decodeFieldHoldCnt(Mask2)));
188+
Mask,
189+
std::min(AMDGPU::DepCtr::decodeFieldHoldCnt(Mask1, Version),
190+
AMDGPU::DepCtr::decodeFieldHoldCnt(Mask2, Version)),
191+
Version);
188192
Mask = AMDGPU::DepCtr::encodeFieldVaSsrc(
189193
Mask, std::min(AMDGPU::DepCtr::decodeFieldVaSsrc(Mask1),
190194
AMDGPU::DepCtr::decodeFieldVaSsrc(Mask2)));

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10688,6 +10688,12 @@ SIInstrInfo::getGenericInstructionUniformity(const MachineInstr &MI) const {
1068810688
return InstructionUniformity::Default;
1068910689
}
1069010690

10691+
const MIRFormatter *SIInstrInfo::getMIRFormatter() const {
10692+
if (!Formatter)
10693+
Formatter = std::make_unique<AMDGPUMIRFormatter>(ST);
10694+
return Formatter.get();
10695+
}
10696+
1069110697
InstructionUniformity
1069210698
SIInstrInfo::getInstructionUniformity(const MachineInstr &MI) const {
1069310699

llvm/lib/Target/AMDGPU/SIInstrInfo.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,11 +1673,7 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
16731673
InstructionUniformity
16741674
getGenericInstructionUniformity(const MachineInstr &MI) const;
16751675

1676-
const MIRFormatter *getMIRFormatter() const override {
1677-
if (!Formatter)
1678-
Formatter = std::make_unique<AMDGPUMIRFormatter>();
1679-
return Formatter.get();
1680-
}
1676+
const MIRFormatter *getMIRFormatter() const override;
16811677

16821678
static unsigned getDSShaderTypeValue(const MachineFunction &MF);
16831679

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,13 @@ inline unsigned getVaSsrcBitWidth() { return 1; }
177177
inline unsigned getVaSsrcBitShift() { return 8; }
178178

179179
/// \returns HoldCnt bit shift
180-
inline unsigned getHoldCntWidth() { return 1; }
180+
inline unsigned getHoldCntWidth(unsigned VersionMajor, unsigned VersionMinor) {
181+
static constexpr const unsigned MinMajor = 10;
182+
static constexpr const unsigned MinMinor = 3;
183+
return std::tie(VersionMajor, VersionMinor) >= std::tie(MinMajor, MinMinor)
184+
? 1
185+
: 0;
186+
}
181187

182188
/// \returns HoldCnt bit shift
183189
inline unsigned getHoldCntBitShift() { return 7; }
@@ -2074,8 +2080,20 @@ int encodeDepCtr(const StringRef Name, int64_t Val, unsigned &UsedOprMask,
20742080

20752081
unsigned getVaVdstBitMask() { return (1 << getVaVdstBitWidth()) - 1; }
20762082

2083+
unsigned getVaSdstBitMask() { return (1 << getVaSdstBitWidth()) - 1; }
2084+
2085+
unsigned getVaSsrcBitMask() { return (1 << getVaSsrcBitWidth()) - 1; }
2086+
2087+
unsigned getHoldCntBitMask(const IsaVersion &Version) {
2088+
return (1 << getHoldCntWidth(Version.Major, Version.Minor)) - 1;
2089+
}
2090+
20772091
unsigned getVmVsrcBitMask() { return (1 << getVmVsrcBitWidth()) - 1; }
20782092

2093+
unsigned getVaVccBitMask() { return (1 << getVaVccBitWidth()) - 1; }
2094+
2095+
unsigned getSaSdstBitMask() { return (1 << getSaSdstBitWidth()) - 1; }
2096+
20792097
unsigned decodeFieldVmVsrc(unsigned Encoded) {
20802098
return unpackBits(Encoded, getVmVsrcBitShift(), getVmVsrcBitWidth());
20812099
}
@@ -2100,8 +2118,9 @@ unsigned decodeFieldVaSsrc(unsigned Encoded) {
21002118
return unpackBits(Encoded, getVaSsrcBitShift(), getVaSsrcBitWidth());
21012119
}
21022120

2103-
unsigned decodeFieldHoldCnt(unsigned Encoded) {
2104-
return unpackBits(Encoded, getHoldCntBitShift(), getHoldCntWidth());
2121+
unsigned decodeFieldHoldCnt(unsigned Encoded, const IsaVersion &Version) {
2122+
return unpackBits(Encoded, getHoldCntBitShift(),
2123+
getHoldCntWidth(Version.Major, Version.Minor));
21052124
}
21062125

21072126
unsigned encodeFieldVmVsrc(unsigned Encoded, unsigned VmVsrc) {
@@ -2158,13 +2177,15 @@ unsigned encodeFieldVaSsrc(unsigned VaSsrc, const MCSubtargetInfo &STI) {
21582177
return encodeFieldVaSsrc(Encoded, VaSsrc);
21592178
}
21602179

2161-
unsigned encodeFieldHoldCnt(unsigned Encoded, unsigned HoldCnt) {
2162-
return packBits(HoldCnt, Encoded, getHoldCntBitShift(), getHoldCntWidth());
2180+
unsigned encodeFieldHoldCnt(unsigned Encoded, unsigned HoldCnt,
2181+
const IsaVersion &Version) {
2182+
return packBits(HoldCnt, Encoded, getHoldCntBitShift(),
2183+
getHoldCntWidth(Version.Major, Version.Minor));
21632184
}
21642185

21652186
unsigned encodeFieldHoldCnt(unsigned HoldCnt, const MCSubtargetInfo &STI) {
21662187
unsigned Encoded = getDefaultDepCtrEncoding(STI);
2167-
return encodeFieldHoldCnt(Encoded, HoldCnt);
2188+
return encodeFieldHoldCnt(Encoded, HoldCnt, getIsaVersion(STI.getCPU()));
21682189
}
21692190

21702191
} // namespace DepCtr

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,24 @@ bool decodeDepCtr(unsigned Code, int &Id, StringRef &Name, unsigned &Val,
13111311
/// \returns Maximum VaVdst value that can be encoded.
13121312
unsigned getVaVdstBitMask();
13131313

1314+
/// \returns Maximum VaSdst value that can be encoded.
1315+
unsigned getVaSdstBitMask();
1316+
1317+
/// \returns Maximum VaSsrc value that can be encoded.
1318+
unsigned getVaSsrcBitMask();
1319+
1320+
/// \returns Maximum HoldCnt value that can be encoded.
1321+
unsigned getHoldCntBitMask(const IsaVersion &Version);
1322+
13141323
/// \returns Maximum VmVsrc value that can be encoded.
13151324
unsigned getVmVsrcBitMask();
13161325

1326+
/// \returns Maximum VaVcc value that can be encoded.
1327+
unsigned getVaVccBitMask();
1328+
1329+
/// \returns Maximum SaSdst value that can be encoded.
1330+
unsigned getSaSdstBitMask();
1331+
13171332
/// \returns Decoded VaVdst from given immediate \p Encoded.
13181333
unsigned decodeFieldVaVdst(unsigned Encoded);
13191334

@@ -1333,7 +1348,7 @@ unsigned decodeFieldVaVcc(unsigned Encoded);
13331348
unsigned decodeFieldVaSsrc(unsigned Encoded);
13341349

13351350
/// \returns Decoded HoldCnt from given immediate \p Encoded.
1336-
unsigned decodeFieldHoldCnt(unsigned Encoded);
1351+
unsigned decodeFieldHoldCnt(unsigned Encoded, const IsaVersion &Version);
13371352

13381353
/// \returns \p VmVsrc as an encoded Depctr immediate.
13391354
unsigned encodeFieldVmVsrc(unsigned VmVsrc, const MCSubtargetInfo &STI);
@@ -1369,7 +1384,8 @@ unsigned encodeFieldVaVcc(unsigned Encoded, unsigned VaVcc);
13691384
unsigned encodeFieldHoldCnt(unsigned HoldCnt, const MCSubtargetInfo &STI);
13701385

13711386
/// \returns \p Encoded combined with encoded \p HoldCnt.
1372-
unsigned encodeFieldHoldCnt(unsigned Encoded, unsigned HoldCnt);
1387+
unsigned encodeFieldHoldCnt(unsigned Encoded, unsigned HoldCnt,
1388+
const IsaVersion &Version);
13731389

13741390
/// \returns \p VaSsrc as an encoded Depctr immediate.
13751391
unsigned encodeFieldVaSsrc(unsigned VaSsrc, const MCSubtargetInfo &STI);

0 commit comments

Comments
 (0)