Skip to content

Commit dd3d727

Browse files
authored
Revert "[llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings" (llvm#184179)
Reverts llvm#180140 Unblock bot: https://lab.llvm.org/buildbot/#/builders/140
1 parent f486fc9 commit dd3d727

4 files changed

Lines changed: 43 additions & 119 deletions

File tree

llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,6 @@
4343
print(f" BB: {bb_name}")
4444
print(f" Embedding: {emb.tolist()}")
4545

46-
# Test getInstEmbMap
47-
print("\n=== Instruction Embeddings ===")
48-
49-
# Test valid function names in sorted order
50-
for func_name in sorted(["add", "multiply", "conditional"]):
51-
inst_emb_map = tool.getInstEmbMap(func_name)
52-
print(f"Function: {func_name}")
53-
for inst_str in sorted(inst_emb_map.keys()):
54-
emb = inst_emb_map[inst_str]
55-
print(f" Inst: {inst_str}")
56-
print(f" Embedding: {emb.tolist()}")
57-
5846
# CHECK: SUCCESS: Tool initialized
5947
# CHECK: Tool type: IR2VecTool
6048
# CHECK: === Function Embeddings ===
@@ -87,29 +75,3 @@
8775
# CHECK: Function: multiply
8876
# CHECK: BB: entry
8977
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
90-
# CHECK: === Instruction Embeddings ===
91-
# CHECK: Function: add
92-
# CHECK: Inst: %sum = add i32 %a, %b
93-
# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
94-
# CHECK: Inst: ret i32 %sum
95-
# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
96-
# CHECK: Function: conditional
97-
# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
98-
# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
99-
# CHECK: Inst: %neg_val = sub i32 %n, 10
100-
# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
101-
# CHECK: Inst: %pos_val = add i32 %n, 10
102-
# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
103-
# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
104-
# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
105-
# CHECK: Inst: br i1 %cmp, label %positive, label %negative
106-
# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
107-
# CHECK: Inst: br label %exit
108-
# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
109-
# CHECK: Inst: ret i32 %result
110-
# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
111-
# CHECK: Function: multiply
112-
# CHECK: Inst: %prod = mul i32 %x, %y
113-
# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
114-
# CHECK: Inst: ret i32 %prod
115-
# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]

llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -150,41 +150,6 @@ class PyIR2VecTool {
150150

151151
return NbBBEmbMap;
152152
}
153-
154-
nb::dict getInstEmbMap(const std::string &FuncName) {
155-
const Function *F = M->getFunction(FuncName);
156-
157-
if (!F)
158-
throw nb::value_error(
159-
("Function '" + FuncName + "' not found in module").c_str());
160-
161-
auto ToolInstEmbMap = Tool->getInstEmbeddingsMap(*F, OutputEmbeddingMode);
162-
163-
if (!ToolInstEmbMap)
164-
throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
165-
166-
nb::dict NbInstEmbMap;
167-
168-
for (const auto &[InstPtr, InstEmb] : *ToolInstEmbMap) {
169-
auto InstEmbVec = InstEmb.getData();
170-
double *NbInstEmbVec = new double[InstEmbVec.size()];
171-
std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
172-
173-
auto NbArray = nb::ndarray<nb::numpy, double>(
174-
NbInstEmbVec, {InstEmbVec.size()},
175-
nb::capsule(NbInstEmbVec, [](void *P) noexcept {
176-
delete[] static_cast<double *>(P);
177-
}));
178-
179-
std::string InstStr;
180-
raw_string_ostream OS(InstStr);
181-
InstPtr->print(OS);
182-
183-
NbInstEmbMap[nb::str(OS.str().c_str())] = NbArray;
184-
}
185-
186-
return NbInstEmbMap;
187-
}
188153
};
189154

190155
} // namespace
@@ -208,13 +173,7 @@ NB_MODULE(ir2vec, m) {
208173
"Generate embeddings for all basic blocks in a function\n"
209174
"Args: funcName (str) - IR-Name of the function\n"
210175
"Returns: dict[str, ndarray[float64]] - "
211-
"{basic_block_name: embedding vector}")
212-
.def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap, nb::arg("funcName"),
213-
"Generate embeddings for all instructions in a function\n"
214-
"Args: funcName (str) - IR-Name of the function\n"
215-
"Returns: dict[str, ndarray[float64]] - "
216-
"{instruction_string: embedding_vector}");
217-
176+
"{basic_block_name: embedding vector}");
218177
m.def(
219178
"initEmbedding",
220179
[](const std::string &filename, const std::string &mode,

llvm/tools/llvm-ir2vec/lib/Utils.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
152152
OS << Entities[EntityID] << '\t' << EntityID << '\n';
153153
}
154154

155-
Expected<std::unique_ptr<Embedder>>
156-
IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
155+
Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
156+
IR2VecKind Kind) const {
157157
if (!Vocab || !Vocab->isValid())
158158
return createStringError(
159159
errc::invalid_argument,
@@ -169,20 +169,16 @@ IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
169169
"Failed to create embedder for function '%s'.",
170170
F.getName().str().c_str());
171171

172-
return Emb;
173-
}
174-
175-
Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
176-
IR2VecKind Kind) const {
177-
auto Emb = createIR2VecEmbedder(F, Kind);
178-
if (!Emb)
179-
return Emb.takeError();
180-
181-
return (*Emb)->getFunctionVector();
172+
return Emb->getFunctionVector();
182173
}
183174

