Skip to content

Commit c29447b

Browse files
Merge pull request #545 from hw-native-sys/codex/planmemory-determinism-issue541
[PlanMemory] Stabilize ordering and reject overlapping UB layouts
2 parents 7235dba + 275104c commit c29447b

1 file changed

Lines changed: 83 additions & 15 deletions

File tree

lib/PTO/Transforms/PTOPlanMemory.cpp

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
#include "llvm/Support/Debug.h"
2222
#include "llvm/Support/ErrorHandling.h"
23+
#include "llvm/Support/raw_ostream.h"
2324

2425
#include <algorithm>
2526
#include <optional>
27+
#include <string>
2628
#include <vector>
2729

2830
#define DEBUG_TYPE "pto-plan-memory"
@@ -69,6 +71,19 @@ static LocalMemSpec getLocalMemSpec(Operation *op, AddressSpace as) {
6971
}
7072
}
7173

74+
static std::string getStableValueKey(Value value) {
75+
std::string key;
76+
llvm::raw_string_ostream os(key);
77+
value.printAsOperand(os, OpPrintingFlags());
78+
return os.str();
79+
}
80+
81+
static void sortValuesByStableKey(SmallVectorImpl<Value> &values) {
82+
std::stable_sort(values.begin(), values.end(), [](Value lhs, Value rhs) {
83+
return getStableValueKey(lhs) < getStableValueKey(rhs);
84+
});
85+
}
86+
7287
static SmallVector<Value> getScratchBuffersFromEffects(Operation *op,
7388
ValueRange dpsInits) {
7489
SmallVector<Value> scratchBuffers;
@@ -91,6 +106,7 @@ static SmallVector<Value> getScratchBuffersFromEffects(Operation *op,
91106
if (!llvm::is_contained(scratchBuffers, value))
92107
scratchBuffers.push_back(value);
93108
}
109+
sortValuesByStableKey(scratchBuffers);
94110
return scratchBuffers;
95111
}
96112

@@ -517,6 +533,7 @@ SmallVector<Value> MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp,
517533
allocBeforeLoopBuffers.push_back(Buffer);
518534
}
519535
}
536+
sortValuesByStableKey(allocBeforeLoopBuffers);
520537
return allocBeforeLoopBuffers;
521538
}
522539

