Skip to content
21 changes: 11 additions & 10 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TextMerger:

Args:
model_name: A tiktoken model name (e.g. ``"gpt-4"``). When given,
the tokenizer for that model is used.
the tokenizer for that model is used.
encoding_name: A tiktoken encoding name (default ``"gpt2"``).
Ignored when *model_name* is provided.
"""
Expand Down Expand Up @@ -174,7 +174,7 @@ class DataStorage:

Args:
embedding_function: A LangChain-compatible ``Embeddings`` instance
root_path: Optional directory for persisted embeddings.
root_path: Optional directory for persisted embeddings.
engine: The vector-store class to use.

"""
Expand Down Expand Up @@ -278,7 +278,7 @@ class DocumentQAEngine:
Args:
llm: A LangChain chat model (e.g. ``ChatOpenAI``).
data_storage: A `DataStorage` instance for managing embeddings.
grobid_url: URL of the GROBID server.
grobid_url: URL of the GROBID server.
memory: Optional ``ConversationBufferMemory`` for multi-turn context.

"""
Expand All @@ -297,7 +297,8 @@ def __init__(self,
llm,
data_storage: DataStorage,
grobid_url=None,
memory=None
memory=None,
ping_grobid_server: bool = True
):

self.llm = llm
Expand All @@ -307,7 +308,7 @@ def __init__(self,
self.data_storage = data_storage

if grobid_url:
self.grobid_processor = GrobidProcessor(grobid_url)
self.grobid_processor = GrobidProcessor(grobid_url, ping_server=ping_grobid_server)

def query_document(
self,
Expand All @@ -317,7 +318,7 @@ def query_document(
context_size=4,
extraction_schema=None,
verbose=False
) -> tuple[Any, str]:
) -> tuple[Any, str, list]:
"""Ask a question and get an LLM-generated answer.

Retrieves the most relevant chunks from the vector store, feeds
Expand Down Expand Up @@ -354,7 +355,7 @@ def query_document(

if output_parser:
try:
return self._parse_json(response, output_parser), response
return self._parse_json(response, output_parser), response, coordinates
except Exception as oe:
print("Failing to parse the response", oe)
return None, response, coordinates
Expand All @@ -369,7 +370,7 @@ def query_document(
else:
return None, response, coordinates

def query_storage(self, query: str, doc_id, context_size=4) -> tuple[List[Document], list]:
def query_storage(self, query: str, doc_id, context_size=4) -> tuple[List[str], list]:
"""Retrieve relevant text passages without calling the LLM.

Useful for debugging which chunks would be used as context, or for
Expand Down Expand Up @@ -480,7 +481,7 @@ def _parse_json(self, response, output_parser):

return parsed_output

def _run_query(self, doc_id, query, context_size=4) -> tuple[List[Document], list]:
def _run_query(self, doc_id, query, context_size=4) -> tuple[Any, list]:
relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size)
response = self.chain.invoke({"context": relevant_documents, "question": query})
return response, relevant_document_coordinates
Expand Down Expand Up @@ -550,7 +551,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
biblio['filename'] = filename.replace(" ", "_")

if verbose:
print("Generating embeddings for:", hash, ", filename: ", filename)
print("Generating embeddings for filename: ", filename)

texts = []
metadatas = []
Expand Down
72 changes: 60 additions & 12 deletions document_qa/grobid_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,19 @@

import dateparser
import grobid_tei_xml
import requests
from bs4 import BeautifulSoup
from grobid_client.grobid_client import GrobidClient


class GrobidServiceError(RuntimeError):
"""Raised when the Grobid service fails to process a document."""

def __init__(self, message="Grobid service error", status_code=None):
super().__init__(message)
self.status_code = status_code


def get_span_start(type, title=None):
"""Return an opening ``<span>`` tag for an annotation of the given *type*."""
title_ = ' title="' + title + '"' if title is not None else ""
Expand Down Expand Up @@ -168,22 +177,61 @@ def process_structure(self, input_path, coordinates=False):

Returns ``None`` if GROBID returns a non-200 status.
"""
pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
input_path,
consolidate_header=True,
consolidate_citations=False,
segment_sentences=False,
tei_coordinates=coordinates,
include_raw_citations=False,
include_raw_affiliations=False,
generateIDs=True)
try:
pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
input_path,
consolidate_header=True,
consolidate_citations=False,
segment_sentences=False,
tei_coordinates=coordinates,
include_raw_citations=False,
include_raw_affiliations=False,
generateIDs=True)
except requests.exceptions.RequestException as exc:
# Transport-level failure (connection refused, timeout, …).
# Local/usage errors (bad path, parsing bugs) are intentionally
# not caught here so they surface with their real traceback.
raise GrobidServiceError("Grobid service did not respond.") from exc

