2020
2121#include " llvm/Support/Debug.h"
2222#include " llvm/Support/ErrorHandling.h"
23+ #include " llvm/Support/raw_ostream.h"
2324
2425#include < algorithm>
2526#include < optional>
27+ #include < string>
2628#include < vector>
2729
2830#define DEBUG_TYPE " pto-plan-memory"
@@ -69,6 +71,19 @@ static LocalMemSpec getLocalMemSpec(Operation *op, AddressSpace as) {
6971 }
7072}
7173
74+ static std::string getStableValueKey (Value value) {
75+ std::string key;
76+ llvm::raw_string_ostream os (key);
77+ value.printAsOperand (os, OpPrintingFlags ());
78+ return os.str ();
79+ }
80+
81+ static void sortValuesByStableKey (SmallVectorImpl<Value> &values) {
82+ std::stable_sort (values.begin (), values.end (), [](Value lhs, Value rhs) {
83+ return getStableValueKey (lhs) < getStableValueKey (rhs);
84+ });
85+ }
86+
7287static SmallVector<Value> getScratchBuffersFromEffects (Operation *op,
7388 ValueRange dpsInits) {
7489 SmallVector<Value> scratchBuffers;
@@ -91,6 +106,7 @@ static SmallVector<Value> getScratchBuffersFromEffects(Operation *op,
91106 if (!llvm::is_contained (scratchBuffers, value))
92107 scratchBuffers.push_back (value);
93108 }
109+ sortValuesByStableKey (scratchBuffers);
94110 return scratchBuffers;
95111}
96112
@@ -517,6 +533,7 @@ SmallVector<Value> MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp,
517533 allocBeforeLoopBuffers.push_back (Buffer);
518534 }
519535 }
536+ sortValuesByStableKey (allocBeforeLoopBuffers);
520537 return allocBeforeLoopBuffers;
521538}
522539
@@ -654,8 +671,9 @@ void MemLivenessAnalysis::OpKillHandle(OpInfo *opInfo, Liveness live,
654671 if (currentLiveValues.empty ()) {
655672 return ;
656673 }
657- SetVector<Value> liveValues (currentLiveValues.begin (),
658- currentLiveValues.end ());
674+ SmallVector<Value> liveValues (currentLiveValues.begin (),
675+ currentLiveValues.end ());
676+ sortValuesByStableKey (liveValues);
659677 for (const Value &operand : liveValues) {
660678 UpdateOpKillInfo (opInfo, operand, live);
661679 }
@@ -800,21 +818,28 @@ SmallVector<ValuePair> MemPlan::GenerateInplaceList() {
800818 continue ;
801819 if (hasTouchOp[operationSeq->operation ]) {
802820 continue ;
821+ }
822+
823+ SmallVector<Value> genBuffers (it->second .gen .begin (), it->second .gen .end ());
824+ SmallVector<Value> killBuffers (it->second .kill .begin (), it->second .kill .end ());
825+ sortValuesByStableKey (genBuffers);
826+ sortValuesByStableKey (killBuffers);
827+
828+ for (const Value &genBuffer : genBuffers) {
829+ auto genBufferIter = bufferInfos.find (genBuffer);
830+ if (genBufferIter == bufferInfos.end ())
831+ llvm::report_fatal_error (" gen buffer missing from buffer info map" );
832+ if (genBufferIter->second .ignoreInplace ) {
833+ continue ;
803834 }
804- for (const Value &genBuffer : it->second .gen ) {
805- auto genBufferIter = bufferInfos.find (genBuffer);
806- if (genBufferIter == bufferInfos.end ())
807- llvm::report_fatal_error (" gen buffer missing from buffer info map" );
808- if (genBufferIter->second .ignoreInplace ) {
835+
836+ for (const Value &killBuffer : killBuffers) {
837+ auto killBufferIter = bufferInfos.find (killBuffer);
838+ if (killBufferIter == bufferInfos.end ())
839+ llvm::report_fatal_error (" kill buffer missing from buffer info map" );
840+ if (killBufferIter->second .ignoreInplace ) {
809841 continue ;
810842 }
811- for (const Value &killBuffer : it->second .kill ) {
812- auto killBufferIter = bufferInfos.find (killBuffer);
813- if (killBufferIter == bufferInfos.end ())
814- llvm::report_fatal_error (" kill buffer missing from buffer info map" );
815- if (killBufferIter->second .ignoreInplace ) {
816- continue ;
817- }
818843
819844 bool bufferSizeMatch =
820845 killBufferIter->second .constBits >= genBufferIter->second .constBits ;
@@ -915,6 +940,47 @@ LogicalResult MemPlan::plan() {
915940 EmitPlanMemoryFailureInfo ();
916941 return failure ();
917942 }
943+ auto hasAddressOverlap = [](const StorageEntry *lhs, const StorageEntry *rhs) {
944+ uint64_t lhsBegin = lhs->bitsOffset ;
945+ uint64_t lhsEnd = lhs->bitsOffset + lhs->alignedConstBits ;
946+ uint64_t rhsBegin = rhs->bitsOffset ;
947+ uint64_t rhsEnd = rhs->bitsOffset + rhs->alignedConstBits ;
948+ return lhsBegin < rhsEnd && rhsBegin < lhsEnd;
949+ };
950+ SmallVector<const StorageEntry *> plannedEntries;
951+ plannedEntries.reserve (StorageEntryVec.size () + pingEntry2RelationPongEntry.size ());
952+ for (const auto &entry : StorageEntryVec) {
953+ plannedEntries.push_back (entry.get ());
954+ }
955+ for (const auto &entry : pingEntry2RelationPongEntry) {
956+ plannedEntries.push_back (entry.second .get ());
957+ }
958+ for (size_t i = 0 ; i < plannedEntries.size (); ++i) {
959+ for (size_t j = i + 1 ; j < plannedEntries.size (); ++j) {
960+ const StorageEntry *lhs = plannedEntries[i];
961+ const StorageEntry *rhs = plannedEntries[j];
962+ if (!lhs || !rhs) {
963+ continue ;
964+ }
965+ if (lhs->bufInfo ->bufferScope != rhs->bufInfo ->bufferScope ) {
966+ continue ;
967+ }
968+ if (!hasAddressOverlap (lhs, rhs)) {
969+ continue ;
970+ }
971+ bool lifeOverlap =
972+ !GetOverlapBufferLife (lhs->bufferLifeVec , rhs->bufferLifeVec ).empty ();
973+ bool semanticConflict = HasSemanticConflict (lhs, rhs->bufferLifeVec );
974+ if (!lifeOverlap && !semanticConflict) {
975+ continue ;
976+ }
977+ func_.emitError ()
978+ << " PlanMemory produced overlapping local buffers in "
979+ << stringifyEnum (lhs->bufInfo ->bufferScope )
980+ << " at offsets " << lhs->bitsOffset << " and " << rhs->bitsOffset ;
981+ return failure ();
982+ }
983+ }
918984 // Update the address information of each buffer after memory buffer.
919985 UpdateBuffer2Offsets ();
920986 if (enablePrintMemoryAllocatedSize) {
@@ -929,7 +995,9 @@ void MemPlan::GenerateStorageEntry() {
929995 auto it = genKillMap.find (operation.get ());
930996 if (it == genKillMap.end ())
931997 continue ;
932- for (const Value &genBuffer : it->second .gen ) {
998+ SmallVector<Value> genBuffers (it->second .gen .begin (), it->second .gen .end ());
999+ sortValuesByStableKey (genBuffers);
1000+ for (const Value &genBuffer : genBuffers) {
9331001 auto iter = bufferInfos.find (genBuffer);
9341002 if (iter == bufferInfos.end ()) {
9351003 continue ;
0 commit comments