Skip to content

Commit ff6e0df

Browse files
committed
refactor: rename context to entities when referring to entity nodes
1 parent 5b964fb commit ff6e0df

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

graph/rag.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,15 @@ def run(inputs: dict[str, Any]) -> dict[str, Any]:
191191
params = inputs.get("params")
192192

193193
try:
194-
context_df = pd.DataFrame(self.graph.query(query, params))
194+
entities_df = pd.DataFrame(self.graph.query(query, params))
195195

196196
if shuffle:
197-
context_df = context_df.sample(frac=1)
197+
entities_df = entities_df.sample(frac=1)
198198

199199
if limit is not None:
200-
context_df = context_df.head(limit)
200+
entities_df = entities_df.head(limit)
201201

202-
return dict(context=context_df)
202+
return dict(entities=entities_df)
203203
except:
204204
raise GraphRetrievalException("Graph query failed", query=query)
205205

@@ -233,12 +233,12 @@ def combined_knn(self, k: int) -> RunnableFn:
233233
knn_per_node_dfs = []
234234

235235
def run(inputs: dict[str, Any]) -> dict[str, Any]:
236-
context = inputs["context"]
236+
entities = inputs["entities"]
237237

238-
if context is None or len(context) == 0:
239-
raise ContextAssemblerException("Context not found")
238+
if entities is None or len(entities) == 0:
239+
raise ContextAssemblerException("Entities not found")
240240

241-
node_ids = context.node_id.to_list()
241+
node_ids = entities.node_id.to_list()
242242

243243
for node_id in node_ids:
244244
knn_df = self.ops.knn(
@@ -271,12 +271,12 @@ def nn_sample_shortest_paths(
271271
max_length: int,
272272
) -> RunnableFn:
273273
def run(inputs: dict[str, Any]) -> dict[str, Any]:
274-
context = inputs["graph_retrieval"]["context"]
274+
entities = inputs["graph_retrieval"]["entities"]
275275

276-
if context is None or len(context) == 0:
277-
raise ContextAssemblerException("Context not found")
276+
if entities is None or len(entities) == 0:
277+
raise ContextAssemblerException("Entities not found")
278278

279-
source_node_ids = context.node_id.to_list()
279+
source_node_ids = entities.node_id.to_list()
280280
target_node_ids = inputs["combined_knn"]["knn"]
281281

282282
if target_node_ids is None or len(target_node_ids) == 0:

0 commit comments

Comments
 (0)