184175
Expected<FuncEmbMap>
185176
IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
177+
if (!Vocab || !Vocab->isValid())
178+
return createStringError(
179+
errc::invalid_argument,
180+
"Vocabulary is not valid. IR2VecTool not initialized.");
181+
186182
FuncEmbMap Result;
187183

188184
for (const Function &F : M.getFunctionDefs()) {
@@ -197,47 +193,61 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
197193

198194
Expected<BBEmbeddingsMap>
199195
IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
200-
auto Emb = createIR2VecEmbedder(F, Kind);
201-
if (!Emb)
202-
return Emb.takeError();
196+
if (!Vocab || !Vocab->isValid())
197+
return createStringError(
198+
errc::invalid_argument,
199+
"Vocabulary is not valid. IR2VecTool not initialized.");
203200

204201
BBEmbeddingsMap Result;
205202

206-
for (const BasicBlock &BB : F)
207-
Result.try_emplace(&BB, (*Emb)->getBBVector(BB));
208-
209-
return Result;
210-
}
203+
if (F.isDeclaration())
204+
return createStringError(errc::invalid_argument,
205+
"Function is a declaration.");
211206

212-
Expected<InstEmbeddingsMap>
213-
IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
214-
auto Emb = createIR2VecEmbedder(F, Kind);
207+
auto Emb = Embedder::create(Kind, F, *Vocab);
215208
if (!Emb)
216-
return Emb.takeError();
217-
218-
InstEmbeddingsMap Result;
209+
return createStringError(errc::invalid_argument,
210+
"Failed to create embedder for function '%s'.",
211+
F.getName().str().c_str());
219212

220-
for (const Instruction &I : instructions(F))
221-
Result.try_emplace(&I, (*Emb)->getInstVector(I));
213+
for (const BasicBlock &BB : F)
214+
Result.try_emplace(&BB, Emb->getBBVector(BB));
222215

223216
return Result;
224217
}
225218

226219
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
227220
EmbeddingLevel Level) const {
221+
if (!Vocab || !Vocab->isValid()) {
222+
WithColor::error(errs(), ToolName)
223+
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
224+
return;
225+
}
226+
228227
for (const Function &F : M.getFunctionDefs())
229228
writeEmbeddingsToStream(F, OS, Level);
230229
}
231230

232231
void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
233232
EmbeddingLevel Level) const {
234-
auto IR2VecEmbedderObj = createIR2VecEmbedder(F, IR2VecEmbeddingKind);
235-
if (!IR2VecEmbedderObj) {
233+
if (!Vocab || !Vocab->isValid()) {
234+
WithColor::error(errs(), ToolName)
235+
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
236+
return;
237+
}
238+
239+
if (F.isDeclaration()) {
240+
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
241+
return;
242+
}
243+
244+
// Create embedder for this function
245+
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
246+
if (!Emb) {
236247
WithColor::error(errs(), ToolName)
237-
<< toString(IR2VecEmbedderObj.takeError()) << "\n";
248+
<< "Failed to create embedder for function " << F.getName() << "\n";
238249
return;
239250
}
240-
auto Emb = std::move(*IR2VecEmbedderObj);
241251

242252
OS << "Function: " << F.getName() << "\n";
243253

llvm/tools/llvm-ir2vec/lib/Utils.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ class IR2VecTool {
9494
public:
9595
explicit IR2VecTool(Module &M) : M(M) {}
9696

97-
/// Creates the embedding object for downstream embedding streaming
98-
Expected<std::unique_ptr<Embedder>>
99-
createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
100-
10197
/// Initialize the IR2Vec vocabulary from the specified file path.
10298
Error initializeVocabulary(StringRef VocabPath);
10399

@@ -131,9 +127,6 @@ class IR2VecTool {
131127
/// Get embeddings for all basic blocks in a function
132128
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
133129
IR2VecKind Kind) const;
134-
/// Get embeddings for all instructions in a function
135-
Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
136-
IR2VecKind Kind) const;
137130

138131
/// Generate embeddings for the entire module
139132
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;

0 commit comments

Comments
 (0)