From 149fa627f90ef18f0499331c1fbacc9e754f17e9 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Sun, 7 Jun 2026 21:30:03 +0100 Subject: [PATCH 1/5] feat: add additional script for en-only embeddings --- document_qa/deployment/modal_embeddings_en.py | 117 ++++++++++++++++++ ...dings.py => modal_embeddings_multilang.py} | 0 2 files changed, 117 insertions(+) create mode 100644 document_qa/deployment/modal_embeddings_en.py rename document_qa/deployment/{modal_embeddings.py => modal_embeddings_multilang.py} (100%) diff --git a/document_qa/deployment/modal_embeddings_en.py b/document_qa/deployment/modal_embeddings_en.py new file mode 100644 index 0000000..19ff4e1 --- /dev/null +++ b/document_qa/deployment/modal_embeddings_en.py @@ -0,0 +1,117 @@ +import os +from typing import Annotated, List +from fastapi import Request, HTTPException, Form + +import modal +import torch +import torch.nn.functional as F +from torch import Tensor +from transformers import AutoTokenizer, AutoModel + +image = ( + modal.Image.debian_slim(python_version="3.11") + .pip_install( + "transformers", + "huggingface_hub[hf_transfer]==0.26.2", + "flashinfer-python==0.2.0.post2", # pinning, very unstable + "fastapi[standard]", + extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers +) + +MODELS_DIR = "/llamas" +MODEL_NAME = "intfloat/e5-large-v2" +MODEL_REVISION = "756b8ddb6e4bda943d3b6f5d131355825efda70c" + +hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) +vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) + +app = modal.App("intfloat-e5-large-v2-embeddings") + + +def get_device(): + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def load_model(): + print("Loading model...") + device = get_device() + print(f"Using device: {device}") + + tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct') + model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct').to(device) + print("Model loaded successfully.") + + return tokenizer, model, device + + +N_GPU = 1 +MINUTES = 60 # seconds +VLLM_PORT = 8000 + + +def average_pool(last_hidden_states: Tensor, + attention_mask: Tensor) -> Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +@app.function( + image=image, + gpu=f"L40S:{N_GPU}", + # gpu=f"A10G:{N_GPU}", + # how long should we stay up with no requests? + scaledown_window=3 * MINUTES, + volumes={ + "/root/.cache/huggingface": hf_cache_vol, + "/root/.cache/vllm": vllm_cache_vol, + }, + secrets=[modal.Secret.from_name("document-qa-embedding-key")] +) +@modal.concurrent( + max_inputs=5 +) # how many requests can one replica handle? tune carefully! +@modal.fastapi_endpoint(method="POST") +def embed(request: Request, text: Annotated[str, Form()]): + api_key = request.headers.get("x-api-key") + expected_key = os.environ["API_KEY"] + + if api_key != expected_key: + raise HTTPException(status_code=401, detail="Unauthorized") + + + texts = [t for t in text.split("\n") if t.strip()] + if not texts: + return [] + + tokenizer, model, device = load_model() + model.eval() + + print(f"Start embedding {len(texts)} texts") + try: + with torch.no_grad(): + # Move inputs to the same device as model + batch_dict = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') + batch_dict = {k: v.to(device) for k, v in batch_dict.items()} + + # Forward pass + outputs = model(**batch_dict) + + # Process embeddings + embeddings = average_pool( + outputs.last_hidden_state, + batch_dict['attention_mask'] + ) + embeddings = F.normalize(embeddings, p=2, dim=1) + + # Move to CPU and convert to list for serialization + embeddings = embeddings.cpu().numpy().tolist() + + print("Finished embedding texts.") + return embeddings + + except RuntimeError as e: + print(f"Error during embedding: {str(e)}") + if "CUDA out of memory" in str(e): + print("CUDA out of memory error. Try reducing batch size or using a smaller model.") + raise diff --git a/document_qa/deployment/modal_embeddings.py b/document_qa/deployment/modal_embeddings_multilang.py similarity index 100% rename from document_qa/deployment/modal_embeddings.py rename to document_qa/deployment/modal_embeddings_multilang.py From d9a0f88a9dd6383c3634a30fad9c72c229d9d121 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Sun, 7 Jun 2026 22:15:20 +0100 Subject: [PATCH 2/5] feat: switch from flake8 to ruff --- .env.example | 27 +++++++++++++-------------- .github/workflows/ci-build.yml | 10 +++++----- .github/workflows/ci-release.yml | 10 +++++----- README.md | 4 +++- docs/README.md | 7 +++++++ pyproject.toml | 10 +++++++++- 6 files changed, 42 insertions(+), 26 deletions(-) diff --git a/.env.example b/.env.example index eaf4bb2..3a1382e 100644 --- a/.env.example +++ b/.env.example @@ -1,18 +1,17 @@ -PHI_URL=.... -QWEN_URL=... +# ── LLM endpoints (OpenAI-compatible vLLM servers on Modal) ── +PHI_URL=https://--phi-4-mini-instruct-qa-vllm-serve.modal.run/v1 +QWEN_URL=https://--qwen-0-6b-qa-vllm-serve.modal.run/v1 +API_KEY=your-llm-api-key -EMBEDS_URL=... +# ── Embedding endpoint ─────────────────────────────────────── +EMBEDS_URL=https://--intfloat-multilingual-e5-large-instruct-embeddings-embed.modal.run +EMBEDS_API_KEY=your-embedding-api-key + +# ── Defaults pre-selected in the UI ────────────────────────── DEFAULT_MODEL=microsoft/Phi-4-mini-instruct DEFAULT_EMBEDDING=intfloat/multilingual-e5-large-instruct-modal -API_KEY=... -EMBEDS_API_KEY=... - -GROBID_URL=... -GROBID_QUANTITIES_URL=... - - -QWEN_URL=... -GROBID_MATERIALS_URL=... -API_KEY=... -EMBEDS_API_KEY=... \ No newline at end of file +# ── GROBID services ────────────────────────────────────────── +GROBID_URL=https://your-grobid-url +GROBID_QUANTITIES_URL=https://your-grobid-quantities-url/ # optional (measurements NER) +GROBID_MATERIALS_URL=https://your-grobid-superconductors-url/ # optional (materials NER) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 3a40f63..207a4cf 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -31,14 +31,14 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install --upgrade flake8 pytest pycodestyle pytest-cov huggingface_hub + pip install --upgrade ruff pytest pytest-cov huggingface_hub if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 + - name: Lint with ruff run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + ruff check --select=E9,F63,F7,F82 --output-format=full . + # non-blocking: report all remaining style issues without failing the build + ruff check --exit-zero --statistics . - name: Test with pytest run: | pytest diff --git a/.github/workflows/ci-release.yml b/.github/workflows/ci-release.yml index f29a43b..20b9d1b 100644 --- a/.github/workflows/ci-release.yml +++ b/.github/workflows/ci-release.yml @@ -25,14 +25,14 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install --upgrade flake8 pytest pycodestyle + pip install --upgrade ruff pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 + - name: Lint with ruff run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + ruff check --select=E9,F63,F7,F82 --output-format=full . + # non-blocking: report all remaining style issues without failing the build + ruff check --exit-zero --statistics . # - name: Test with pytest # run: | # pytest diff --git a/README.md b/README.md index 1ea14bf..8324d22 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,8 @@ Additionally, this frontend provides the visualisation of named entities on LLM **For full technical documentation** of the `document-qa-engine` library **[`docs/README.md`](docs/README.md)**. + **To deploy the LLM and embedding endpoints** on Modal.com, see **[`document_qa/deployment/README.md`](document_qa/deployment/README.md)**. + ### Embedding selection In the latest version, there is the possibility to select both embedding functions and LLMs. There are some limitations, OpenAI embeddings cannot be used with open source models, and vice-versa. @@ -83,7 +85,7 @@ For more information, see the [details](https://docs.trychroma.com/troubleshooti Please read carefully: - Avoid uploading sensitive data. We temporarily store text from the uploaded PDF documents only for processing your request, and we disclaim any responsibility for subsequent use or handling of the submitted data by third-party LLMs. -- Mistral and Zephyr are FREE to use and do not require any API, but as we leverage the free API entrypoint, there is no guarantee that all requests will go through. Use at your own risk. +- The public demo serves open models (Phi-4-mini-instruct, Qwen3) self-hosted on [Modal.com](https://www.modal.com) under a limited monthly compute budget, so there is no guarantee that all requests will go through. Use at your own risk. - We do not assume responsibility for how the data is utilized by the LLM end-points API. ## Development notes diff --git a/docs/README.md b/docs/README.md index 67048b2..bac6f4e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -97,6 +97,13 @@ GROBID_MATERIALS_URL=https://your-grobid-superconductors-url/ | `GROBID_QUANTITIES_URL` | URL to a grobid-quantities server (for measurement NER) | | `GROBID_MATERIALS_URL` | URL to a grobid-superconductors server (for materials NER) | +### Deploying the model endpoints + +The `PHI_URL`, `QWEN_URL`, and `EMBEDS_URL` endpoints above are served by the Modal apps +in [`../document_qa/deployment/`](../document_qa/deployment/README.md). That README covers +the required secrets, deploy commands, and how each printed `*.modal.run` URL maps back to +these variables. + --- ## Quick Start — Streamlit App diff --git a/pyproject.toml b/pyproject.toml index 1a42ed5..2df49bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,12 @@ dependencies = {file = ["requirements.txt"]} [project.urls] Homepage = "https://document-insights.streamlit.app" Repository = "https://github.com/lfoppiano/document-qa" -Changelog = "https://github.com/lfoppiano/document-qa/blob/main/CHANGELOG.md" \ No newline at end of file +Changelog = "https://github.com/lfoppiano/document-qa/blob/main/CHANGELOG.md" + +[tool.ruff] +# Mirrors the previous flake8 configuration (line length 127, GitHub editor width). +line-length = 127 + +[tool.ruff.lint] +# pycodestyle errors (E) + pyflakes (F), matching the old flake8 default rule set. +select = ["E", "F"] \ No newline at end of file From 96674b8ded20c4fb78cdba26f9911f31edba8223 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Sun, 7 Jun 2026 22:15:41 +0100 Subject: [PATCH 3/5] refactor: embedding deployment scripts --- document_qa/custom_embeddings.py | 10 +- document_qa/deployment/README.md | 104 ++++++++++++++ document_qa/deployment/_embeddings_app.py | 120 ++++++++++++++++ document_qa/deployment/modal_embeddings_en.py | 131 ++++-------------- .../deployment/modal_embeddings_multilang.py | 131 ++++-------------- document_qa/deployment/modal_inference_phi.py | 2 +- .../deployment/modal_inference_qwen.py | 5 +- 7 files changed, 286 insertions(+), 217 deletions(-) create mode 100644 document_qa/deployment/README.md create mode 100644 document_qa/deployment/_embeddings_app.py diff --git a/document_qa/custom_embeddings.py b/document_qa/custom_embeddings.py index 893986d..abe1852 100644 --- a/document_qa/custom_embeddings.py +++ b/document_qa/custom_embeddings.py @@ -92,9 +92,15 @@ def get_model_name(self) -> str: if __name__ == "__main__": + # Smoke test against a deployed Modal embedding endpoint. The endpoint requires + # the x-api-key header, so set EMBEDS_URL and EMBEDS_API_KEY in the environment + # (see document_qa/deployment/README.md). + import os + embeds = ModalEmbeddings( - url="https://lfoppiano--intfloat-multilingual-e5-large-instruct-embed-5da184.modal.run/", - model_name="intfloat/multilingual-e5-large-instruct" + url=os.environ["EMBEDS_URL"], + model_name="intfloat/multilingual-e5-large-instruct", + api_key=os.environ.get("EMBEDS_API_KEY"), ) print(embeds.embed( diff --git a/document_qa/deployment/README.md b/document_qa/deployment/README.md new file mode 100644 index 0000000..3a7f497 --- /dev/null +++ b/document_qa/deployment/README.md @@ -0,0 +1,104 @@ +# Modal deployment scripts + +This folder contains the [Modal](https://modal.com) apps that serve the LLM and +embedding endpoints used by document-qa. Each script is an independent Modal app: +deploy the ones you need, then point the matching `.env` variables at the URLs +Modal prints. + +| Script | Modal app | Serves | Maps to `.env` | +|--------|-----------|--------|----------------| +| `modal_inference_phi.py` | `phi-4-mini-instruct-qa-vllm` | `microsoft/Phi-4-mini-instruct` (vLLM, OpenAI-compatible) | `PHI_URL` | +| `modal_inference_qwen.py` | `qwen-0.6b-qa-vllm` | `Qwen/Qwen3-0.6B` (vLLM, reasoning) | `QWEN_URL` | +| `modal_embeddings_multilang.py` | `intfloat-multilingual-e5-large-instruct-embeddings` | `intfloat/multilingual-e5-large-instruct` | `EMBEDS_URL` | +| `modal_embeddings_en.py` | `intfloat-e5-large-v2-embeddings` | `intfloat/e5-large-v2` (English-only) | `EMBEDS_URL` | + +> Both embedding scripts define a tiny global `EmbeddingModel` class that delegates +> to the shared helpers in `_embeddings_app.py` (`cls_kwargs`, `load_embedding_model`, +> `run_embed`). The shared module holds the container image and the embedding logic; +> the model is loaded **once per container** via `@modal.enter()`. To add another +> embedding model, copy one wrapper and change `MODEL_NAME` / `MODEL_REVISION` / the +> app name. + +## Prerequisites + +```bash +pip install modal +modal token new # one-time browser auth +``` + +## Secrets + +The scripts read an `API_KEY` from a Modal [Secret](https://modal.com/docs/guide/secrets). +Create the two secrets once (the value is the bearer token clients must send): + +```bash +# Used by the inference scripts (phi, qwen) +modal secret create document-qa-api-key API_KEY= + +# Used by the embedding scripts +modal secret create document-qa-embedding-key API_KEY= +``` + +| Secret | Used by | Provides | +|--------|---------|----------| +| `document-qa-api-key` | `modal_inference_phi.py`, `modal_inference_qwen.py` | `API_KEY` for the vLLM `--api-key` flag | +| `document-qa-embedding-key` | `modal_embeddings_*.py` | `API_KEY` checked against the `x-api-key` header | + +## Deploy + +```bash +modal deploy document_qa/deployment/modal_inference_phi.py +modal deploy document_qa/deployment/modal_inference_qwen.py +modal deploy document_qa/deployment/modal_embeddings_multilang.py +# modal deploy document_qa/deployment/modal_embeddings_en.py # optional English-only +``` + +Each deploy prints a public `https://<...>.modal.run` URL. Copy it into `.env`: + +```env +PHI_URL=https://--phi-4-mini-instruct-qa-vllm-serve.modal.run/v1 +QWEN_URL=https://--qwen-0-6b-qa-vllm-serve.modal.run/v1 +EMBEDS_URL=https://--intfloat-multilingual-e5-large-instruct-embeddings-embed.modal.run +API_KEY= # matches document-qa-api-key +EMBEDS_API_KEY= # matches document-qa-embedding-key +``` + +> **Inference endpoints** are OpenAI-compatible vLLM servers, so their URLs end in +> `/v1`. **Embedding endpoints** are a custom form endpoint (see below), so their +> URL has no `/v1` suffix. + +## Endpoint contracts + +### Inference (vLLM) + +Standard OpenAI Chat Completions API at ``, authenticated with the +`Authorization: Bearer ` header. Used by `langchain_openai.ChatOpenAI` in +`streamlit_app.py`. + +### Embeddings + +A custom `POST` endpoint consumed by +[`ModalEmbeddings`](../custom_embeddings.py): + +- **Auth**: `x-api-key: ` header. +- **Body**: form field `text` with newline-separated strings. +- **Response**: JSON list of L2-normalised vectors, one per input line. + +Smoke test: + +```bash +curl -X POST "$EMBEDS_URL" \ + -H "x-api-key: $EMBEDS_API_KEY" \ + -F $'text=first sentence\nsecond sentence' +``` + +## Tuning + +These knobs live near the top of each script (or in `_embeddings_app.py`): + +| Setting | Where | Notes | +|---------|-------|-------| +| `gpu` | `@app.function` / `@app.cls` | `A10G` is cheaper; `L40S` is faster. Embeddings default to `L40S`, inference to `A10G`. | +| `scaledown_window` | decorator | Idle time before a replica is stopped (cost vs. cold starts). | +| `max_inputs` | `@modal.concurrent` | Concurrent requests per replica — tune to GPU memory. | +| `FAST_BOOT` | `modal_inference_phi.py` | `--enforce-eager` for faster cold starts vs. peak throughput. | diff --git a/document_qa/deployment/_embeddings_app.py b/document_qa/deployment/_embeddings_app.py new file mode 100644 index 0000000..caf00b0 --- /dev/null +++ b/document_qa/deployment/_embeddings_app.py @@ -0,0 +1,120 @@ +"""Shared building blocks for the Modal embedding endpoints. + +``modal_embeddings_en.py`` and ``modal_embeddings_multilang.py`` each define a tiny +``EmbeddingModel`` class at module scope (Modal requires globally-defined classes +with stacked ``@app.cls`` / ``@modal.concurrent`` decorators) that delegates to the +helpers here. All the heavy lifting — the container image, model loading, pooling, +and the embedding request handler — lives in this module so it is written once. + +The endpoint contract (consumed by ``document_qa.custom_embeddings.ModalEmbeddings``): + +- **Method**: ``POST`` +- **Auth**: ``x-api-key`` header, compared against the ``API_KEY`` secret. +- **Body**: form field ``text`` containing newline-separated strings. +- **Response**: JSON list of L2-normalised embedding vectors, one per input line. +""" + +import os + +import modal +import torch +import torch.nn.functional as F +from fastapi import HTTPException, Request +from torch import Tensor + +MINUTES = 60 # seconds +N_GPU = 1 + +# Shared container image for every embedding model. +image = ( + modal.Image.debian_slim(python_version="3.11") + .pip_install( + "transformers", + "huggingface_hub[hf_transfer]==0.26.2", + "flashinfer-python==0.2.0.post2", # pinning, very unstable + "fastapi[standard]", + extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers + # Modal 1.0 no longer auto-mounts imported local modules; the wrapper scripts + # import this module by name, so it must be added explicitly. Kept last so it + # doesn't invalidate the (expensive) pip layer above on every code edit. + .add_local_python_source("_embeddings_app") +) + +hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) +vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) + + +def cls_kwargs() -> dict: + """Common ``@app.cls`` configuration shared by every embedding endpoint.""" + return dict( + image=image, + gpu=f"L40S:{N_GPU}", + # how long should we stay up with no requests? + scaledown_window=3 * MINUTES, + volumes={ + "/root/.cache/huggingface": hf_cache_vol, + "/root/.cache/vllm": vllm_cache_vol, + }, + secrets=[modal.Secret.from_name("document-qa-embedding-key")], + ) + + +def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + """Mean-pool token embeddings, ignoring padding positions.""" + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def load_embedding_model(model_name: str, model_revision: str): + """Load a tokenizer + model onto the best available device, once per container. + + Returns: + tuple: ``(tokenizer, model, device)`` with ``model`` already in eval mode. + """ + # transformers is only available inside the Modal image, so import lazily. + from transformers import AutoModel, AutoTokenizer + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Loading {model_name} on {device}...") + tokenizer = AutoTokenizer.from_pretrained(model_name, revision=model_revision) + model = AutoModel.from_pretrained(model_name, revision=model_revision).to(device) + model.eval() + print("Model loaded successfully.") + return tokenizer, model, device + + +def run_embed(tokenizer, model, device, request: Request, text: str): + """Authenticate, embed newline-separated ``text``, and return normalised vectors.""" + api_key = request.headers.get("x-api-key") + if api_key != os.environ["API_KEY"]: + raise HTTPException(status_code=401, detail="Unauthorized") + + texts = [t for t in text.split("\n") if t.strip()] + if not texts: + return [] + + print(f"Start embedding {len(texts)} texts") + try: + with torch.no_grad(): + batch_dict = tokenizer( + texts, padding=True, truncation=True, return_tensors="pt" + ) + batch_dict = {k: v.to(device) for k, v in batch_dict.items()} + + outputs = model(**batch_dict) + embeddings = average_pool( + outputs.last_hidden_state, batch_dict["attention_mask"] + ) + embeddings = F.normalize(embeddings, p=2, dim=1) + embeddings = embeddings.cpu().numpy().tolist() + + print("Finished embedding texts.") + return embeddings + + except RuntimeError as e: + print(f"Error during embedding: {str(e)}") + if "CUDA out of memory" in str(e): + print("CUDA OOM. Try reducing batch size or using a smaller model.") + raise diff --git a/document_qa/deployment/modal_embeddings_en.py b/document_qa/deployment/modal_embeddings_en.py index 19ff4e1..6ccd126 100644 --- a/document_qa/deployment/modal_embeddings_en.py +++ b/document_qa/deployment/modal_embeddings_en.py @@ -1,117 +1,36 @@ -import os -from typing import Annotated, List -from fastapi import Request, HTTPException, Form +"""Modal deployment: English-only embeddings (``intfloat/e5-large-v2``). -import modal -import torch -import torch.nn.functional as F -from torch import Tensor -from transformers import AutoTokenizer, AutoModel - -image = ( - modal.Image.debian_slim(python_version="3.11") - .pip_install( - "transformers", - "huggingface_hub[hf_transfer]==0.26.2", - "flashinfer-python==0.2.0.post2", # pinning, very unstable - "fastapi[standard]", - extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", - ) - .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers -) - -MODELS_DIR = "/llamas" -MODEL_NAME = "intfloat/e5-large-v2" -MODEL_REVISION = "756b8ddb6e4bda943d3b6f5d131355825efda70c" - -hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) -vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) - -app = modal.App("intfloat-e5-large-v2-embeddings") +Deploy with:: + modal deploy document_qa/deployment/modal_embeddings_en.py -def get_device(): - return torch.device('cuda' if torch.cuda.is_available() else 'cpu') +See ``document_qa/deployment/README.md`` for secrets, tuning, and how the +resulting URL maps to ``EMBEDS_URL`` in ``.env``. The shared logic lives in +``_embeddings_app.py``. +""" -def load_model(): - print("Loading model...") - device = get_device() - print(f"Using device: {device}") - - tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct') - model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct').to(device) - print("Model loaded successfully.") +from typing import Annotated - return tokenizer, model, device - - -N_GPU = 1 -MINUTES = 60 # seconds -VLLM_PORT = 8000 - - -def average_pool(last_hidden_states: Tensor, - attention_mask: Tensor) -> Tensor: - last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] +import modal +from fastapi import Form, Request +from _embeddings_app import cls_kwargs, load_embedding_model, run_embed -@app.function( - image=image, - gpu=f"L40S:{N_GPU}", - # gpu=f"A10G:{N_GPU}", - # how long should we stay up with no requests? - scaledown_window=3 * MINUTES, - volumes={ - "/root/.cache/huggingface": hf_cache_vol, - "/root/.cache/vllm": vllm_cache_vol, - }, - secrets=[modal.Secret.from_name("document-qa-embedding-key")] -) -@modal.concurrent( - max_inputs=5 -) # how many requests can one replica handle? tune carefully! -@modal.fastapi_endpoint(method="POST") -def embed(request: Request, text: Annotated[str, Form()]): - api_key = request.headers.get("x-api-key") - expected_key = os.environ["API_KEY"] +MODEL_NAME = "intfloat/e5-large-v2" +MODEL_REVISION = "756b8ddb6e4bda943d3b6f5d131355825efda70c" - if api_key != expected_key: - raise HTTPException(status_code=401, detail="Unauthorized") +app = modal.App("intfloat-e5-large-v2-embeddings") - texts = [t for t in text.split("\n") if t.strip()] - if not texts: - return [] - - tokenizer, model, device = load_model() - model.eval() +@app.cls(**cls_kwargs()) +@modal.concurrent(max_inputs=5) # requests per replica; tune carefully! +class EmbeddingModel: + @modal.enter() + def load_model(self): + self.tokenizer, self.model, self.device = load_embedding_model( + MODEL_NAME, MODEL_REVISION + ) - print(f"Start embedding {len(texts)} texts") - try: - with torch.no_grad(): - # Move inputs to the same device as model - batch_dict = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') - batch_dict = {k: v.to(device) for k, v in batch_dict.items()} - - # Forward pass - outputs = model(**batch_dict) - - # Process embeddings - embeddings = average_pool( - outputs.last_hidden_state, - batch_dict['attention_mask'] - ) - embeddings = F.normalize(embeddings, p=2, dim=1) - - # Move to CPU and convert to list for serialization - embeddings = embeddings.cpu().numpy().tolist() - - print("Finished embedding texts.") - return embeddings - - except RuntimeError as e: - print(f"Error during embedding: {str(e)}") - if "CUDA out of memory" in str(e): - print("CUDA out of memory error. Try reducing batch size or using a smaller model.") - raise + @modal.fastapi_endpoint(method="POST") + def embed(self, request: Request, text: Annotated[str, Form()]): + return run_embed(self.tokenizer, self.model, self.device, request, text) diff --git a/document_qa/deployment/modal_embeddings_multilang.py b/document_qa/deployment/modal_embeddings_multilang.py index 47f2ac8..39a4a7c 100644 --- a/document_qa/deployment/modal_embeddings_multilang.py +++ b/document_qa/deployment/modal_embeddings_multilang.py @@ -1,117 +1,36 @@ -import os -from typing import Annotated, List -from fastapi import Request, HTTPException, Form +"""Modal deployment: multilingual embeddings (``intfloat/multilingual-e5-large-instruct``). -import modal -import torch -import torch.nn.functional as F -from torch import Tensor -from transformers import AutoTokenizer, AutoModel - -image = ( - modal.Image.debian_slim(python_version="3.11") - .pip_install( - "transformers", - "huggingface_hub[hf_transfer]==0.26.2", - "flashinfer-python==0.2.0.post2", # pinning, very unstable - "fastapi[standard]", - extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", - ) - .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers -) - -MODELS_DIR = "/llamas" -MODEL_NAME = "intfloat/multilingual-e5-large-instruct" -MODEL_REVISION = "84344a23ee1820ac951bc365f1e91d094a911763" - -hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) -vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) - -app = modal.App("intfloat-multilingual-e5-large-instruct-embeddings") +Deploy with:: + modal deploy document_qa/deployment/modal_embeddings_multilang.py -def get_device(): - return torch.device('cuda' if torch.cuda.is_available() else 'cpu') +See ``document_qa/deployment/README.md`` for secrets, tuning, and how the +resulting URL maps to ``EMBEDS_URL`` in ``.env``. The shared logic lives in +``_embeddings_app.py``. +""" -def load_model(): - print("Loading model...") - device = get_device() - print(f"Using device: {device}") - - tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct') - model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct').to(device) - print("Model loaded successfully.") +from typing import Annotated - return tokenizer, model, device - - -N_GPU = 1 -MINUTES = 60 # seconds -VLLM_PORT = 8000 - - -def average_pool(last_hidden_states: Tensor, - attention_mask: Tensor) -> Tensor: - last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] +import modal +from fastapi import Form, Request +from _embeddings_app import cls_kwargs, load_embedding_model, run_embed -@app.function( - image=image, - gpu=f"L40S:{N_GPU}", - # gpu=f"A10G:{N_GPU}", - # how long should we stay up with no requests? - scaledown_window=3 * MINUTES, - volumes={ - "/root/.cache/huggingface": hf_cache_vol, - "/root/.cache/vllm": vllm_cache_vol, - }, - secrets=[modal.Secret.from_name("document-qa-embedding-key")] -) -@modal.concurrent( - max_inputs=5 -) # how many requests can one replica handle? tune carefully! -@modal.fastapi_endpoint(method="POST") -def embed(request: Request, text: Annotated[str, Form()]): - api_key = request.headers.get("x-api-key") - expected_key = os.environ["API_KEY"] +MODEL_NAME = "intfloat/multilingual-e5-large-instruct" +MODEL_REVISION = "84344a23ee1820ac951bc365f1e91d094a911763" - if api_key != expected_key: - raise HTTPException(status_code=401, detail="Unauthorized") +app = modal.App("intfloat-multilingual-e5-large-instruct-embeddings") - texts = [t for t in text.split("\n") if t.strip()] - if not texts: - return [] - - tokenizer, model, device = load_model() - model.eval() +@app.cls(**cls_kwargs()) +@modal.concurrent(max_inputs=5) # requests per replica; tune carefully! +class EmbeddingModel: + @modal.enter() + def load_model(self): + self.tokenizer, self.model, self.device = load_embedding_model( + MODEL_NAME, MODEL_REVISION + ) - print(f"Start embedding {len(texts)} texts") - try: - with torch.no_grad(): - # Move inputs to the same device as model - batch_dict = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') - batch_dict = {k: v.to(device) for k, v in batch_dict.items()} - - # Forward pass - outputs = model(**batch_dict) - - # Process embeddings - embeddings = average_pool( - outputs.last_hidden_state, - batch_dict['attention_mask'] - ) - embeddings = F.normalize(embeddings, p=2, dim=1) - - # Move to CPU and convert to list for serialization - embeddings = embeddings.cpu().numpy().tolist() - - print("Finished embedding texts.") - return embeddings - - except RuntimeError as e: - print(f"Error during embedding: {str(e)}") - if "CUDA out of memory" in str(e): - print("CUDA out of memory error. Try reducing batch size or using a smaller model.") - raise + @modal.fastapi_endpoint(method="POST") + def embed(self, request: Request, text: Annotated[str, Form()]): + return run_embed(self.tokenizer, self.model, self.device, request, text) diff --git a/document_qa/deployment/modal_inference_phi.py b/document_qa/deployment/modal_inference_phi.py index 3bc95a7..dba7970 100644 --- a/document_qa/deployment/modal_inference_phi.py +++ b/document_qa/deployment/modal_inference_phi.py @@ -3,7 +3,7 @@ import modal vllm_image = ( - modal.Image.debian_slim(python_version="3.10") + modal.Image.debian_slim(python_version="3.11") .pip_install( "vllm", "huggingface_hub[hf_transfer]==0.26.2", diff --git a/document_qa/deployment/modal_inference_qwen.py b/document_qa/deployment/modal_inference_qwen.py index 1e4f0ee..bab93d5 100644 --- a/document_qa/deployment/modal_inference_qwen.py +++ b/document_qa/deployment/modal_inference_qwen.py @@ -22,7 +22,7 @@ vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) -app = modal.App("gwen-0.6b-qa-vllm") +app = modal.App("qwen-0.6b-qa-vllm") N_GPU = 1 MINUTES = 60 # seconds @@ -55,7 +55,8 @@ def serve(): MODEL_NAME, "--revision", MODEL_REVISION, - "--enable-reasoning", + # --reasoning-parser alone enables reasoning; the old --enable-reasoning + # flag was removed in recent vLLM releases. "--reasoning-parser", "deepseek_r1", "--max-model-len", From d3cfa2b82c8b46e5dff9e5068bcc6520981c8d07 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Sun, 7 Jun 2026 22:52:45 +0100 Subject: [PATCH 4/5] chore: code formatting --- document_qa/custom_embeddings.py | 18 +- document_qa/deployment/_embeddings_app.py | 8 +- document_qa/deployment/modal_embeddings_en.py | 4 +- .../deployment/modal_embeddings_multilang.py | 4 +- document_qa/deployment/modal_inference_phi.py | 8 +- .../deployment/modal_inference_qwen.py | 8 +- document_qa/document_qa_engine.py | 182 ++--- document_qa/grobid_processors.py | 662 +++++++++--------- document_qa/langchain.py | 75 +- document_qa/ner_client_generic.py | 201 ++---- streamlit_app.py | 338 +++++---- tests/conftest.py | 9 +- tests/test_document_qa_engine.py | 68 +- tests/test_grobid_processors.py | 38 +- 14 files changed, 732 insertions(+), 891 deletions(-) diff --git a/document_qa/custom_embeddings.py b/document_qa/custom_embeddings.py index abe1852..6f89307 100644 --- a/document_qa/custom_embeddings.py +++ b/document_qa/custom_embeddings.py @@ -47,18 +47,13 @@ def embed(self, text: List[str]) -> List[List[float]]: # Newlines degrade embedding quality for most models cleaned_text = [t.replace("\n", " ") for t in text] - payload = {'text': "\n".join(cleaned_text)} + payload = {"text": "\n".join(cleaned_text)} headers = {} if self.api_key: - headers = {'x-api-key': self.api_key} - - response = requests.post( - self.url, - data=payload, - files=[], - headers=headers - ) + headers = {"x-api-key": self.api_key} + + response = requests.post(self.url, data=payload, files=[], headers=headers) response.raise_for_status() # print(response.text) @@ -103,7 +98,4 @@ def get_model_name(self) -> str: api_key=os.environ.get("EMBEDS_API_KEY"), ) - print(embeds.embed( - ["We are surrounded by stupid kids", - "We are interested in the future of AI"] - )) + print(embeds.embed(["We are surrounded by stupid kids", "We are interested in the future of AI"])) diff --git a/document_qa/deployment/_embeddings_app.py b/document_qa/deployment/_embeddings_app.py index caf00b0..1bf20a8 100644 --- a/document_qa/deployment/_embeddings_app.py +++ b/document_qa/deployment/_embeddings_app.py @@ -98,15 +98,11 @@ def run_embed(tokenizer, model, device, request: Request, text: str): print(f"Start embedding {len(texts)} texts") try: with torch.no_grad(): - batch_dict = tokenizer( - texts, padding=True, truncation=True, return_tensors="pt" - ) + batch_dict = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") batch_dict = {k: v.to(device) for k, v in batch_dict.items()} outputs = model(**batch_dict) - embeddings = average_pool( - outputs.last_hidden_state, batch_dict["attention_mask"] - ) + embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) embeddings = F.normalize(embeddings, p=2, dim=1) embeddings = embeddings.cpu().numpy().tolist() diff --git a/document_qa/deployment/modal_embeddings_en.py b/document_qa/deployment/modal_embeddings_en.py index 6ccd126..2690957 100644 --- a/document_qa/deployment/modal_embeddings_en.py +++ b/document_qa/deployment/modal_embeddings_en.py @@ -27,9 +27,7 @@ class EmbeddingModel: @modal.enter() def load_model(self): - self.tokenizer, self.model, self.device = load_embedding_model( - MODEL_NAME, MODEL_REVISION - ) + self.tokenizer, self.model, self.device = load_embedding_model(MODEL_NAME, MODEL_REVISION) @modal.fastapi_endpoint(method="POST") def embed(self, request: Request, text: Annotated[str, Form()]): diff --git a/document_qa/deployment/modal_embeddings_multilang.py b/document_qa/deployment/modal_embeddings_multilang.py index 39a4a7c..bd6a808 100644 --- a/document_qa/deployment/modal_embeddings_multilang.py +++ b/document_qa/deployment/modal_embeddings_multilang.py @@ -27,9 +27,7 @@ class EmbeddingModel: @modal.enter() def load_model(self): - self.tokenizer, self.model, self.device = load_embedding_model( - MODEL_NAME, MODEL_REVISION - ) + self.tokenizer, self.model, self.device = load_embedding_model(MODEL_NAME, MODEL_REVISION) @modal.fastapi_endpoint(method="POST") def embed(self, request: Request, text: Annotated[str, Form()]): diff --git a/document_qa/deployment/modal_inference_phi.py b/document_qa/deployment/modal_inference_phi.py index dba7970..0f66d46 100644 --- a/document_qa/deployment/modal_inference_phi.py +++ b/document_qa/deployment/modal_inference_phi.py @@ -40,11 +40,9 @@ "/root/.cache/huggingface": hf_cache_vol, "/root/.cache/vllm": vllm_cache_vol, }, - secrets=[modal.Secret.from_name("document-qa-api-key")] + secrets=[modal.Secret.from_name("document-qa-api-key")], ) -@modal.concurrent( - max_inputs=5 -) # how many requests can one replica handle? tune carefully! +@modal.concurrent(max_inputs=5) # how many requests can one replica handle? tune carefully! @modal.web_server(port=VLLM_PORT, startup_timeout=5 * MINUTES) def serve(): import subprocess @@ -73,4 +71,4 @@ def serve(): # assume multiple GPUs are for splitting up large matrix multiplications cmd += ["--tensor-parallel-size", str(N_GPU)] - subprocess.Popen(" ".join(cmd), shell=True) \ No newline at end of file + subprocess.Popen(" ".join(cmd), shell=True) diff --git a/document_qa/deployment/modal_inference_qwen.py b/document_qa/deployment/modal_inference_qwen.py index bab93d5..ce63fbb 100644 --- a/document_qa/deployment/modal_inference_qwen.py +++ b/document_qa/deployment/modal_inference_qwen.py @@ -39,11 +39,9 @@ "/root/.cache/huggingface": hf_cache_vol, "/root/.cache/vllm": vllm_cache_vol, }, - secrets=[modal.Secret.from_name("document-qa-api-key")] + secrets=[modal.Secret.from_name("document-qa-api-key")], ) -@modal.concurrent( - max_inputs=5 -) # how many requests can one replica handle? tune carefully! +@modal.concurrent(max_inputs=5) # how many requests can one replica handle? tune carefully! @modal.web_server(port=VLLM_PORT, startup_timeout=5 * MINUTES) def serve(): import subprocess @@ -69,4 +67,4 @@ def serve(): os.environ["API_KEY"], ] - subprocess.Popen(" ".join(cmd), shell=True) \ No newline at end of file + subprocess.Popen(" ".join(cmd), shell=True) diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 0ac2048..25947aa 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -12,8 +12,7 @@ import tiktoken from langchain.chains import create_extraction_chain from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.chains.question_answering import stuff_prompt, refine_prompts, map_reduce_prompt, \ - map_rerank_prompt +from langchain.chains.question_answering import stuff_prompt, refine_prompts, map_reduce_prompt, map_rerank_prompt from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever from langchain.schema import Document @@ -89,8 +88,8 @@ def merge_passages(self, passages, chunk_size, tolerance=0.2): current_texts = [] current_coordinates = [] for idx, passage in enumerate(passages): - text = passage['text'] - coordinates = passage['coordinates'] + text = passage["text"] + coordinates = passage["coordinates"] current_texts.append(text) current_coordinates.append(coordinates) @@ -131,7 +130,7 @@ def merge_passages(self, passages, chunk_size, tolerance=0.2): "coordinates": coordinates, "type": "aggregated chunks", "section": "mixed", - "subSection": "mixed" + "subSection": "mixed", } ) @@ -139,14 +138,9 @@ def merge_passages(self, passages, chunk_size, tolerance=0.2): class BaseRetrieval: - """Abstract base for retrieval backends. - """ + """Abstract base for retrieval backends.""" - def __init__( - self, - persist_directory: Path, - embedding_function - ): + def __init__(self, persist_directory: Path, embedding_function): self.embedding_function = embedding_function self.persist_directory = persist_directory @@ -156,13 +150,11 @@ class NER_Retrival(VectorStore): This class implement a retrieval based on NER models. This is an alternative retrieval to embeddings that relies on extracted entities. """ + pass -engines = { - 'chroma': ChromaAdvancedRetrieval, - 'ner': NER_Retrival -} +engines = {"chroma": ChromaAdvancedRetrieval, "ner": NER_Retrival} class DataStorage: @@ -184,10 +176,10 @@ class DataStorage: embeddings_map_to_md5 = {} def __init__( - self, - embedding_function, - root_path: Path = None, - engine=ChromaAdvancedRetrieval, + self, + embedding_function, + root_path: Path = None, + engine=ChromaAdvancedRetrieval, ) -> None: self.root_path = root_path self.engine = engine @@ -214,11 +206,10 @@ def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: for embedding_document_dir in embeddings_directories: self.embeddings_dict[embedding_document_dir.name] = self.engine( - persist_directory=embedding_document_dir.path, - embedding_function=self.embedding_function + persist_directory=embedding_document_dir.path, embedding_function=self.embedding_function ) - filename_list = list(Path(embedding_document_dir).glob('*.storage_filename')) + filename_list = list(Path(embedding_document_dir).glob("*.storage_filename")) if filename_list: filenam = filename_list[0].name.replace(".storage_filename", "") self.embeddings_map_from_md5[embedding_document_dir.name] = filenam @@ -248,18 +239,14 @@ def embed_document(self, doc_id, texts, metadatas): """ if doc_id not in self.embeddings_dict.keys(): self.embeddings_dict[doc_id] = self.engine.from_texts( - texts, - embedding=self.embedding_function, - metadatas=metadatas, - collection_name=doc_id) + texts, embedding=self.embedding_function, metadatas=metadatas, collection_name=doc_id + ) else: # Workaround Chroma (?) breaking change self.embeddings_dict[doc_id].delete_collection() self.embeddings_dict[doc_id] = self.engine.from_texts( - texts, - embedding=self.embedding_function, - metadatas=metadatas, - collection_name=doc_id) + texts, embedding=self.embedding_function, metadatas=metadatas, collection_name=doc_id + ) self.embeddings_root_path = None @@ -287,23 +274,17 @@ class DocumentQAEngine: qa_chain_type = None default_prompts = { - 'stuff': stuff_prompt, - 'refine': refine_prompts, + "stuff": stuff_prompt, + "refine": refine_prompts, "map_reduce": map_reduce_prompt, - "map_rerank": map_rerank_prompt + "map_rerank": map_rerank_prompt, } - def __init__(self, - llm, - data_storage: DataStorage, - grobid_url=None, - memory=None, - ping_grobid_server: bool = True - ): + def __init__(self, llm, data_storage: DataStorage, grobid_url=None, memory=None, ping_grobid_server: bool = True): self.llm = llm self.memory = memory - self.chain = create_stuff_documents_chain(llm, self.default_prompts['stuff'].PROMPT) + self.chain = create_stuff_documents_chain(llm, self.default_prompts["stuff"].PROMPT) self.text_merger = TextMerger() self.data_storage = data_storage @@ -311,13 +292,7 @@ def __init__(self, self.grobid_processor = GrobidProcessor(grobid_url, ping_server=ping_grobid_server) def query_document( - self, - query: str, - doc_id, - output_parser=None, - context_size=4, - extraction_schema=None, - verbose=False + self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None, verbose=False ) -> tuple[Any, str, list]: """Ask a question and get an LLM-generated answer. @@ -348,7 +323,7 @@ def query_document( print(query) response, coordinates = self._run_query(doc_id, query, context_size=context_size) - response = response['output_text'] if 'output_text' in response else response + response = response["output_text"] if "output_text" in response else response if verbose: print(doc_id, "->", response) @@ -410,10 +385,7 @@ def query_storage_and_embeddings(self, query: str, doc_id, context_size=4) -> Li embedding metadata. """ db = self.data_storage.embeddings_dict[doc_id] - retriever = db.as_retriever( - search_kwargs={"k": context_size}, - search_type="similarity_with_embeddings" - ) + retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") relevant_documents = retriever.invoke(query) return relevant_documents @@ -440,10 +412,10 @@ def analyse_query(self, query, doc_id, context_size=4): # ) retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") relevant_documents = retriever.invoke(query) - relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] - for doc in - relevant_documents] - all_documents = db.get(include=['documents', 'metadatas', 'embeddings']) + relevant_document_coordinates = [ + doc.metadata["coordinates"].split(";") if "coordinates" in doc.metadata else [] for doc in relevant_documents + ] + all_documents = db.get(include=["documents", "metadatas", "embeddings"]) # all_documents_embeddings = all_documents["embeddings"] # query_embedding = db._embedding_function.embed_query(query) @@ -453,16 +425,21 @@ def analyse_query(self, query, doc_id, context_size=4): # distance_evaluator.evaluate_string_pairs(query=query_embedding, documents="") - similarities = [doc.metadata['__similarity'] for doc in relevant_documents] + similarities = [doc.metadata["__similarity"] for doc in relevant_documents] min_similarity = min(similarities) mean_similarity = sum(similarities) / len(similarities) coefficient = min_similarity - mean_similarity - return f"Coefficient: {coefficient}, (Min similarity {min_similarity}, Mean similarity: {mean_similarity})", relevant_document_coordinates + return ( + f"Coefficient: {coefficient}, (Min similarity {min_similarity}, Mean similarity: {mean_similarity})", + relevant_document_coordinates, + ) def _parse_json(self, response, output_parser): - system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \ - "that can process text and transform it to JSON." + system_message = ( + "You are an useful assistant expert in materials science, physics, and chemistry " + "that can process text and transform it to JSON." + ) human_message = """Transform the text between three double quotes in JSON.\n\n\n\n {format_instructions}\n\nText: \"\"\"{text}\"\"\"""" @@ -473,8 +450,7 @@ def _parse_json(self, response, output_parser): results = self.llm( prompt_template.format_prompt( - text=response, - format_instructions=output_parser.get_format_instructions() + text=response, format_instructions=output_parser.get_format_instructions() ).to_messages() ) parsed_output = output_parser.parse(results.content) @@ -491,15 +467,15 @@ def _get_context(self, doc_id, query, context_size=4) -> tuple[List[Document], l retriever = db.as_retriever(search_kwargs={"k": context_size}) relevant_documents = retriever.invoke(query) relevant_document_coordinates = [ - doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] - for doc in - relevant_documents + doc.metadata["coordinates"].split(";") if "coordinates" in doc.metadata else [] for doc in relevant_documents ] if self.memory and len(self.memory.buffer_as_messages) > 0: relevant_documents.append( Document( page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format( - self.memory.buffer_as_str)) + self.memory.buffer_as_str + ) + ) ) return relevant_documents, relevant_document_coordinates @@ -509,7 +485,7 @@ def get_full_context_by_document(self, doc_id): """ db = self.data_storage.embeddings_dict[doc_id] docs = db.get() - return docs['documents'] + return docs["documents"] def _get_context_multiquery(self, doc_id, query, context_size=4): db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) @@ -547,8 +523,8 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, coordinates = True # if chunk_size == -1 else False structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates) - biblio = structure['biblio'] - biblio['filename'] = filename.replace(" ", "_") + biblio = structure["biblio"] + biblio["filename"] = filename.replace(" ", "_") if verbose: print("Generating embeddings for filename: ", filename) @@ -558,19 +534,19 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, ids = [] if chunk_size > 0: - new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size) + new_passages = self.text_merger.merge_passages(structure["passages"], chunk_size=chunk_size) else: - new_passages = structure['passages'] + new_passages = structure["passages"] for passage in new_passages: biblio_copy = copy.copy(biblio) - if len(str.strip(passage['text'])) > 0: - texts.append(passage['text']) + if len(str.strip(passage["text"])) > 0: + texts.append(passage["text"]) - biblio_copy['type'] = passage['type'] - biblio_copy['section'] = passage['section'] - biblio_copy['subSection'] = passage['subSection'] - biblio_copy['coordinates'] = passage['coordinates'] + biblio_copy["type"] = passage["type"] + biblio_copy["section"] = passage["section"] + biblio_copy["subSection"] = passage["subSection"] + biblio_copy["coordinates"] = passage["coordinates"] metadatas.append(biblio_copy) # ids.append(passage['passage_id']) @@ -579,13 +555,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, return texts, metadatas, ids - def create_memory_embeddings( - self, - pdf_path, - doc_id=None, - chunk_size=500, - perc_overlap=0.1 - ): + def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1): """Parse a PDF and create an in-memory vector collection. This is the main entry-point for ingesting a new document. It @@ -602,26 +572,17 @@ def create_memory_embeddings( Returns: str: The document ID. """ - texts, metadata, ids = self.get_text_from_document( - pdf_path, - chunk_size=chunk_size, - perc_overlap=perc_overlap) + texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap) if doc_id: hash = doc_id else: - hash = metadata[0]['hash'] if len(metadata) > 0 and 'hash' in metadata[0] else "" + hash = metadata[0]["hash"] if len(metadata) > 0 and "hash" in metadata[0] else "" self.data_storage.embed_document(hash, texts, metadata) return hash - def create_embeddings( - self, - pdfs_dir_path: Path, - chunk_size=500, - perc_overlap=0.1, - include_biblio=False - ): + def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False): """Batch-process a directory of PDFs and persist their embeddings. Walks *pdfs_dir_path*, processes each ``.pdf`` file through GROBID, @@ -641,9 +602,7 @@ def create_embeddings( continue input_files.append(os.path.join(root, file_)) - for input_file in tqdm(input_files, total=len(input_files), unit='document', - desc="Grobid + embeddings processing"): - + for input_file in tqdm(input_files, total=len(input_files), unit="document", desc="Grobid + embeddings processing"): md5 = self.calculate_md5(input_file) data_path = os.path.join(self.data_storage.embeddings_root_path, md5) @@ -651,19 +610,15 @@ def create_embeddings( print(data_path, "exists. Skipping it ") continue # include = ["biblio"] if include_biblio else [] - texts, metadata, ids = self.get_text_from_document( - input_file, - chunk_size=chunk_size, - perc_overlap=perc_overlap) - filename = metadata[0]['filename'] - - vector_db_document = Chroma.from_texts(texts, - metadatas=metadata, - embedding=self.embedding_function, - persist_directory=data_path) + texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size, perc_overlap=perc_overlap) + filename = metadata[0]["filename"] + + vector_db_document = Chroma.from_texts( + texts, metadatas=metadata, embedding=self.embedding_function, persist_directory=data_path + ) vector_db_document.persist() - with open(os.path.join(data_path, filename + ".storage_filename"), 'w') as fo: + with open(os.path.join(data_path, filename + ".storage_filename"), "w") as fo: fo.write("") @staticmethod @@ -671,7 +626,8 @@ def calculate_md5(input_file: Union[Path, str]): """Return the uppercase hex MD5 digest of *input_file*.""" import hashlib + md5_hash = hashlib.md5() - with open(input_file, 'rb') as fi: + with open(input_file, "rb") as fi: md5_hash.update(fi.read()) return md5_hash.hexdigest().upper() diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index 285e39b..38d514b 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -36,11 +36,11 @@ def __init__(self, message="Grobid service error", status_code=None): def get_span_start(type, title=None): """Return an opening ```` tag for an annotation of the given *type*.""" title_ = ' title="' + title + '"' if title is not None else "" - return '' + return '" def get_span_end(): - return '' + return "" def get_rs_start(type): @@ -48,11 +48,11 @@ def get_rs_start(type): def get_rs_end(): - return '' + return "" def has_space_between_value_and_unit(quantity): - return quantity['offsetEnd'] < quantity['rawUnit']['offsetStart'] + return quantity["offsetEnd"] < quantity["rawUnit"]["offsetStart"] def decorate_text_with_annotations(text, spans, tag="span"): @@ -70,27 +70,27 @@ def decorate_text_with_annotations(text, spans, tag="span"): Returns: str: The text with inline annotation markup. """ - sorted_spans = list(sorted(spans, key=lambda item: item['offset_start'])) + sorted_spans = list(sorted(spans, key=lambda item: item["offset_start"])) annotated_text = "" start = 0 for span in sorted_spans: - type = span['type'].replace("<", "").replace(">", "") - if 'unit_type' in span and span['unit_type'] is not None: - type = span['unit_type'].replace(" ", "_") - annotated_text += escape(text[start: span['offset_start']]) - title = span['quantified'] if 'quantified' in span else None + type = span["type"].replace("<", "").replace(">", "") + if "unit_type" in span and span["unit_type"] is not None: + type = span["unit_type"].replace(" ", "_") + annotated_text += escape(text[start : span["offset_start"]]) + title = span["quantified"] if "quantified" in span else None annotated_text += get_span_start(type, title) if tag == "span" else get_rs_start(type) - annotated_text += escape(text[span['offset_start']: span['offset_end']]) + annotated_text += escape(text[span["offset_start"] : span["offset_end"]]) annotated_text += get_span_end() if tag == "span" else get_rs_end() - start = span['offset_end'] - annotated_text += escape(text[start: len(text)]) + start = span["offset_end"] + annotated_text += escape(text[start : len(text)]) return annotated_text def get_parsed_value_type(quantity): - if 'parsedValue' in quantity and 'structure' in quantity['parsedValue']: - return quantity['parsedValue']['structure']['type'] + if "parsedValue" in quantity and "structure" in quantity["parsedValue"]: + return quantity["parsedValue"]["structure"]["type"] class BaseProcessor(object): @@ -101,9 +101,7 @@ class BaseProcessor(object): inherit :meth:`post_process` from here. """ - patterns = [ - r'\d+e\d+' - ] + patterns = [r"\d+e\d+"] def post_process(self, text): """Clean encoding artefacts and normalise special characters. @@ -114,16 +112,16 @@ def post_process(self, text): Returns: str: Cleaned text. """ - output = text.replace('À', '-') - output = output.replace('¼', '=') - output = output.replace('þ', '+') - output = output.replace('Â', 'x') - output = output.replace('$', '~') - output = output.replace('−', '-') - output = output.replace('–', '-') + output = text.replace("À", "-") + output = output.replace("¼", "=") + output = output.replace("þ", "+") + output = output.replace("Â", "x") + output = output.replace("$", "~") + output = output.replace("−", "-") + output = output.replace("–", "-") for pattern in self.patterns: - output = re.sub(pattern, lambda match: match.group().replace('e', '-'), output) + output = re.sub(pattern, lambda match: match.group().replace("e", "-"), output) return output @@ -154,7 +152,7 @@ def __init__(self, grobid_url, ping_server=True): coordinates=["p", "title", "persName"], sleep_time=5, timeout=60, - check_server=ping_server + check_server=ping_server, ) self.grobid_client = grobid_client @@ -178,15 +176,17 @@ def process_structure(self, input_path, coordinates=False): Returns ``None`` if GROBID returns a non-200 status. """ try: - pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument", - input_path, - consolidate_header=True, - consolidate_citations=False, - segment_sentences=False, - tei_coordinates=coordinates, - include_raw_citations=False, - include_raw_affiliations=False, - generateIDs=True) + pdf_file, status, text = self.grobid_client.process_pdf( + "processFulltextDocument", + input_path, + consolidate_header=True, + consolidate_citations=False, + segment_sentences=False, + tei_coordinates=coordinates, + include_raw_citations=False, + include_raw_affiliations=False, + generateIDs=True, + ) except requests.exceptions.RequestException as exc: # Transport-level failure (connection refused, timeout, …). # Local/usage errors (bad path, parsing bugs) are intentionally @@ -205,10 +205,7 @@ def process_structure(self, input_path, coordinates=False): # Grobid can answer 200 with an empty body (e.g. it gave up on the PDF). if not text or not text.strip(): - raise GrobidServiceError( - "Grobid returned an empty response.", - status_code=status - ) + raise GrobidServiceError("Grobid returned an empty response.", status_code=status) # A truncated/corrupted TEI payload makes the XML parser blow up; map # that to a clear service error instead of an opaque parsing traceback. @@ -217,29 +214,23 @@ def process_structure(self, input_path, coordinates=False): except GrobidServiceError: raise except Exception as exc: - raise GrobidServiceError( - "Grobid returned a malformed or truncated response.", - status_code=status - ) from exc + raise GrobidServiceError("Grobid returned a malformed or truncated response.", status_code=status) from exc - document_object['filename'] = Path(pdf_file).stem.replace(".tei", "") + document_object["filename"] = Path(pdf_file).stem.replace(".tei", "") # Well-formed XML can still carry no usable text (e.g. an image-only or # truncated PDF). Nothing to embed downstream, so fail loudly here. - if not any(passage.get('text', '').strip() for passage in document_object.get('passages', [])): - raise GrobidServiceError( - "Grobid returned a document with no extractable text.", - status_code=status - ) + if not any(passage.get("text", "").strip() for passage in document_object.get("passages", [])): + raise GrobidServiceError("Grobid returned a document with no extractable text.", status_code=status) return document_object def process_single(self, input_file): doc = self.process_structure(input_file) - for paragraph in doc['passages']: - entities = self.process_single_text(paragraph['text']) - paragraph['spans'] = entities + for paragraph in doc["passages"]: + entities = self.process_single_text(paragraph["text"]) + paragraph["spans"] = entities return doc @@ -264,7 +255,7 @@ def parse_grobid_xml(self, text, coordinates=False): "doi": doc_biblio.header.doi if doc_biblio.header.doi is not None else "", "authors": ", ".join([author.full_name for author in doc_biblio.header.authors]), "title": doc_biblio.header.title, - "hash": doc_biblio.pdf_md5 + "hash": doc_biblio.pdf_md5, } try: year = dateparser.parse(doc_biblio.header.date).year @@ -272,12 +263,12 @@ def parse_grobid_xml(self, text, coordinates=False): except Exception: pass - output_data['biblio'] = biblio + output_data["biblio"] = biblio passages = [] - output_data['passages'] = passages + output_data["passages"] = passages passage_type = "paragraph" - soup = BeautifulSoup(text, 'xml') + soup = BeautifulSoup(text, "xml") blocks_header = get_xml_nodes_header(soup, use_paragraphs=True) # passages.append({ @@ -290,99 +281,132 @@ def parse_grobid_xml(self, text, coordinates=False): # blocks_header['authors']]) # }) - passages.append({ - "text": self.post_process(" ".join([node.text for node in blocks_header['title']])), - "type": passage_type, - "section": "
", - "subSection": "", - "passage_id": "htitle", - "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in - blocks_header['title']]) - }) - - passages.append({ - "text": self.post_process( - ''.join(node.text for node in blocks_header['abstract'] for text in node.find_all(text=True) if - text.parent.name != "ref" or ( - text.parent.name == "ref" and text.parent.attrs[ - 'type'] != 'bibr'))), - "type": passage_type, - "section": "<header>", - "subSection": "<abstract>", - "passage_id": "habstract", - "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in - blocks_header['abstract']]) - }) + passages.append( + { + "text": self.post_process(" ".join([node.text for node in blocks_header["title"]])), + "type": passage_type, + "section": "<header>", + "subSection": "<title>", + "passage_id": "htitle", + "coordinates": ";".join( + [node["coords"] if coordinates and node.has_attr("coords") else "" for node in blocks_header["title"]] + ), + } + ) + + passages.append( + { + "text": self.post_process( + "".join( + node.text + for node in blocks_header["abstract"] + for text in node.find_all(text=True) + if text.parent.name != "ref" or (text.parent.name == "ref" and text.parent.attrs["type"] != "bibr") + ) + ), + "type": passage_type, + "section": "<header>", + "subSection": "<abstract>", + "passage_id": "habstract", + "coordinates": ";".join( + [node["coords"] if coordinates and node.has_attr("coords") else "" for node in blocks_header["abstract"]] + ), + } + ) text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True) text_blocks_body.extend(get_xml_nodes_back(soup, verbose=False, use_paragraphs=True)) use_paragraphs = True if not use_paragraphs: - passages.extend([ - { - "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if - text.parent.name != "ref" or ( - text.parent.name == "ref" and text.parent.attrs[ - 'type'] != 'bibr'))), - "type": passage_type, - "section": "<body>", - "subSection": "<paragraph>", - "passage_id": str(paragraph_id), - "coordinates": paragraph['coords'] if coordinates and sentence.has_attr('coords') else "" - } - for paragraph_id, paragraph in enumerate(text_blocks_body) for - sentence_id, sentence in enumerate(paragraph) - ]) + passages.extend( + [ + { + "text": self.post_process( + "".join( + text + for text in sentence.find_all(text=True) + if text.parent.name != "ref" + or (text.parent.name == "ref" and text.parent.attrs["type"] != "bibr") + ) + ), + "type": passage_type, + "section": "<body>", + "subSection": "<paragraph>", + "passage_id": str(paragraph_id), + "coordinates": paragraph["coords"] if coordinates and sentence.has_attr("coords") else "", + } + for paragraph_id, paragraph in enumerate(text_blocks_body) + for sentence_id, sentence in enumerate(paragraph) + ] + ) else: - passages.extend([ - { - "text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if - text.parent.name != "ref" or ( - text.parent.name == "ref" and text.parent.attrs[ - 'type'] != 'bibr'))), - "type": passage_type, - "section": "<body>", - "subSection": "<paragraph>", - "passage_id": str(paragraph_id), - "coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else "" - } - for paragraph_id, paragraph in enumerate(text_blocks_body) - ]) + passages.extend( + [ + { + "text": self.post_process( + "".join( + text + for text in paragraph.find_all(text=True) + if text.parent.name != "ref" + or (text.parent.name == "ref" and text.parent.attrs["type"] != "bibr") + ) + ), + "type": passage_type, + "section": "<body>", + "subSection": "<paragraph>", + "passage_id": str(paragraph_id), + "coordinates": paragraph["coords"] if coordinates and paragraph.has_attr("coords") else "", + } + for paragraph_id, paragraph in enumerate(text_blocks_body) + ] + ) text_blocks_figures = get_xml_nodes_figures(soup, verbose=False) if not use_paragraphs: - passages.extend([ - { - "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if - text.parent.name != "ref" or ( - text.parent.name == "ref" and text.parent.attrs[ - 'type'] != 'bibr'))), - "type": passage_type, - "section": "<body>", - "subSection": "<figure>", - "passage_id": str(paragraph_id) + str(sentence_id), - "coordinates": sentence['coords'] if coordinates and 'coords' in sentence else "" - } - for paragraph_id, paragraph in enumerate(text_blocks_figures) for - sentence_id, sentence in enumerate(paragraph) - ]) + passages.extend( + [ + { + "text": self.post_process( + "".join( + text + for text in sentence.find_all(text=True) + if text.parent.name != "ref" + or (text.parent.name == "ref" and text.parent.attrs["type"] != "bibr") + ) + ), + "type": passage_type, + "section": "<body>", + "subSection": "<figure>", + "passage_id": str(paragraph_id) + str(sentence_id), + "coordinates": sentence["coords"] if coordinates and "coords" in sentence else "", + } + for paragraph_id, paragraph in enumerate(text_blocks_figures) + for sentence_id, sentence in enumerate(paragraph) + ] + ) else: - passages.extend([ - { - "text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if - text.parent.name != "ref" or ( - text.parent.name == "ref" and text.parent.attrs[ - 'type'] != 'bibr'))), - "type": passage_type, - "section": "<body>", - "subSection": "<figure>", - "passage_id": str(paragraph_id), - "coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else "" - } - for paragraph_id, paragraph in enumerate(text_blocks_figures) - ]) + passages.extend( + [ + { + "text": self.post_process( + "".join( + text + for text in paragraph.find_all(text=True) + if text.parent.name != "ref" + or (text.parent.name == "ref" and text.parent.attrs["type"] != "bibr") + ) + ), + "type": passage_type, + "section": "<body>", + "subSection": "<figure>", + "passage_id": str(paragraph_id), + "coordinates": paragraph["coords"] if coordinates and paragraph.has_attr("coords") else "", + } + for paragraph_id, paragraph in enumerate(text_blocks_figures) + ] + ) return output_data @@ -418,26 +442,26 @@ def process(self, text) -> list: spans = [] - if 'measurements' in result: + if "measurements" in result: found_measurements = self.parse_measurements_output(result) for m in found_measurements: item = { - "text": text[m['offset_start']:m['offset_end']], - 'offset_start': m['offset_start'], - 'offset_end': m['offset_end'] + "text": text[m["offset_start"] : m["offset_end"]], + "offset_start": m["offset_start"], + "offset_end": m["offset_end"], } - if 'raw' in m and m['raw'] != item['text']: - item['text'] = m['raw'] + if "raw" in m and m["raw"] != item["text"]: + item["text"] = m["raw"] - if 'quantified_substance' in m: - item['quantified'] = m['quantified_substance'] + if "quantified_substance" in m: + item["quantified"] = m["quantified_substance"] - if 'type' in m: - item["unit_type"] = m['type'] + if "type" in m: + item["unit_type"] = m["type"] - item['type'] = 'property' + item["type"] = "property" # if 'raw_value' in m: # item['raw_value'] = m['raw_value'] @@ -449,21 +473,21 @@ def process(self, text) -> list: def parse_measurements_output(result): measurements_output = [] - for measurement in result['measurements']: - type = measurement['type'] + for measurement in result["measurements"]: + type = measurement["type"] measurement_output_object = {} quantity_type = None has_unit = False parsed_value_type = None - if 'quantified' in measurement: - if 'normalizedName' in measurement['quantified']: - quantified_substance = measurement['quantified']['normalizedName'] + if "quantified" in measurement: + if "normalizedName" in measurement["quantified"]: + quantified_substance = measurement["quantified"]["normalizedName"] measurement_output_object["quantified_substance"] = quantified_substance - if 'measurementOffsets' in measurement: - measurement_output_object["offset_start"] = measurement["measurementOffsets"]['start'] - measurement_output_object["offset_end"] = measurement["measurementOffsets"]['end'] + if "measurementOffsets" in measurement: + measurement_output_object["offset_start"] = measurement["measurementOffsets"]["start"] + measurement_output_object["offset_end"] = measurement["measurementOffsets"]["end"] else: # If there are no offsets we skip the measurement continue @@ -471,66 +495,66 @@ def parse_measurements_output(result): # if 'measurementRaw' in measurement: # measurement_output_object['raw_value'] = measurement['measurementRaw'] - if type == 'value': - quantity = measurement['quantity'] + if type == "value": + quantity = measurement["quantity"] parsed_value = GrobidQuantitiesProcessor.get_parsed(quantity) if parsed_value: - measurement_output_object['parsed'] = parsed_value + measurement_output_object["parsed"] = parsed_value normalized_value = GrobidQuantitiesProcessor.get_normalized(quantity) if normalized_value: - measurement_output_object['normalized'] = normalized_value + measurement_output_object["normalized"] = normalized_value raw_value = GrobidQuantitiesProcessor.get_raw(quantity) if raw_value: - measurement_output_object['raw'] = raw_value + measurement_output_object["raw"] = raw_value - if 'type' in quantity: - quantity_type = quantity['type'] + if "type" in quantity: + quantity_type = quantity["type"] - if 'rawUnit' in quantity: + if "rawUnit" in quantity: has_unit = True parsed_value_type = get_parsed_value_type(quantity) - elif type == 'interval': - if 'quantityMost' in measurement: - quantityMost = measurement['quantityMost'] - if 'type' in quantityMost: - quantity_type = quantityMost['type'] + elif type == "interval": + if "quantityMost" in measurement: + quantityMost = measurement["quantityMost"] + if "type" in quantityMost: + quantity_type = quantityMost["type"] - if 'rawUnit' in quantityMost: + if "rawUnit" in quantityMost: has_unit = True parsed_value_type = get_parsed_value_type(quantityMost) - if 'quantityLeast' in measurement: - quantityLeast = measurement['quantityLeast'] + if "quantityLeast" in measurement: + quantityLeast = measurement["quantityLeast"] - if 'type' in quantityLeast: - quantity_type = quantityLeast['type'] + if "type" in quantityLeast: + quantity_type = quantityLeast["type"] - if 'rawUnit' in quantityLeast: + if "rawUnit" in quantityLeast: has_unit = True parsed_value_type = get_parsed_value_type(quantityLeast) - elif type == 'listc': - quantities = measurement['quantities'] + elif type == "listc": + quantities = measurement["quantities"] - if 'type' in quantities[0]: - quantity_type = quantities[0]['type'] + if "type" in quantities[0]: + quantity_type = quantities[0]["type"] - if 'rawUnit' in quantities[0]: + if "rawUnit" in quantities[0]: has_unit = True parsed_value_type = get_parsed_value_type(quantities[0]) if quantity_type is not None or has_unit: - measurement_output_object['type'] = quantity_type + measurement_output_object["type"] = quantity_type - if parsed_value_type is None or parsed_value_type not in ['ALPHABETIC', 'TIME']: + if parsed_value_type is None or parsed_value_type not in ["ALPHABETIC", "TIME"]: measurements_output.append(measurement_output_object) return measurements_output @@ -538,10 +562,10 @@ def parse_measurements_output(result): @staticmethod def get_parsed(quantity): parsed_value = parsed_unit = None - if 'parsedValue' in quantity and 'parsed' in quantity['parsedValue']: - parsed_value = quantity['parsedValue']['parsed'] - if 'parsedUnit' in quantity and 'name' in quantity['parsedUnit']: - parsed_unit = quantity['parsedUnit']['name'] + if "parsedValue" in quantity and "parsed" in quantity["parsedValue"]: + parsed_value = quantity["parsedValue"]["parsed"] + if "parsedUnit" in quantity and "name" in quantity["parsedUnit"]: + parsed_unit = quantity["parsedUnit"]["name"] if parsed_value and parsed_unit: if has_space_between_value_and_unit(quantity): @@ -552,10 +576,10 @@ def get_parsed(quantity): @staticmethod def get_normalized(quantity): normalized_value = normalized_unit = None - if 'normalizedQuantity' in quantity: - normalized_value = quantity['normalizedQuantity'] - if 'normalizedUnit' in quantity and 'name' in quantity['normalizedUnit']: - normalized_unit = quantity['normalizedUnit']['name'] + if "normalizedQuantity" in quantity: + normalized_value = quantity["normalizedQuantity"] + if "normalizedUnit" in quantity and "name" in quantity["normalizedUnit"]: + normalized_unit = quantity["normalizedUnit"]["name"] if normalized_value and normalized_unit: if has_space_between_value_and_unit(quantity): @@ -566,10 +590,10 @@ def get_normalized(quantity): @staticmethod def get_raw(quantity): raw_value = raw_unit = None - if 'rawValue' in quantity: - raw_value = quantity['rawValue'] - if 'rawUnit' in quantity and 'name' in quantity['rawUnit']: - raw_unit = quantity['rawUnit']['name'] + if "rawValue" in quantity: + raw_value = quantity["rawValue"] + if "rawUnit" in quantity and "name" in quantity["rawUnit"]: + raw_unit = quantity["rawUnit"]["name"] if raw_value and raw_unit: if has_space_between_value_and_unit(quantity): @@ -603,28 +627,27 @@ def process(self, text): ``type`` (``"material"``), and optional ``formula`` keys. """ preprocessed_text = text.strip() - status, result = self.grobid_superconductors_client.process_text(preprocessed_text, - "processText_disable_linking") + status, result = self.grobid_superconductors_client.process_text(preprocessed_text, "processText_disable_linking") if status != 200: result = {} spans = [] - if 'passages' in result: + if "passages" in result: materials = self.parse_superconductors_output(result, preprocessed_text) for m in materials: - item = {"text": preprocessed_text[m['offset_start']:m['offset_end']]} + item = {"text": preprocessed_text[m["offset_start"] : m["offset_end"]]} - item['offset_start'] = m['offset_start'] - item['offset_end'] = m['offset_end'] + item["offset_start"] = m["offset_start"] + item["offset_end"] = m["offset_end"] - if 'formula' in m: - item["formula"] = m['formula'] + if "formula" in m: + item["formula"] = m["formula"] - item['type'] = 'material' - item['raw_value'] = m['text'] + item["type"] = "material" + item["raw_value"] = m["text"] spans.append(item) @@ -640,13 +663,13 @@ def parse_materials(self, text): for position_material in result: compositions = [] for material in position_material: - if 'resolvedFormulas' in material: - for resolved_formula in material['resolvedFormulas']: - if 'formulaComposition' in resolved_formula: - compositions.append(resolved_formula['formulaComposition']) - elif 'formula' in material: - if 'formulaComposition' in material['formula']: - compositions.append(material['formula']['formulaComposition']) + if "resolvedFormulas" in material: + for resolved_formula in material["resolvedFormulas"]: + if "formulaComposition" in resolved_formula: + compositions.append(resolved_formula["formulaComposition"]) + elif "formula" in material: + if "formulaComposition" in material["formula"]: + compositions.append(material["formula"]["formulaComposition"]) results.append(compositions) return results @@ -664,32 +687,32 @@ def parse_material(self, text): def output_info(self, result): compositions = [] for material in result: - if 'resolvedFormulas' in material: - for resolved_formula in material['resolvedFormulas']: - if 'formulaComposition' in resolved_formula: - compositions.append(resolved_formula['formulaComposition']) - elif 'formula' in material: - if 'formulaComposition' in material['formula']: - compositions.append(material['formula']['formulaComposition']) - if 'name' in material: - compositions.append(material['name']) + if "resolvedFormulas" in material: + for resolved_formula in material["resolvedFormulas"]: + if "formulaComposition" in resolved_formula: + compositions.append(resolved_formula["formulaComposition"]) + elif "formula" in material: + if "formulaComposition" in material["formula"]: + compositions.append(material["formula"]["formulaComposition"]) + if "name" in material: + compositions.append(material["name"]) return compositions @staticmethod def parse_superconductors_output(result, original_text): materials = [] - for passage in result['passages']: - sentence_offset = original_text.index(passage['text']) - if 'spans' in passage: - spans = passage['spans'] - for material_span in filter(lambda s: s['type'] == '<material>', spans): - text_ = material_span['text'] + for passage in result["passages"]: + sentence_offset = original_text.index(passage["text"]) + if "spans" in passage: + spans = passage["spans"] + for material_span in filter(lambda s: s["type"] == "<material>", spans): + text_ = material_span["text"] base_material_information = { "text": text_, - "offset_start": sentence_offset + material_span['offset_start'], - 'offset_end': sentence_offset + material_span['offset_end'] + "offset_start": sentence_offset + material_span["offset_start"], + "offset_end": sentence_offset + material_span["offset_end"], } materials.append(base_material_information) @@ -765,13 +788,13 @@ def box_to_dict(box, color=None, type=None, border=None): item = {"page": box[0], "x": box[1], "y": box[2], "width": box[3], "height": box[4]} if color: - item['color'] = color + item["color"] = color if type: - item['type'] = type + item["type"] = type if border: - item['border'] = border + item["border"] = border return item @@ -790,7 +813,7 @@ def prune_overlapping_annotations(entities: list) -> list: list[dict]: Pruned, non-overlapping spans sorted by offset. """ # Sorting by offsets - sorted_entities = sorted(entities, key=lambda d: d['offset_start']) + sorted_entities = sorted(entities, key=lambda d: d["offset_start"]) if len(entities) <= 1: return sorted_entities @@ -806,96 +829,104 @@ def prune_overlapping_annotations(entities: list) -> list: previous = current continue - if previous['offset_start'] < current['offset_start'] \ - and previous['offset_end'] < current['offset_end'] \ - and (previous['offset_end'] < current['offset_start'] \ - and not (previous['text'] == "-" and current['text'][0].isdigit())): + if ( + previous["offset_start"] < current["offset_start"] + and previous["offset_end"] < current["offset_end"] + and ( + previous["offset_end"] < current["offset_start"] + and not (previous["text"] == "-" and current["text"][0].isdigit()) + ) + ): previous = current continue - if previous['offset_end'] < current['offset_end']: - if current['type'] == previous['type']: + if previous["offset_end"] < current["offset_end"]: + if current["type"] == previous["type"]: # Type is the same - if current['offset_start'] == previous['offset_end']: - if current['type'] == 'property': - if current['text'].startswith("."): + if current["offset_start"] == previous["offset_end"]: + if current["type"] == "property": + if current["text"].startswith("."): print( - f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>" + ) # current entity starts with a ".", suspiciously look like a truncated value to_be_removed.append(previous) - current['text'] = previous['text'] + current['text'] - current['raw_value'] = current['text'] - current['offset_start'] = previous['offset_start'] - elif previous['text'].endswith(".") and current['text'][0].isdigit(): + current["text"] = previous["text"] + current["text"] + current["raw_value"] = current["text"] + current["offset_start"] = previous["offset_start"] + elif previous["text"].endswith(".") and current["text"][0].isdigit(): print( - f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>" + ) # previous entity ends with ".", current entity starts with a number to_be_removed.append(previous) - current['text'] = previous['text'] + current['text'] - current['raw_value'] = current['text'] - current['offset_start'] = previous['offset_start'] - elif previous['text'].startswith("-"): + current["text"] = previous["text"] + current["text"] + current["raw_value"] = current["text"] + current["offset_start"] = previous["offset_start"] + elif previous["text"].startswith("-"): print( - f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>" + ) # previous starts with a `-`, sherlock this is another truncated value - current['text'] = previous['text'] + current['text'] - current['raw_value'] = current['text'] - current['offset_start'] = previous['offset_start'] + current["text"] = previous["text"] + current["text"] + current["raw_value"] = current["text"] + current["offset_start"] = previous["offset_start"] to_be_removed.append(previous) else: print("Other cases to be considered: ", previous, current) else: - if current['text'].startswith("-"): + if current["text"].startswith("-"): print( - f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>" + ) # previous starts with a `-`, sherlock this is another truncated value - current['text'] = previous['text'] + current['text'] - current['raw_value'] = current['text'] - current['offset_start'] = previous['offset_start'] + current["text"] = previous["text"] + current["text"] + current["raw_value"] = current["text"] + current["offset_start"] = previous["offset_start"] to_be_removed.append(previous) else: print("Other cases to be considered: ", previous, current) - elif previous['text'] == "-" and current['text'][0].isdigit(): - print( - f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + elif previous["text"] == "-" and current["text"][0].isdigit(): + print(f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") # previous starts with a `-`, sherlock this is another truncated value - current['text'] = previous['text'] + " " * (current['offset_start'] - previous['offset_end']) + \ - current['text'] - current['raw_value'] = current['text'] - current['offset_start'] = previous['offset_start'] + current["text"] = ( + previous["text"] + " " * (current["offset_start"] - previous["offset_end"]) + current["text"] + ) + current["raw_value"] = current["text"] + current["offset_start"] = previous["offset_start"] to_be_removed.append(previous) else: print( - f"Overlapping. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + f"Overlapping. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>" + ) # take the largest one - if len(previous['text']) > len(current['text']): + if len(previous["text"]) > len(current["text"]): to_be_removed.append(current) - elif len(previous['text']) < len(current['text']): + elif len(previous["text"]) < len(current["text"]): to_be_removed.append(previous) else: to_be_removed.append(previous) - elif current['type'] != previous['type']: - print( - f"Overlapping. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") + elif current["type"] != previous["type"]: + print(f"Overlapping. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") - if len(previous['text']) > len(current['text']): + if len(previous["text"]) > len(current["text"]): to_be_removed.append(current) - elif len(previous['text']) < len(current['text']): + elif len(previous["text"]) < len(current["text"]): to_be_removed.append(previous) else: - if current['type'] == "material": + if current["type"] == "material": to_be_removed.append(previous) else: to_be_removed.append(current) previous = current - elif previous['offset_end'] > current['offset_end']: + elif previous["offset_end"] > current["offset_end"]: to_be_removed.append(current) # the previous goes after the current, so we keep the previous and we discard the current else: - if current['type'] == "material": + if current["type"] == "material": to_be_removed.append(previous) else: to_be_removed.append(current) @@ -912,11 +943,11 @@ def __init__(self): def process_structure(self, input_file): text = "" - with open(input_file, encoding='utf-8') as fi: + with open(input_file, encoding="utf-8") as fi: text = fi.read() output_data = self.parse_xml(text) - output_data['filename'] = Path(input_file).stem.replace(".tei", "") + output_data["filename"] = Path(input_file).stem.replace(".tei", "") return output_data @@ -931,25 +962,30 @@ def process_structure(self, input_file): def process(self, text): output_data = OrderedDict() - soup = BeautifulSoup(text, 'xml') + soup = BeautifulSoup(text, "xml") text_blocks_children = get_children_list_supermat(soup, verbose=False) passages = [] - output_data['passages'] = passages - passages.extend([ - { - "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if - text.parent.name != "ref" or ( - text.parent.name == "ref" and text.parent.attrs[ - 'type'] != 'bibr'))), - "type": "paragraph", - "section": "<body>", - "subSection": "<paragraph>", - "passage_id": str(paragraph_id) + str(sentence_id) - } - for paragraph_id, paragraph in enumerate(text_blocks_children) for - sentence_id, sentence in enumerate(paragraph) - ]) + output_data["passages"] = passages + passages.extend( + [ + { + "text": self.post_process( + "".join( + text + for text in sentence.find_all(text=True) + if text.parent.name != "ref" or (text.parent.name == "ref" and text.parent.attrs["type"] != "bibr") + ) + ), + "type": "paragraph", + "section": "<body>", + "subSection": "<paragraph>", + "passage_id": str(paragraph_id) + str(sentence_id), + } + for paragraph_id, paragraph in enumerate(text_blocks_children) + for sentence_id, sentence in enumerate(paragraph) + ] + ) return output_data @@ -959,12 +995,12 @@ def get_children_list_supermat(soup, use_paragraphs=False, verbose=False): child_name = "p" if use_paragraphs else "s" for child in soup.tei.children: - if child.name == 'teiHeader': + if child.name == "teiHeader": pass children.append(child.find_all("title")) children.extend([subchild.find_all(child_name) for subchild in child.find_all("abstract")]) children.extend([subchild.find_all(child_name) for subchild in child.find_all("ab", {"type": "keywords"})]) - elif child.name == 'text': + elif child.name == "text": children.extend([subchild.find_all(child_name) for subchild in child.find_all("body")]) if verbose: @@ -978,11 +1014,11 @@ def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbos child_name = "p" if use_paragraphs else "s" for child in soup.TEI.children: - if child.name == 'teiHeader': + if child.name == "teiHeader": pass # children.extend(child.find_all("title", attrs={"level": "a"}, limit=1)) # children.extend([subchild.find_all(child_name) for subchild in child.find_all("abstract")]) - elif child.name == 'text': + elif child.name == "text": children.extend([subchild.find_all(child_name) for subchild in child.find_all("body")]) children.extend([subchild.find_all("figDesc") for subchild in child.find_all("body")]) @@ -997,9 +1033,12 @@ def get_xml_nodes_header(soup: object, use_paragraphs: bool = True) -> list: header_elements = { "authors": [persNameNode for persNameNode in soup.teiHeader.find_all("persName")], - "abstract": [p_in_abstract for abstractNodes in soup.teiHeader.find_all("abstract") for p_in_abstract in - abstractNodes.find_all(sub_tag)], - "title": [soup.teiHeader.fileDesc.title] + "abstract": [ + p_in_abstract + for abstractNodes in soup.teiHeader.find_all("abstract") + for p_in_abstract in abstractNodes.find_all(sub_tag) + ], + "title": [soup.teiHeader.fileDesc.title], } return header_elements @@ -1009,10 +1048,9 @@ def get_xml_nodes_body(soup: object, use_paragraphs: bool = True, verbose: bool nodes = [] tag_name = "p" if use_paragraphs else "s" for child in soup.TEI.children: - if child.name == 'text': + if child.name == "text": # nodes.extend([subchild.find_all(tag_name) for subchild in child.find_all("body")]) - nodes.extend( - [subsubchild for subchild in child.find_all("body") for subsubchild in subchild.find_all(tag_name)]) + nodes.extend([subsubchild for subchild in child.find_all("body") for subsubchild in subchild.find_all(tag_name)]) if verbose: print(str(nodes)) @@ -1024,9 +1062,8 @@ def get_xml_nodes_back(soup: object, use_paragraphs: bool = True, verbose: bool nodes = [] tag_name = "p" if use_paragraphs else "s" for child in soup.TEI.children: - if child.name == 'text': - nodes.extend( - [subsubchild for subchild in child.find_all("back") for subsubchild in subchild.find_all(tag_name)]) + if child.name == "text": + nodes.extend([subsubchild for subchild in child.find_all("back") for subsubchild in subchild.find_all(tag_name)]) if verbose: print(str(nodes)) @@ -1037,9 +1074,8 @@ def get_xml_nodes_back(soup: object, use_paragraphs: bool = True, verbose: bool def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list: children = [] for child in soup.TEI.children: - if child.name == 'text': - children.extend( - [subchild for subchilds in child.find_all("body") for subchild in subchilds.find_all("figDesc")]) + if child.name == "text": + children.extend([subchild for subchilds in child.find_all("body") for subchild in subchilds.find_all("figDesc")]) if verbose: print(str(children)) diff --git a/document_qa/langchain.py b/document_qa/langchain.py index 4056dfd..196e1d4 100644 --- a/document_qa/langchain.py +++ b/document_qa/langchain.py @@ -29,12 +29,10 @@ class AdvancedVectorStoreRetriever(VectorStoreRetriever): "similarity", "similarity_score_threshold", "mmr", - "similarity_with_embeddings" + "similarity_with_embeddings", ) - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: """Fetch relevant documents for the configured search type. Supports all standard search types plus @@ -51,28 +49,20 @@ def _get_relevant_documents( """ if self.search_type == "similarity_with_embeddings": - docs_scores_and_embeddings = ( - self.vectorstore.advanced_similarity_search( - query, **self.search_kwargs - ) - ) + docs_scores_and_embeddings = self.vectorstore.advanced_similarity_search(query, **self.search_kwargs) for doc, score, embeddings in docs_scores_and_embeddings: - if '__embeddings' not in doc.metadata.keys(): - doc.metadata['__embeddings'] = embeddings - if '__similarity' not in doc.metadata.keys(): - doc.metadata['__similarity'] = score + if "__embeddings" not in doc.metadata.keys(): + doc.metadata["__embeddings"] = embeddings + if "__similarity" not in doc.metadata.keys(): + doc.metadata["__similarity"] = score docs = [doc for doc, _, _ in docs_scores_and_embeddings] elif self.search_type == "similarity_score_threshold": - docs_and_similarities = ( - self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs - ) - ) + docs_and_similarities = self.vectorstore.similarity_search_with_relevance_scores(query, **self.search_kwargs) for doc, similarity in docs_and_similarities: - if '__similarity' not in doc.metadata.keys(): - doc.metadata['__similarity'] = similarity + if "__similarity" not in doc.metadata.keys(): + doc.metadata["__similarity"] = similarity docs = [doc for doc, _ in docs_and_similarities] else: @@ -110,22 +100,19 @@ def __init__(self, **kwargs): @xor_args(("query_texts", "query_embeddings")) def __query_collection( - self, - query_texts: Optional[List[str]] = None, - query_embeddings: Optional[List[List[float]]] = None, - n_results: int = 4, - where: Optional[Dict[str, str]] = None, - where_document: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query_texts: Optional[List[str]] = None, + query_embeddings: Optional[List[List[float]]] = None, + n_results: int = 4, + where: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Query the chroma collection.""" try: import chromadb # noqa: F401 except ImportError: - raise ValueError( - "Could not import chromadb python package. " - "Please install it with `pip install chromadb`." - ) + raise ValueError("Could not import chromadb python package. Please install it with `pip install chromadb`.") return self._collection.query( query_texts=query_texts, query_embeddings=query_embeddings, @@ -136,11 +123,11 @@ def __query_collection( ) def advanced_similarity_search( - self, - query: str, - k: int = DEFAULT_K, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float, List[float]]]: """Return documents, similarity scores, and embeddings for *query*. @@ -157,12 +144,12 @@ def advanced_similarity_search( return docs_scores_and_embeddings def similarity_search_with_scores_and_embeddings( - self, - query: str, - k: int = DEFAULT_K, - filter: Optional[Dict[str, str]] = None, - where_document: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float, List[float]]]: """Low-level search returning docs with scores and embeddings. @@ -186,7 +173,7 @@ def similarity_search_with_scores_and_embeddings( n_results=k, where=filter, where_document=where_document, - include=['metadatas', 'documents', 'embeddings', 'distances'] + include=["metadatas", "documents", "embeddings", "distances"], ) else: query_embedding = self._embedding_function.embed_query(query) @@ -195,7 +182,7 @@ def similarity_search_with_scores_and_embeddings( n_results=k, where=filter, where_document=where_document, - include=['metadatas', 'documents', 'embeddings', 'distances'] + include=["metadatas", "documents", "embeddings", "distances"], ) return _results_to_docs_scores_and_embeddings(results) diff --git a/document_qa/ner_client_generic.py b/document_qa/ner_client_generic.py index fe4b846..9cccb30 100644 --- a/document_qa/ner_client_generic.py +++ b/document_qa/ner_client_generic.py @@ -3,12 +3,12 @@ import yaml -''' +""" This client is a generic client for any Grobid application and sub-modules. At the moment, it supports only single document processing. Source: https://github.com/kermitt2/grobid-client-python -''' +""" """ Generic API Client """ from copy import deepcopy @@ -22,24 +22,17 @@ class ApiClient(object): - """ Client to interact with a generic Rest API. + """Client to interact with a generic Rest API. Subclasses should implement functionality accordingly with the provided service methods, i.e. ``get``, ``post``, ``put`` and ``delete``. """ - accept_type = 'application/xml' + accept_type = "application/xml" api_base = None - def __init__( - self, - base_url, - username=None, - api_key=None, - status_endpoint=None, - timeout=60 - ): - """ Initialise client. + def __init__(self, base_url, username=None, api_key=None, status_endpoint=None, timeout=60): + """Initialise client. Args: base_url (str): The base URL to the service being used. @@ -55,7 +48,7 @@ def __init__( @staticmethod def encode(request, data): - """ Add request content data to request body, set Content-type header. + """Add request content data to request body, set Content-type header. Should be overridden by subclasses if not using JSON encoding. @@ -69,14 +62,14 @@ def encode(request, data): if data is None: return request - request.add_header('Content-Type', 'application/json') + request.add_header("Content-Type", "application/json") request.extracted_data = json.dumps(data) return request @staticmethod def decode(response): - """ Decode the returned data in the response. + """Decode the returned data in the response. Should be overridden by subclasses if something else than JSON is expected. @@ -93,7 +86,7 @@ def decode(response): return e.message def get_credentials(self): - """ Returns parameters to be added to authenticate the request. + """Returns parameters to be added to authenticate the request. This lives on its own to make it easier to re-implement it if needed. @@ -103,16 +96,16 @@ def get_credentials(self): return {"username": self.username, "api_key": self.api_key} def call_api( - self, - method, - url, - headers=None, - params=None, - data=None, - files=None, - timeout=None, + self, + method, + url, + headers=None, + params=None, + data=None, + files=None, + timeout=None, ): - """ Call API. + """Call API. This returns object containing data, with error details if applicable. @@ -129,7 +122,7 @@ def call_api( ResultParser or ErrorParser. """ headers = deepcopy(headers) or {} - headers['Accept'] = self.accept_type if 'Accept' not in headers else headers['Accept'] + headers["Accept"] = self.accept_type if "Accept" not in headers else headers["Accept"] params = deepcopy(params) or {} data = data or {} files = files or {} @@ -148,7 +141,7 @@ def call_api( return r, r.status_code def get(self, url, params=None, **kwargs): - """ Call the API with a GET request. + """Call the API with a GET request. Args: url (str): Resource location relative to the base URL. @@ -157,15 +150,10 @@ def get(self, url, params=None, **kwargs): Returns: ResultParser or ErrorParser. """ - return self.call_api( - "GET", - url, - params=params, - **kwargs - ) + return self.call_api("GET", url, params=params, **kwargs) def delete(self, url, params=None, **kwargs): - """ Call the API with a DELETE request. + """Call the API with a DELETE request. Args: url (str): Resource location relative to the base URL. @@ -174,15 +162,10 @@ def delete(self, url, params=None, **kwargs): Returns: ResultParser or ErrorParser. """ - return self.call_api( - "DELETE", - url, - params=params, - **kwargs - ) + return self.call_api("DELETE", url, params=params, **kwargs) def put(self, url, params=None, data=None, files=None, **kwargs): - """ Call the API with a PUT request. + """Call the API with a PUT request. Args: url (str): Resource location relative to the base URL. @@ -193,17 +176,10 @@ def put(self, url, params=None, data=None, files=None, **kwargs): Returns: An instance of ResultParser or ErrorParser. """ - return self.call_api( - "PUT", - url, - params=params, - data=data, - files=files, - **kwargs - ) + return self.call_api("PUT", url, params=params, data=data, files=files, **kwargs) def post(self, url, params=None, data=None, files=None, **kwargs): - """ Call the API with a POST request. + """Call the API with a POST request. Args: url (str): Resource location relative to the base URL. @@ -214,63 +190,50 @@ def post(self, url, params=None, data=None, files=None, **kwargs): Returns: An instance of ResultParser or ErrorParser. """ - return self.call_api( - method="POST", - url=url, - params=params, - data=data, - files=files, - **kwargs - ) + return self.call_api(method="POST", url=url, params=params, data=data, files=files, **kwargs) def service_status(self, **kwargs): - """ Call the API to get the status of the service. + """Call the API to get the status of the service. Returns: An instance of ResultParser or ErrorParser. """ - return self.call_api( - 'GET', - self.status_endpoint, - params={'format': 'json'}, - **kwargs - ) + return self.call_api("GET", self.status_endpoint, params={"format": "json"}, **kwargs) class NERClientGeneric(ApiClient): - def __init__(self, config_path=None, ping=False): self.config = None if config_path is not None: self.config = self._load_yaml_config_from_file(path=config_path) - super().__init__(self.config['grobid']['server']) + super().__init__(self.config["grobid"]["server"]) if ping: result = self.ping_service() if not result: raise Exception("Grobid is down.") - os.environ['NO_PROXY'] = "nims.go.jp" + os.environ["NO_PROXY"] = "nims.go.jp" @staticmethod - def _load_json_config_from_file(path='./config.json'): + def _load_json_config_from_file(path="./config.json"): """ Load the json configuration """ config = {} - with open(path, 'r') as fp: + with open(path, "r") as fp: config = json.load(fp) return config @staticmethod - def _load_yaml_config_from_file(path='./config.yaml'): + def _load_yaml_config_from_file(path="./config.yaml"): """ Load the YAML configuration """ config = {} try: - with open(path, 'r') as the_file: + with open(path, "r") as the_file: raw_configuration = the_file.read() config = yaml.safe_load(raw_configuration) @@ -298,130 +261,86 @@ def ping_service(self): status = r.status_code if status != 200: - print('GROBID server does not appear up and running ' + str(status)) + print("GROBID server does not appear up and running " + str(status)) return False else: print("GROBID server is up and running") return True def get_url(self, action): - grobid_config = self.config['grobid'] - base_url = grobid_config['server'] - action_url = base_url + grobid_config['url_mapping'][action] + grobid_config = self.config["grobid"] + base_url = grobid_config["server"] + action_url = base_url + grobid_config["url_mapping"][action] return action_url - def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + def process_texts(self, input, method_name="superconductors", params={}, headers={"Accept": "application/json"}): - files = { - 'texts': input - } + files = {"texts": input} the_url = self.get_url(method_name) params, the_url = self.get_params_from_url(the_url) - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) + res, status = self.post(url=the_url, files=files, data=params, headers=headers) if status == 503: - time.sleep(self.config['sleep_time']) + time.sleep(self.config["sleep_time"]) return self.process_texts(input, method_name, params, headers) elif status != 200: - print('Processing failed with error ' + str(status)) + print("Processing failed with error " + str(status)) return status, None else: return status, json.loads(res.text) - def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + def process_text(self, input, method_name="superconductors", params={}, headers={"Accept": "application/json"}): - files = { - 'text': input - } + files = {"text": input} the_url = self.get_url(method_name) params, the_url = self.get_params_from_url(the_url) - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) + res, status = self.post(url=the_url, files=files, data=params, headers=headers) if status == 503: - time.sleep(self.config['sleep_time']) + time.sleep(self.config["sleep_time"]) return self.process_text(input, method_name, params, headers) elif status != 200: - print('Processing failed with error ' + str(status)) + print("Processing failed with error " + str(status)) return status, None else: return status, json.loads(res.text) - def process_pdf(self, - form_data: dict, - method_name='superconductors', - params={}, - headers={"Accept": "application/json"} - ): + def process_pdf(self, form_data: dict, method_name="superconductors", params={}, headers={"Accept": "application/json"}): the_url = self.get_url(method_name) params, the_url = self.get_params_from_url(the_url) - res, status = self.post( - url=the_url, - files=form_data, - data=params, - headers=headers - ) + res, status = self.post(url=the_url, files=form_data, data=params, headers=headers) if status == 503: - time.sleep(self.config['sleep_time']) + time.sleep(self.config["sleep_time"]) return self.process_text(input, method_name, params, headers) elif status != 200: - print('Processing failed with error ' + str(status)) + print("Processing failed with error " + str(status)) else: return res.text def process_pdfs(self, pdf_files, params={}): pass - def process_pdf( - self, - pdf_file, - method_name, - params={}, - headers={"Accept": "application/json"}, - verbose=False, - retry=None - ): + def process_pdf(self, pdf_file, method_name, params={}, headers={"Accept": "application/json"}, verbose=False, retry=None): - files = { - 'input': ( - pdf_file, - open(pdf_file, 'rb'), - 'application/pdf', - {'Expires': '0'} - ) - } + files = {"input": (pdf_file, open(pdf_file, "rb"), "application/pdf", {"Expires": "0"})} the_url = self.get_url(method_name) params, the_url = self.get_params_from_url(the_url) - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) + res, status = self.post(url=the_url, files=files, data=params, headers=headers) if status == 503 or status == 429: if retry is None: - retry = self.config['max_retry'] - 1 + retry = self.config["max_retry"] - 1 else: if retry - 1 == 0: if verbose: @@ -430,7 +349,7 @@ def process_pdf( else: retry -= 1 - sleep_time = self.config['sleep_time'] + sleep_time = self.config["sleep_time"] if verbose: print("Server is saturated, waiting", sleep_time, "seconds and trying again. ") time.sleep(sleep_time) @@ -439,7 +358,7 @@ def process_pdf( desc = None if res.content: c = json.loads(res.text) - desc = c['description'] if 'description' in c else None + desc = c["description"] if "description" in c else None return desc, status elif status == 204: # print('No content returned. Moving on. ') diff --git a/streamlit_app.py b/streamlit_app.py index fe71ddd..61ff970 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -28,69 +28,64 @@ dotenv.load_dotenv(override=True) -API_MODELS = { - "microsoft/Phi-4-mini-instruct": os.environ["PHI_URL"], - "Qwen/Qwen3-0.6B": os.environ["QWEN_URL"] -} +API_MODELS = {"microsoft/Phi-4-mini-instruct": os.environ["PHI_URL"], "Qwen/Qwen3-0.6B": os.environ["QWEN_URL"]} -API_EMBEDDINGS = { - 'intfloat/multilingual-e5-large-instruct-modal': os.environ['EMBEDS_URL'] -} +API_EMBEDDINGS = {"intfloat/multilingual-e5-large-instruct-modal": os.environ["EMBEDS_URL"]} -if 'rqa' not in st.session_state: - st.session_state['rqa'] = {} +if "rqa" not in st.session_state: + st.session_state["rqa"] = {} -if 'model' not in st.session_state: - st.session_state['model'] = None +if "model" not in st.session_state: + st.session_state["model"] = None -if 'api_keys' not in st.session_state: - st.session_state['api_keys'] = {} +if "api_keys" not in st.session_state: + st.session_state["api_keys"] = {} -if 'doc_id' not in st.session_state: - st.session_state['doc_id'] = None +if "doc_id" not in st.session_state: + st.session_state["doc_id"] = None -if 'loaded_embeddings' not in st.session_state: - st.session_state['loaded_embeddings'] = None +if "loaded_embeddings" not in st.session_state: + st.session_state["loaded_embeddings"] = None -if 'hash' not in st.session_state: - st.session_state['hash'] = None +if "hash" not in st.session_state: + st.session_state["hash"] = None -if 'git_rev' not in st.session_state: - st.session_state['git_rev'] = "unknown" +if "git_rev" not in st.session_state: + st.session_state["git_rev"] = "unknown" if os.path.exists("revision.txt"): - with open("revision.txt", 'r') as fr: + with open("revision.txt", "r") as fr: from_file = fr.read() - st.session_state['git_rev'] = from_file if len(from_file) > 0 else "unknown" + st.session_state["git_rev"] = from_file if len(from_file) > 0 else "unknown" if "messages" not in st.session_state: st.session_state.messages = [] -if 'ner_processing' not in st.session_state: - st.session_state['ner_processing'] = False +if "ner_processing" not in st.session_state: + st.session_state["ner_processing"] = False -if 'uploaded' not in st.session_state: - st.session_state['uploaded'] = False +if "uploaded" not in st.session_state: + st.session_state["uploaded"] = False -if 'memory' not in st.session_state: - st.session_state['memory'] = None +if "memory" not in st.session_state: + st.session_state["memory"] = None -if 'binary' not in st.session_state: - st.session_state['binary'] = None +if "binary" not in st.session_state: + st.session_state["binary"] = None -if 'annotations' not in st.session_state: - st.session_state['annotations'] = None +if "annotations" not in st.session_state: + st.session_state["annotations"] = None -if 'should_show_annotations' not in st.session_state: - st.session_state['should_show_annotations'] = True +if "should_show_annotations" not in st.session_state: + st.session_state["should_show_annotations"] = True -if 'pdf' not in st.session_state: - st.session_state['pdf'] = None +if "pdf" not in st.session_state: + st.session_state["pdf"] = None -if 'embeddings' not in st.session_state: - st.session_state['embeddings'] = None +if "embeddings" not in st.session_state: + st.session_state["embeddings"] = None -if 'scroll_to_first_annotation' not in st.session_state: - st.session_state['scroll_to_first_annotation'] = False +if "scroll_to_first_annotation" not in st.session_state: + st.session_state["scroll_to_first_annotation"] = False st.set_page_config( page_title="Scientific Document Insights Q/A", @@ -98,10 +93,10 @@ initial_sidebar_state="expanded", layout="wide", menu_items={ - 'Get Help': 'https://github.com/lfoppiano/document-qa', - 'Report a bug': "https://github.com/lfoppiano/document-qa/issues", - 'About': "Upload a scientific article in PDF, ask questions, get insights." - } + "Get Help": "https://github.com/lfoppiano/document-qa", + "Report a bug": "https://github.com/lfoppiano/document-qa/issues", + "About": "Upload a scientific article in PDF, ask questions, get insights.", + }, ) st.markdown( @@ -115,7 +110,7 @@ } </style> """, - unsafe_allow_html=True + unsafe_allow_html=True, ) @@ -125,17 +120,17 @@ def new_file(): Clears previous embeddings, annotations, and conversation memory so the pipeline starts fresh for the new document. """ - st.session_state['loaded_embeddings'] = None - st.session_state['doc_id'] = None - st.session_state['uploaded'] = True - st.session_state['annotations'] = [] - if st.session_state['memory']: - st.session_state['memory'].clear() + st.session_state["loaded_embeddings"] = None + st.session_state["doc_id"] = None + st.session_state["uploaded"] = True + st.session_state["annotations"] = [] + if st.session_state["memory"]: + st.session_state["memory"].clear() def clear_memory(): """Clear the conversation buffer memory (chat history).""" - st.session_state['memory'].clear() + st.session_state["memory"].clear() # @st.cache_resource @@ -150,30 +145,16 @@ def init_qa(model_name, embeddings_name): Returns: DocumentQAEngine: Ready-to-use engine instance. """ - st.session_state['memory'] = ConversationBufferMemory( - memory_key="chat_history", - return_messages=True - ) - chat = ChatOpenAI( - model=model_name, - temperature=0.0, - base_url=API_MODELS[model_name], - api_key=os.environ.get('API_KEY') - ) + st.session_state["memory"] = ConversationBufferMemory(memory_key="chat_history", return_messages=True) + chat = ChatOpenAI(model=model_name, temperature=0.0, base_url=API_MODELS[model_name], api_key=os.environ.get("API_KEY")) embeddings = ModalEmbeddings( - url=API_EMBEDDINGS[embeddings_name], - model_name=embeddings_name, - api_key=os.environ.get('EMBEDS_API_KEY') + url=API_EMBEDDINGS[embeddings_name], model_name=embeddings_name, api_key=os.environ.get("EMBEDS_API_KEY") ) storage = DataStorage(embeddings) return DocumentQAEngine( - chat, - storage, - grobid_url=os.environ['GROBID_URL'], - memory=st.session_state['memory'], - ping_grobid_server=False + chat, storage, grobid_url=os.environ["GROBID_URL"], memory=st.session_state["memory"], ping_grobid_server=False ) @@ -187,25 +168,26 @@ def init_ner(): Returns: GrobidAggregationProcessor: Configured processor instance. """ - quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True) + quantities_client = QuantitiesAPI(os.environ["GROBID_QUANTITIES_URL"], check_server=True) materials_client = NERClientGeneric(ping=True) config_materials = { - 'grobid': { - "server": os.environ['GROBID_MATERIALS_URL'], - 'sleep_time': 5, - 'timeout': 60, - 'url_mapping': { - 'processText_disable_linking': "/service/process/text?disableLinking=True", + "grobid": { + "server": os.environ["GROBID_MATERIALS_URL"], + "sleep_time": 5, + "timeout": 60, + "url_mapping": { + "processText_disable_linking": "/service/process/text?disableLinking=True", # 'processText_disable_linking': "/service/process/text" - } + }, } } materials_client.set_config(config_materials) - gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client, - grobid_superconductors_client=materials_client) + gqa = GrobidAggregationProcessor( + grobid_quantities_client=quantities_client, grobid_superconductors_client=materials_client + ) return gqa @@ -230,15 +212,15 @@ def play_old_messages(container): Called on Streamlit reruns to restore the visible conversation history from ``st.session_state['messages']``. """ - if st.session_state['messages']: - for message in st.session_state['messages']: - if message['role'] == 'user': - container.chat_message("user").markdown(message['content']) - elif message['role'] == 'assistant': + if st.session_state["messages"]: + for message in st.session_state["messages"]: + if message["role"] == "user": + container.chat_message("user").markdown(message["content"]) + elif message["role"] == "assistant": if mode == "LLM": - container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True) + container.chat_message("assistant").markdown(message["content"], unsafe_allow_html=True) else: - container.chat_message("assistant").write(message['content']) + container.chat_message("assistant").write(message["content"]) # is_api_key_provided = st.session_state['api_key'] @@ -247,37 +229,39 @@ def play_old_messages(container): st.title("📝 Document Q/A") st.markdown("Upload a scientific article in PDF, ask questions, get insights.") st.markdown( - ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") + ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: " + ) st.markdown("LM and Embeddings are powered by [Modal.com](https://modal.com/)") st.divider() - st.session_state['model'] = model = st.selectbox( + st.session_state["model"] = model = st.selectbox( "Model:", options=API_MODELS.keys(), - index=(list(API_MODELS.keys())).index( - os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0, + index=(list(API_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]) + if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] + else 0, placeholder="Select model", help="Select the LLM model:", - disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] + disabled=st.session_state["doc_id"] is not None or st.session_state["uploaded"], ) - st.session_state['embeddings'] = embedding_name = st.selectbox( + st.session_state["embeddings"] = embedding_name = st.selectbox( "Embeddings:", options=API_EMBEDDINGS.keys(), - index=(list(API_EMBEDDINGS.keys())).index( - os.environ["DEFAULT_EMBEDDING"]) if "DEFAULT_EMBEDDING" in os.environ and os.environ[ - "DEFAULT_EMBEDDING"] else 0, + index=(list(API_EMBEDDINGS.keys())).index(os.environ["DEFAULT_EMBEDDING"]) + if "DEFAULT_EMBEDDING" in os.environ and os.environ["DEFAULT_EMBEDDING"] + else 0, placeholder="Select embedding", help="Select the Embedding function:", - disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] + disabled=st.session_state["doc_id"] is not None or st.session_state["uploaded"], ) - api_key = os.environ['API_KEY'] + api_key = os.environ["API_KEY"] - if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']: + if model not in st.session_state["rqa"] or model not in st.session_state["api_keys"]: with st.spinner("Preparing environment"): - st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings']) - st.session_state['api_keys'][model] = api_key + st.session_state["rqa"][model] = init_qa(model, st.session_state["embeddings"]) + st.session_state["api_keys"][model] = api_key left_column, right_column = st.columns([5, 4]) right_column = right_column.container(border=True) @@ -288,9 +272,8 @@ def play_old_messages(container): "Upload a scientific article", type=("pdf"), on_change=new_file, - disabled=st.session_state['model'] is not None and st.session_state['model'] not in - st.session_state['api_keys'], - help="The full-text is extracted using Grobid." + disabled=st.session_state["model"] is not None and st.session_state["model"] not in st.session_state["api_keys"], + help="The full-text is extracted using Grobid.", ) placeholder = st.empty() @@ -299,14 +282,10 @@ def play_old_messages(container): question = st.chat_input( "Ask something about the article", # placeholder="Can you give me a short summary?", - disabled=not uploaded_file + disabled=not uploaded_file, ) -query_modes = { - "llm": "LLM Q/A", - "embeddings": "Embeddings", - "question_coefficient": "Question coefficient" -} +query_modes = {"llm": "LLM Q/A", "embeddings": "Embeddings", "question_coefficient": "Question coefficient"} with st.sidebar: st.header("Settings") @@ -318,53 +297,73 @@ def play_old_messages(container): horizontal=True, format_func=lambda x: query_modes[x], help="LLM will respond the question, Embedding will show the " - "relevant paragraphs to the question in the paper. " - "Question coefficient attempt to estimate how effective the question will be answered." + "relevant paragraphs to the question in the paper. " + "Question coefficient attempt to estimate how effective the question will be answered.", ) - st.session_state['scroll_to_first_annotation'] = st.checkbox( - "Scroll to context", - help='The PDF viewer will automatically scroll to the first relevant passage in the document.' + st.session_state["scroll_to_first_annotation"] = st.checkbox( + "Scroll to context", help="The PDF viewer will automatically scroll to the first relevant passage in the document." ) - st.session_state['ner_processing'] = st.checkbox( + st.session_state["ner_processing"] = st.checkbox( "Identify materials and properties.", - help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.' + help="The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.", ) # Add a checkbox for showing annotations # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True) # st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True) - chunk_size = st.slider("Text chunks size", -1, 2000, value=-1, - help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.", - disabled=uploaded_file is not None) + chunk_size = st.slider( + "Text chunks size", + -1, + 2000, + value=-1, + help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.", + disabled=uploaded_file is not None, + ) if chunk_size == -1: - context_size = st.slider("Context size (paragraphs)", 3, 20, value=10, - help="Number of paragraphs to consider when answering a question", - disabled=not uploaded_file) + context_size = st.slider( + "Context size (paragraphs)", + 3, + 20, + value=10, + help="Number of paragraphs to consider when answering a question", + disabled=not uploaded_file, + ) else: - context_size = st.slider("Context size (chunks)", 3, 10, value=4, - help="Number of chunks to consider when answering a question", - disabled=not uploaded_file) + context_size = st.slider( + "Context size (chunks)", + 3, + 10, + value=4, + help="Number of chunks to consider when answering a question", + disabled=not uploaded_file, + ) st.divider() st.header("Documentation") st.markdown("https://github.com/lfoppiano/document-qa") st.markdown( - """Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions.""") + """Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions.""" + ) - if st.session_state['git_rev'] != "unknown": - st.markdown("**Revision number**: [" + st.session_state[ - 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")") + if st.session_state["git_rev"] != "unknown": + st.markdown( + "**Revision number**: [" + + st.session_state["git_rev"] + + "](https://github.com/lfoppiano/document-qa/commit/" + + st.session_state["git_rev"] + + ")" + ) if uploaded_file and not st.session_state.loaded_embeddings: - if model not in st.session_state['api_keys']: + if model not in st.session_state["api_keys"]: st.error("Before uploading a document, you must enter the API key. ") st.stop() with left_column: try: - with st.spinner('Reading file, calling Grobid, and creating in-memory embeddings...'): + with st.spinner("Reading file, calling Grobid, and creating in-memory embeddings..."): binary = uploaded_file.getvalue() tmp_path = None try: @@ -372,22 +371,20 @@ def play_old_messages(container): tmp_file.write(bytearray(binary)) tmp_file.flush() tmp_path = tmp_file.name - st.session_state['binary'] = binary + st.session_state["binary"] = binary - st.session_state['doc_id'] = st.session_state['rqa'][model].create_memory_embeddings( - tmp_path, - chunk_size=chunk_size, - perc_overlap=0.1 + st.session_state["doc_id"] = st.session_state["rqa"][model].create_memory_embeddings( + tmp_path, chunk_size=chunk_size, perc_overlap=0.1 ) finally: if tmp_path and os.path.exists(tmp_path): os.unlink(tmp_path) - st.session_state['loaded_embeddings'] = True + st.session_state["loaded_embeddings"] = True st.session_state.messages = [] except GrobidServiceError as exc: - st.session_state['doc_id'] = None - st.session_state['loaded_embeddings'] = False - st.session_state['uploaded'] = False + st.session_state["doc_id"] = None + st.session_state["loaded_embeddings"] = False + st.session_state["uploaded"] = False message = str(exc).strip() or "Grobid is not responding." if not message.endswith((".", "!", "?")): message += "." @@ -418,8 +415,9 @@ def generate_color_gradient(num_elements): # Generate a linear gradient of colors color_gradient = [ - rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in - zip(warm_color, cold_color))) + rgb_to_hex( + tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in zip(warm_color, cold_color)) + ) for i in range(num_elements) ] @@ -432,13 +430,13 @@ def generate_color_gradient(num_elements): for message in st.session_state.messages: # with messages.chat_message(message["role"]): - if message['mode'] == "llm": + if message["mode"] == "llm": messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) - elif message['mode'] == "embeddings": + elif message["mode"] == "embeddings": messages.chat_message(message["role"]).write(message["content"]) - elif message['mode'] == "question_coefficient": + elif message["mode"] == "question_coefficient": messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) - if model not in st.session_state['rqa']: + if model not in st.session_state["rqa"]: st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `") st.stop() @@ -446,45 +444,40 @@ def generate_color_gradient(num_elements): if mode == "embeddings": with placeholder: with st.spinner("Fetching the relevant context..."): - text_response, coordinates = st.session_state['rqa'][model].query_storage( - question, - st.session_state.doc_id, - context_size=context_size + text_response, coordinates = st.session_state["rqa"][model].query_storage( + question, st.session_state.doc_id, context_size=context_size ) elif mode == "llm": with placeholder: with st.spinner("Generating LLM response..."): - _, text_response, coordinates = st.session_state['rqa'][model].query_document( - question, - st.session_state.doc_id, - context_size=context_size + _, text_response, coordinates = st.session_state["rqa"][model].query_document( + question, st.session_state.doc_id, context_size=context_size ) elif mode == "question_coefficient": with st.spinner("Estimate question/context relevancy..."): - text_response, coordinates = st.session_state['rqa'][model].analyse_query( - question, - st.session_state.doc_id, - context_size=context_size + text_response, coordinates = st.session_state["rqa"][model].analyse_query( + question, st.session_state.doc_id, context_size=context_size ) - annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] - for coord_doc in coordinates] + annotations = [ + [GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] + for coord_doc in coordinates + ] gradients = generate_color_gradient(len(annotations)) for i, color in enumerate(gradients): for annotation in annotations[i]: - annotation['color'] = color + annotation["color"] = color if i == 0: - annotation['border'] = "dotted" + annotation["border"] = "dotted" - st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in - annotation_doc] + st.session_state["annotations"] = [annotation for annotation_doc in annotations for annotation in annotation_doc] if not text_response: st.error("Something went wrong. Contact info AT sciencialab.com to report the issue through GitHub.") if mode == "llm": - if st.session_state['ner_processing']: + if st.session_state["ner_processing"]: with st.spinner("Processing NER on LLM response..."): entities = gqa.process_single_text(text_response) decorated_text = decorate_text_with_annotations(text_response.strip(), entities) @@ -500,13 +493,14 @@ def generate_color_gradient(num_elements): play_old_messages(messages) with left_column: - if st.session_state['binary']: + if st.session_state["binary"]: with st.container(height=600): pdf_viewer( - input=st.session_state['binary'], + input=st.session_state["binary"], annotation_outline_size=2, - annotations=st.session_state['annotations'] if st.session_state['annotations'] else [], + annotations=st.session_state["annotations"] if st.session_state["annotations"] else [], render_text=True, - scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state[ - 'scroll_to_first_annotation']) else None + scroll_to_annotation=1 + if (st.session_state["annotations"] and st.session_state["scroll_to_first_annotation"]) + else None, ) diff --git a/tests/conftest.py b/tests/conftest.py index e1ac83d..e6663e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,11 +10,11 @@ LOGGER = logging.getLogger(__name__) -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def setup_logging(): logging.root.handlers = [] - logging.basicConfig(level='INFO') - logging.getLogger('tests').setLevel('DEBUG') + logging.basicConfig(level="INFO") + logging.getLogger("tests").setLevel("DEBUG") # logging.getLogger('sciencebeam_trainer_delft').setLevel('DEBUG') @@ -22,7 +22,7 @@ def _backport_assert_called(mock: MagicMock): assert mock.called -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def patch_magicmock(): try: MagicMock.assert_called @@ -34,4 +34,3 @@ def patch_magicmock(): def temp_dir(tmpdir: LocalPath): # convert to standard Path return Path(str(tmpdir)) - diff --git a/tests/test_document_qa_engine.py b/tests/test_document_qa_engine.py index 959846d..109c6cd 100644 --- a/tests/test_document_qa_engine.py +++ b/tests/test_document_qa_engine.py @@ -5,67 +5,43 @@ def test_merge_passages_small_chunk(): merger = TextMerger() passages = [ - { - 'text': "The quick brown fox jumps over the tree", - 'coordinates': '1' - }, - { - 'text': "and went straight into the mouth of a bear.", - 'coordinates': '2' - }, - { - 'text': "The color of the colors is a color with colors", - 'coordinates': '3' - }, - { - 'text': "the main colors are not the colorw we show", - 'coordinates': '4' - } + {"text": "The quick brown fox jumps over the tree", "coordinates": "1"}, + {"text": "and went straight into the mouth of a bear.", "coordinates": "2"}, + {"text": "The color of the colors is a color with colors", "coordinates": "3"}, + {"text": "the main colors are not the colorw we show", "coordinates": "4"}, ] new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0) assert len(new_passages) == 4 - assert new_passages[0]['coordinates'] == "1" - assert new_passages[0]['text'] == "The quick brown fox jumps over the tree" + assert new_passages[0]["coordinates"] == "1" + assert new_passages[0]["text"] == "The quick brown fox jumps over the tree" - assert new_passages[1]['coordinates'] == "2" - assert new_passages[1]['text'] == "and went straight into the mouth of a bear." + assert new_passages[1]["coordinates"] == "2" + assert new_passages[1]["text"] == "and went straight into the mouth of a bear." - assert new_passages[2]['coordinates'] == "3" - assert new_passages[2]['text'] == "The color of the colors is a color with colors" + assert new_passages[2]["coordinates"] == "3" + assert new_passages[2]["text"] == "The color of the colors is a color with colors" - assert new_passages[3]['coordinates'] == "4" - assert new_passages[3]['text'] == "the main colors are not the colorw we show" + assert new_passages[3]["coordinates"] == "4" + assert new_passages[3]["text"] == "the main colors are not the colorw we show" def test_merge_passages_big_chunk(): merger = TextMerger() passages = [ - { - 'text': "The quick brown fox jumps over the tree", - 'coordinates': '1' - }, - { - 'text': "and went straight into the mouth of a bear.", - 'coordinates': '2' - }, - { - 'text': "The color of the colors is a color with colors", - 'coordinates': '3' - }, - { - 'text': "the main colors are not the colorw we show", - 'coordinates': '4' - } + {"text": "The quick brown fox jumps over the tree", "coordinates": "1"}, + {"text": "and went straight into the mouth of a bear.", "coordinates": "2"}, + {"text": "The color of the colors is a color with colors", "coordinates": "3"}, + {"text": "the main colors are not the colorw we show", "coordinates": "4"}, ] new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0) assert len(new_passages) == 2 - assert new_passages[0]['coordinates'] == "1;2" - assert new_passages[0][ - 'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear." + assert new_passages[0]["coordinates"] == "1;2" + assert new_passages[0]["text"] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear." - assert new_passages[1]['coordinates'] == "3;4" - assert new_passages[1][ - 'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show" + assert new_passages[1]["coordinates"] == "3;4" + assert ( + new_passages[1]["text"] == "The color of the colors is a color with colors the main colors are not the colorw we show" + ) diff --git a/tests/test_grobid_processors.py b/tests/test_grobid_processors.py index b736d3e..dd29810 100644 --- a/tests/test_grobid_processors.py +++ b/tests/test_grobid_processors.py @@ -14,8 +14,8 @@ def test_get_xml_nodes_body_paragraphs(): - with open(os.path.join(TEST_DATA_PATH, "2312.07559.paragraphs.tei.xml"), 'r') as fo: - soup = BeautifulSoup(fo, 'xml') + with open(os.path.join(TEST_DATA_PATH, "2312.07559.paragraphs.tei.xml"), "r") as fo: + soup = BeautifulSoup(fo, "xml") nodes = get_xml_nodes_body(soup, use_paragraphs=True) @@ -23,8 +23,8 @@ def test_get_xml_nodes_body_paragraphs(): def test_get_xml_nodes_body_sentences(): - with open(os.path.join(TEST_DATA_PATH, "2312.07559.sentences.tei.xml"), 'r') as fo: - soup = BeautifulSoup(fo, 'xml') + with open(os.path.join(TEST_DATA_PATH, "2312.07559.sentences.tei.xml"), "r") as fo: + soup = BeautifulSoup(fo, "xml") children = get_xml_nodes_body(soup, use_paragraphs=False) @@ -32,8 +32,8 @@ def test_get_xml_nodes_body_sentences(): def test_get_xml_nodes_figures(): - with open(os.path.join(TEST_DATA_PATH, "2312.07559.paragraphs.tei.xml"), 'r') as fo: - soup = BeautifulSoup(fo, 'xml') + with open(os.path.join(TEST_DATA_PATH, "2312.07559.paragraphs.tei.xml"), "r") as fo: + soup = BeautifulSoup(fo, "xml") children = get_xml_nodes_figures(soup) @@ -41,8 +41,8 @@ def test_get_xml_nodes_figures(): def test_get_xml_nodes_header_paragraphs(): - with open(os.path.join(TEST_DATA_PATH, "2312.07559.paragraphs.tei.xml"), 'r') as fo: - soup = BeautifulSoup(fo, 'xml') + with open(os.path.join(TEST_DATA_PATH, "2312.07559.paragraphs.tei.xml"), "r") as fo: + soup = BeautifulSoup(fo, "xml") children = get_xml_nodes_header(soup) @@ -50,13 +50,14 @@ def test_get_xml_nodes_header_paragraphs(): def test_get_xml_nodes_header_sentences(): - with open(os.path.join(TEST_DATA_PATH, "2312.07559.sentences.tei.xml"), 'r') as fo: - soup = BeautifulSoup(fo, 'xml') + with open(os.path.join(TEST_DATA_PATH, "2312.07559.sentences.tei.xml"), "r") as fo: + soup = BeautifulSoup(fo, "xml") children = get_xml_nodes_header(soup, use_paragraphs=False) assert sum([len(child) for k, child in children.items()]) == 15 + def test_grobid_service_error_default_status_code(): error = GrobidServiceError("Something went wrong") assert error.status_code is None @@ -68,6 +69,7 @@ def test_grobid_service_error_stores_status_code(): assert error.status_code == 502 assert "Bad gateway" in str(error) + @pytest.fixture def grobid_processor(): with patch("document_qa.grobid_processors.GrobidClient") as mock_client_class: @@ -79,9 +81,7 @@ def grobid_processor(): # Connection/timeout failures def test_process_structure_raises_on_connection_error(grobid_processor): - grobid_processor.grobid_client.process_pdf.side_effect = requests.exceptions.ConnectionError( - "Connection refused" - ) + grobid_processor.grobid_client.process_pdf.side_effect = requests.exceptions.ConnectionError("Connection refused") with pytest.raises(GrobidServiceError) as exc_info: grobid_processor.process_structure("fake.pdf") @@ -90,9 +90,7 @@ def test_process_structure_raises_on_connection_error(grobid_processor): def test_process_structure_raises_on_timeout(grobid_processor): - grobid_processor.grobid_client.process_pdf.side_effect = requests.exceptions.Timeout( - "Request timed out" - ) + grobid_processor.grobid_client.process_pdf.side_effect = requests.exceptions.Timeout("Request timed out") with pytest.raises(GrobidServiceError) as exc_info: grobid_processor.process_structure("fake.pdf") @@ -102,9 +100,7 @@ def test_process_structure_raises_on_timeout(grobid_processor): # Local/usage errors must NOT be masked as a Grobid outage def test_process_structure_does_not_mask_local_errors(grobid_processor): - grobid_processor.grobid_client.process_pdf.side_effect = FileNotFoundError( - "no such file" - ) + grobid_processor.grobid_client.process_pdf.side_effect = FileNotFoundError("no such file") with pytest.raises(FileNotFoundError): grobid_processor.process_structure("fake.pdf") @@ -131,9 +127,7 @@ def test_process_structure_raises_on_500_status(grobid_processor): def test_process_structure_includes_reason_from_error_body(grobid_processor): - grobid_processor.grobid_client.process_pdf.return_value = ( - "fake.pdf", 500, "[BAD_INPUT_DATA] PDF could not be parsed." - ) + grobid_processor.grobid_client.process_pdf.return_value = ("fake.pdf", 500, "[BAD_INPUT_DATA] PDF could not be parsed.") with pytest.raises(GrobidServiceError) as exc_info: grobid_processor.process_structure("fake.pdf") From 46c48e0e94fb9fde7b14beb79cae9126be4a0f57 Mon Sep 17 00:00:00 2001 From: Luca Foppiano <luca@foppiano.org> Date: Sun, 7 Jun 2026 23:01:53 +0100 Subject: [PATCH 5/5] feat: pin public URL labels for embeddings endpoints --- document_qa/deployment/README.md | 3 ++- document_qa/deployment/modal_embeddings_en.py | 8 +++++++- document_qa/deployment/modal_embeddings_multilang.py | 8 +++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/document_qa/deployment/README.md b/document_qa/deployment/README.md index 3a7f497..61f534e 100644 --- a/document_qa/deployment/README.md +++ b/document_qa/deployment/README.md @@ -58,7 +58,7 @@ Each deploy prints a public `https://<...>.modal.run` URL. Copy it into `.env`: ```env PHI_URL=https://<account>--phi-4-mini-instruct-qa-vllm-serve.modal.run/v1 QWEN_URL=https://<account>--qwen-0-6b-qa-vllm-serve.modal.run/v1 -EMBEDS_URL=https://<account>--intfloat-multilingual-e5-large-instruct-embeddings-embed.modal.run +EMBEDS_URL=https://<account>--embeddings-multilang.modal.run # English-only: --embeddings-en API_KEY=<your-llm-token> # matches document-qa-api-key EMBEDS_API_KEY=<your-embedding-token> # matches document-qa-embedding-key ``` @@ -101,4 +101,5 @@ These knobs live near the top of each script (or in `_embeddings_app.py`): | `gpu` | `@app.function` / `@app.cls` | `A10G` is cheaper; `L40S` is faster. Embeddings default to `L40S`, inference to `A10G`. | | `scaledown_window` | decorator | Idle time before a replica is stopped (cost vs. cold starts). | | `max_inputs` | `@modal.concurrent` | Concurrent requests per replica — tune to GPU memory. | +| `LABEL` | `modal_embeddings_*.py` | Pins the public URL (`--<label>.modal.run`). Without it Modal truncates the long auto-name and appends a random hash. | | `FAST_BOOT` | `modal_inference_phi.py` | `--enforce-eager` for faster cold starts vs. peak throughput. | diff --git a/document_qa/deployment/modal_embeddings_en.py b/document_qa/deployment/modal_embeddings_en.py index 2690957..6a0a813 100644 --- a/document_qa/deployment/modal_embeddings_en.py +++ b/document_qa/deployment/modal_embeddings_en.py @@ -19,6 +19,12 @@ MODEL_NAME = "intfloat/e5-large-v2" MODEL_REVISION = "756b8ddb6e4bda943d3b6f5d131355825efda70c" +# Pin the public URL label. Without this, Modal derives the label from +# "<app>-<function>", which is too long here and gets truncated with a random +# hash suffix. The label gives a stable URL: +# https://<workspace>--embeddings-en.modal.run +LABEL = "embeddings-en" + app = modal.App("intfloat-e5-large-v2-embeddings") @@ -29,6 +35,6 @@ class EmbeddingModel: def load_model(self): self.tokenizer, self.model, self.device = load_embedding_model(MODEL_NAME, MODEL_REVISION) - @modal.fastapi_endpoint(method="POST") + @modal.fastapi_endpoint(method="POST", label=LABEL) def embed(self, request: Request, text: Annotated[str, Form()]): return run_embed(self.tokenizer, self.model, self.device, request, text) diff --git a/document_qa/deployment/modal_embeddings_multilang.py b/document_qa/deployment/modal_embeddings_multilang.py index bd6a808..f49ad35 100644 --- a/document_qa/deployment/modal_embeddings_multilang.py +++ b/document_qa/deployment/modal_embeddings_multilang.py @@ -19,6 +19,12 @@ MODEL_NAME = "intfloat/multilingual-e5-large-instruct" MODEL_REVISION = "84344a23ee1820ac951bc365f1e91d094a911763" +# Pin the public URL label. Without this, Modal derives the label from +# "<app>-<function>", which is too long here and gets truncated with a random +# hash suffix (e.g. ...-embed-c5fe6f.modal.run). The label gives a stable URL: +# https://<workspace>--embeddings-multilang.modal.run +LABEL = "embeddings-multilang" + app = modal.App("intfloat-multilingual-e5-large-instruct-embeddings") @@ -29,6 +35,6 @@ class EmbeddingModel: def load_model(self): self.tokenizer, self.model, self.device = load_embedding_model(MODEL_NAME, MODEL_REVISION) - @modal.fastapi_endpoint(method="POST") + @modal.fastapi_endpoint(method="POST", label=LABEL) def embed(self, request: Request, text: Annotated[str, Form()]): return run_embed(self.tokenizer, self.model, self.device, request, text)