if status != 200:
return
# Grobid attaches a human-readable reason to error responses
# (e.g. a 500 body explaining what went wrong). Surface it
# alongside the status code instead of discarding it.
reason = text.strip() if text else ""
message = f"Grobid service returned status {status}."
if reason:
message += f" {reason}"
raise GrobidServiceError(message, status_code=status)

# Grobid can answer 200 with an empty body (e.g. it gave up on the PDF).
if not text or not text.strip():
raise GrobidServiceError(
"Grobid returned an empty response.",
status_code=status
)

# A truncated/corrupted TEI payload makes the XML parser blow up; map
# that to a clear service error instead of an opaque parsing traceback.
try:
document_object = self.parse_grobid_xml(text, coordinates=coordinates)
except GrobidServiceError:
raise
except Exception as exc:
raise GrobidServiceError(
"Grobid returned a malformed or truncated response.",
status_code=status
) from exc

document_object = self.parse_grobid_xml(text, coordinates=coordinates)
document_object['filename'] = Path(pdf_file).stem.replace(".tei", "")

# Well-formed XML can still carry no usable text (e.g. an image-only or
# truncated PDF). Nothing to embed downstream, so fail loudly here.
if not any(passage.get('text', '').strip() for passage in document_object.get('passages', [])):
raise GrobidServiceError(
"Grobid returned a document with no extractable text.",
status_code=status
)

return document_object

def process_single(self, input_file):
Expand Down Expand Up @@ -221,7 +269,7 @@ def parse_grobid_xml(self, text, coordinates=False):
try:
year = dateparser.parse(doc_biblio.header.date).year
biblio["publication_year"] = year
except:
except Exception:
pass

output_data['biblio'] = biblio
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Grobid
grobid-quantities-client==0.4.0
grobid-client-python==0.0.9
grobid-client-python==0.1.4
grobid-tei-xml==0.1.3

# Utils
Expand Down Expand Up @@ -30,6 +30,6 @@ typing-inspect==0.9.0
typing_extensions==4.12.2
pydantic==2.10.6
sentence-transformers==2.6.1
streamlit-pdf-viewer==0.0.25
streamlit-pdf-viewer==0.0.29
umap-learn==0.5.6
plotly==5.20.0
58 changes: 40 additions & 18 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@
from tempfile import NamedTemporaryFile

import dotenv
import streamlit as st
from grobid_quantities.quantities import QuantitiesAPI
from langchain.memory import ConversationBufferMemory
from langchain_openai import ChatOpenAI
from streamlit_pdf_viewer import pdf_viewer

from document_qa.custom_embeddings import ModalEmbeddings
from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations, GrobidServiceError
from document_qa.ner_client_generic import NERClientGeneric

dotenv.load_dotenv(override=True)

import streamlit as st
from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations

API_MODELS = {
"microsoft/Phi-4-mini-instruct": os.environ["PHI_URL"],
"Qwen/Qwen3-0.6B": os.environ["QWEN_URL"]
Expand Down Expand Up @@ -169,7 +168,13 @@ def init_qa(model_name, embeddings_name):
)

storage = DataStorage(embeddings)
return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
return DocumentQAEngine(
chat,
storage,
grobid_url=os.environ['GROBID_URL'],
memory=st.session_state['memory'],
ping_grobid_server=False
)


@st.cache_resource
Expand Down Expand Up @@ -358,19 +363,36 @@ def play_old_messages(container):
st.stop()

with left_column:
with st.spinner('Reading file, calling Grobid, and creating in-memory embeddings...'):
binary = uploaded_file.getvalue()
tmp_file = NamedTemporaryFile()
tmp_file.write(bytearray(binary))
st.session_state['binary'] = binary

st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(
tmp_file.name,
chunk_size=chunk_size,
perc_overlap=0.1
)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []
try:
with st.spinner('Reading file, calling Grobid, and creating in-memory embeddings...'):
binary = uploaded_file.getvalue()
tmp_path = None
try:
with NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
tmp_file.write(bytearray(binary))
tmp_file.flush()
tmp_path = tmp_file.name
st.session_state['binary'] = binary

st.session_state['doc_id'] = st.session_state['rqa'][model].create_memory_embeddings(
tmp_path,
chunk_size=chunk_size,
perc_overlap=0.1
)
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []
except GrobidServiceError as exc:
st.session_state['doc_id'] = None
st.session_state['loaded_embeddings'] = False
st.session_state['uploaded'] = False
message = str(exc).strip() or "Grobid is not responding."
if not message.endswith((".", "!", "?")):
message += "."
st.error(f"{message} Please try again later.")
st.stop()


def rgb_to_hex(rgb):
Expand Down
Loading