Skip to content

Commit 6f1b9b3

Browse files
Merge pull request #3 from DataLabTechTV/dev
feat: graph rag cli options for interactive and direct querying
2 parents 5b107a1 + f94ade9 commit 6f1b9b3

6 files changed

Lines changed: 93 additions & 21 deletions

File tree

graph/cli.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
2+
from typing import Optional
23

34
import click
45
from loguru import logger as log
56

67
from graph.embedding import NodeEmbedding, NodeEmbeddingAlgo
78
from graph.ops import KuzuOps
8-
from graph.rag import GraphRAG
9+
from graph.rag import ContextAssemblerException, GraphRAG, GraphRetrievalException
910
from shared.lakehouse import Lakehouse
1011
from shared.settings import env
1112

@@ -96,7 +97,7 @@ def embeddings(schema: str, dimension: int, batch_size: int, epochs: int, algo:
9697
log.exception(e)
9798

9899

99-
@graph.command()
100+
@graph.command(help="Reindex embedding property")
100101
@click.argument("schema", type=click.STRING)
101102
def reindex(schema: str):
102103
try:
@@ -106,11 +107,36 @@ def reindex(schema: str):
106107
log.error(e)
107108

108109

109-
@graph.command()
110+
@graph.command(help="Run GraphRAG pipeline")
110111
@click.argument("schema", type=click.STRING)
111-
def rag(schema: str):
112+
@click.option(
113+
"--interactive",
114+
"-i",
115+
is_flag=True,
116+
help="Run in interactive mode using a REPL",
117+
)
118+
@click.option(
119+
"--query",
120+
"-q",
121+
type=click.STRING,
122+
help="User query prompt",
123+
)
124+
def rag(schema: str, interactive: bool, query: Optional[str]):
112125
gr = GraphRAG(schema)
113-
gr.interactive()
126+
127+
if interactive and query is not None:
128+
raise click.UsageError("--interactive and --query cannot be used together")
129+
130+
if query is not None:
131+
try:
132+
response = gr.invoke(dict(user_query=query))
133+
log.info("Final response:\n{}", response.content)
134+
except GraphRetrievalException as e:
135+
log.error("{}\n{}", e, e.query)
136+
except ContextAssemblerException as e:
137+
log.error(e)
138+
else:
139+
gr.interactive()
114140

115141

116142
if __name__ == "__main__":

graph/ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import shutil
33
import tempfile
4-
import textwrap
54
from enum import Enum
65
from string import Template
76
from typing import Any, Optional

graph/rag.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import ollama
88
import pandas as pd
9+
from colorama import Fore
910
from langchain.prompts import ChatPromptTemplate
1011
from langchain.schema import AIMessage
1112
from langchain.schema.runnable import Runnable
@@ -56,6 +57,16 @@ def get_line(lineno):
5657
return get_line
5758

5859

60+
class GraphRetrievalException(Exception):
61+
def __init__(self, message, query):
62+
self.query = query
63+
super().__init__(message)
64+
65+
66+
class ContextAssemblerException(Exception):
67+
pass
68+
69+
5970
class GraphRAG(Runnable):
6071
def __init__(
6172
self,
@@ -170,18 +181,27 @@ def query_graph(
170181
limit: Optional[int] = None,
171182
) -> RunnableFn:
172183
def run(inputs: dict[str, Any]) -> dict[str, Any]:
184+
log.info(
185+
"Querying graph for matching entities (shuffle={}, limit={})",
186+
shuffle,
187+
limit,
188+
)
189+
173190
query = inputs["query"]
174191
params = inputs.get("params")
175192

176-
context_df = pd.DataFrame(self.graph.query(query, params))
193+
try:
194+
context_df = pd.DataFrame(self.graph.query(query, params))
177195

178-
if shuffle:
179-
context_df = context_df.sample(frac=1)
196+
if shuffle:
197+
context_df = context_df.sample(frac=1)
180198

181-
if limit is not None:
182-
context_df = context_df.head(limit)
199+
if limit is not None:
200+
context_df = context_df.head(limit)
183201

184-
return dict(context=context_df)
202+
return dict(context=context_df)
203+
except:
204+
raise GraphRetrievalException("Graph query failed", query=query)
185205

186206
return run
187207

@@ -214,6 +234,10 @@ def combined_knn(self, k: int) -> RunnableFn:
214234

215235
def run(inputs: dict[str, Any]) -> dict[str, Any]:
216236
context = inputs["context"]
237+
238+
if context is None or len(context) == 0:
239+
raise ContextAssemblerException("Context not found")
240+
217241
node_ids = context.node_id.to_list()
218242

219243
for node_id in node_ids:
@@ -247,9 +271,17 @@ def nn_sample_shortest_paths(
247271
max_length: int,
248272
) -> RunnableFn:
249273
def run(inputs: dict[str, Any]) -> dict[str, Any]:
250-
source_node_ids = inputs["graph_retrieval"]["context"].node_id.to_list()
274+
context = inputs["graph_retrieval"]["context"]
275+
276+
if context is None or len(context) == 0:
277+
raise ContextAssemblerException("Context not found")
278+
279+
source_node_ids = context.node_id.to_list()
251280
target_node_ids = inputs["combined_knn"]["knn"]
252281

282+
if target_node_ids is None or len(target_node_ids) == 0:
283+
raise ContextAssemblerException("Nearest neighbors not found")
284+
253285
paths_df = self.ops.sample_shortest_paths(
254286
source_node_ids,
255287
target_node_ids,
@@ -271,6 +303,9 @@ def nn_random_walks(
271303
def run(inputs: dict[str, Any]) -> dict[str, Any]:
272304
source_node_ids = inputs["combined_knn"]["knn"]
273305

306+
if source_node_ids is None or len(source_node_ids) == 0:
307+
raise ContextAssemblerException("Nearest neighbors not found")
308+
274309
paths_dfs = []
275310

276311
for source_node_id in source_node_ids:
@@ -412,7 +447,7 @@ def loader(self, stop_event: threading.Event):
412447

413448
time.sleep(0.1)
414449

415-
print(f"\r{elapsed}\n")
450+
print("\b\b\b ", end="\n\n", flush=True)
416451

417452
def interactive(self):
418453
config_path = user_config_path("datalab", "DataLabTechTV")
@@ -461,9 +496,17 @@ def interactive(self):
461496
)
462497
loader_thread.start()
463498

464-
output = self.invoke(dict(user_query=user_query))
465-
466-
stop_event.set()
467-
loader_thread.join()
468-
469-
print(output["context"])
499+
try:
500+
response = self.invoke(dict(user_query=user_query))
501+
stop_event.set()
502+
loader_thread.join()
503+
print(response.content)
504+
except GraphRetrievalException as e:
505+
stop_event.set()
506+
loader_thread.join()
507+
print(Fore.RED + "Error: " + str(e))
508+
print(Fore.MAGENTA + e.query + Fore.RESET)
509+
except ContextAssemblerException as e:
510+
stop_event.set()
511+
loader_thread.join()
512+
print(Fore.RED + "Error: " + str(e) + Fore.RESET)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"boto3>=1.38.29",
1313
"boto3-stubs[s3]>=1.38.29",
1414
"click>=8.2.1",
15+
"colorama>=0.4.6",
1516
"dbt-core>=1.9.6",
1617
"dbt-duckdb",
1718
"environs>=14.2.0",

tests/test_graph_rag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
PROMPTS = (
66
"If I like metal artists like Metallica or Iron Maiden, but also listen to IDM, what other artists and genres could I listen to?",
7+
"What other bands like Anthrax are there?",
78
)
89

910

uv.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)