@@ -654,8 +671,9 @@ void MemLivenessAnalysis::OpKillHandle(OpInfo *opInfo, Liveness live,
654671
if (currentLiveValues.empty()) {
655672
return;
656673
}
657-
SetVector<Value> liveValues(currentLiveValues.begin(),
658-
currentLiveValues.end());
674+
SmallVector<Value> liveValues(currentLiveValues.begin(),
675+
currentLiveValues.end());
676+
sortValuesByStableKey(liveValues);
659677
for (const Value &operand : liveValues) {
660678
UpdateOpKillInfo(opInfo, operand, live);
661679
}
@@ -800,21 +818,28 @@ SmallVector<ValuePair> MemPlan::GenerateInplaceList() {
800818
continue;
801819
if (hasTouchOp[operationSeq->operation]) {
802820
continue;
821+
}
822+
823+
SmallVector<Value> genBuffers(it->second.gen.begin(), it->second.gen.end());
824+
SmallVector<Value> killBuffers(it->second.kill.begin(), it->second.kill.end());
825+
sortValuesByStableKey(genBuffers);
826+
sortValuesByStableKey(killBuffers);
827+
828+
for (const Value &genBuffer : genBuffers) {
829+
auto genBufferIter = bufferInfos.find(genBuffer);
830+
if (genBufferIter == bufferInfos.end())
831+
llvm::report_fatal_error("gen buffer missing from buffer info map");
832+
if (genBufferIter->second.ignoreInplace) {
833+
continue;
803834
}
804-
for (const Value &genBuffer : it->second.gen) {
805-
auto genBufferIter = bufferInfos.find(genBuffer);
806-
if (genBufferIter == bufferInfos.end())
807-
llvm::report_fatal_error("gen buffer missing from buffer info map");
808-
if (genBufferIter->second.ignoreInplace) {
835+
836+
for (const Value &killBuffer : killBuffers) {
837+
auto killBufferIter = bufferInfos.find(killBuffer);
838+
if (killBufferIter == bufferInfos.end())
839+
llvm::report_fatal_error("kill buffer missing from buffer info map");
840+
if (killBufferIter->second.ignoreInplace) {
809841
continue;
810842
}
811-
for (const Value &killBuffer : it->second.kill) {
812-
auto killBufferIter = bufferInfos.find(killBuffer);
813-
if (killBufferIter == bufferInfos.end())
814-
llvm::report_fatal_error("kill buffer missing from buffer info map");
815-
if (killBufferIter->second.ignoreInplace) {
816-
continue;
817-
}
818843

819844
bool bufferSizeMatch =
820845
killBufferIter->second.constBits >= genBufferIter->second.constBits;
@@ -915,6 +940,47 @@ LogicalResult MemPlan::plan() {
915940
EmitPlanMemoryFailureInfo();
916941
return failure();
917942
}
943+
auto hasAddressOverlap = [](const StorageEntry *lhs, const StorageEntry *rhs) {
944+
uint64_t lhsBegin = lhs->bitsOffset;
945+
uint64_t lhsEnd = lhs->bitsOffset + lhs->alignedConstBits;
946+
uint64_t rhsBegin = rhs->bitsOffset;
947+
uint64_t rhsEnd = rhs->bitsOffset + rhs->alignedConstBits;
948+
return lhsBegin < rhsEnd && rhsBegin < lhsEnd;
949+
};
950+
SmallVector<const StorageEntry *> plannedEntries;
951+
plannedEntries.reserve(StorageEntryVec.size() + pingEntry2RelationPongEntry.size());
952+
for (const auto &entry : StorageEntryVec) {
953+
plannedEntries.push_back(entry.get());
954+
}
955+
for (const auto &entry : pingEntry2RelationPongEntry) {
956+
plannedEntries.push_back(entry.second.get());
957+
}
958+
for (size_t i = 0; i < plannedEntries.size(); ++i) {
959+
for (size_t j = i + 1; j < plannedEntries.size(); ++j) {
960+
const StorageEntry *lhs = plannedEntries[i];
961+
const StorageEntry *rhs = plannedEntries[j];
962+
if (!lhs || !rhs) {
963+
continue;
964+
}
965+
if (lhs->bufInfo->bufferScope != rhs->bufInfo->bufferScope) {
966+
continue;
967+
}
968+
if (!hasAddressOverlap(lhs, rhs)) {
969+
continue;
970+
}
971+
bool lifeOverlap =
972+
!GetOverlapBufferLife(lhs->bufferLifeVec, rhs->bufferLifeVec).empty();
973+
bool semanticConflict = HasSemanticConflict(lhs, rhs->bufferLifeVec);
974+
if (!lifeOverlap && !semanticConflict) {
975+
continue;
976+
}
977+
func_.emitError()
978+
<< "PlanMemory produced overlapping local buffers in "
979+
<< stringifyEnum(lhs->bufInfo->bufferScope)
980+
<< " at offsets " << lhs->bitsOffset << " and " << rhs->bitsOffset;
981+
return failure();
982+
}
983+
}
918984
// Update the address information of each buffer after memory buffer.
919985
UpdateBuffer2Offsets();
920986
if (enablePrintMemoryAllocatedSize) {
@@ -929,7 +995,9 @@ void MemPlan::GenerateStorageEntry() {
929995
auto it = genKillMap.find(operation.get());
930996
if (it == genKillMap.end())
931997
continue;
932-
for (const Value &genBuffer : it->second.gen) {
998+
SmallVector<Value> genBuffers(it->second.gen.begin(), it->second.gen.end());
999+
sortValuesByStableKey(genBuffers);
1000+
for (const Value &genBuffer : genBuffers) {
9331001
auto iter = bufferInfos.find(genBuffer);
9341002
if (iter == bufferInfos.end()) {
9351003
continue;

0 commit comments

Comments
 (0)