Skip to content
59 changes: 52 additions & 7 deletions libs/knowledge-store/ragstack_knowledge_store/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,67 @@
from abc import ABC, abstractmethod
from typing import List

from abc import ABC
from typing import List, Any, Optional

class EmbeddingModel(ABC):
"""Embedding model."""

@abstractmethod
def __init__(self, embeddings: Any, method_map: Optional[dict] = None, other_methods: Optional[List[str]] = None):
self.embeddings = embeddings
self.method_name = {}
method_map = method_map if method_map else {}
other_methods = other_methods if other_methods else []

base_methods = ['embed_texts', 'aembed_texts', 'embed_query', 'aembed_query']
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should try to add all of these as methods, it's definitely pretty messy.

I think we should just have embed_mime(self, mime_type: str, content: Union[str, Bytes]) or something like that. Then there is only a single abstract method to use for any mime type and the names can be different, etc.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% but right now, LangChain doesn't have "embed_mime" :)

extended_methods = ['embed_images', 'aembed_images', 'embed_image', 'aembed_image']

# Combining all method names, including those mapped
all_methods = set(base_methods + extended_methods + other_methods + list(method_map.values()))

for method in all_methods:
mapped_method = method_map.get(method)
if hasattr(embeddings, method):
self.method_name[method] = method
elif hasattr(embeddings, mapped_method) if mapped_method else False:
self.method_name[method] = mapped_method
else:
self.method_name[method] = None

def does_implement(self, method_name: str) -> bool:
"""Check if the method is implemented."""
return self.method_name.get(method_name) is not None

def implements(self) -> List[str]:
"""List of methods that are implemented"""
return [method for method, impl in self.method_name.items() if impl is not None]

def invoke(self, method_name: str, *args, **kwargs):
"""Invoke a synchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

async def ainvoke(self, method_name: str, *args, **kwargs):
"""Invoke an asynchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return await getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return self.invoke('embed_texts', texts)

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.invoke('embed_query', text)

@abstractmethod
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return await self.ainvoke('aembed_texts', texts)

@abstractmethod
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text."""
return await self.ainvoke('aembed_query', text)

110 changes: 69 additions & 41 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,21 @@ class Node:
"""Metadata for the node. May contain information used to link this node
with other nodes."""

content: str = None
"""Encoded content"""

mime_type: str = None
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

@dataclass
class TextNode(Node):
text: str = None
"""Text contained by the node."""

mime_type = "text/plain"

class SetupMode(Enum):
SYNC = 1
Expand Down Expand Up @@ -326,52 +335,71 @@ def add_nodes(
self,
nodes: Iterable[Node] = None,
) -> Iterable[str]:
texts = []
metadatas = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
texts.append(node.text)
metadatas.append(node.metadata)

text_embeddings = self._embedding.embed_texts(texts)

# Organize nodes by MIME type
mime_buckets = {}
ids = []

tag_to_new_sources: Dict[str, List[Tuple[str, str]]] = {}
tag_to_new_targets: Dict[str, Dict[str, Tuple[str, List[float]]]] = {}
# Prepare nodes based on their type
for node in nodes:
if isinstance(node, TextNode):
if 'text' not in mime_buckets:
mime_buckets['text'] = []
mime_buckets['text'].append(node)
if isinstance(node, Node) and node.mime_type:
main_mime_type = node.mime_type.split('/')[0] # Split and take the first part, e.g., "image" from "image/png"
if main_mime_type not in mime_buckets:
mime_buckets[main_mime_type] = []
mime_buckets[main_mime_type].append(node)
else:
raise ValueError("Unsupported node type")

# Process each MIME bucket
embeddings_dict = {}
for mime_type, nodes_list in mime_buckets.items():
method_name = f"embed_{mime_type}s"
if self._embedding.does_implement(method_name):
texts = [node.text if isinstance(node, TextNode) else node.content for node in nodes_list]
embeddings_dict[mime_type] = self._embedding.invoke(method_name, texts)
else:
# If no bulk method, try to call a singular method for each content
singular_method_name = f"embed_{mime_type}"
if self._embedding.does_implement(singular_method_name):
embeddings = []
for node in nodes_list:
embedding = self._embedding.invoke(singular_method_name, node.text if isinstance(node, TextNode) else node.content)
embeddings.append(embedding)
embeddings_dict[mime_type] = embeddings
else:
raise NotImplementedError(f"No embedding method available for MIME type: {mime_type}, implemented methods: {self._embedding.implements()}.")


# Step 1: Add the nodes, collecting the tags and new sources / targets.
tag_to_new_sources = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has changed significantly from the previous implementation. I think it will need to be reworked if we work on adding this right now.

tag_to_new_targets = {}
with self._concurrent_queries() as cq:
tuples = zip(texts, text_embeddings, metadatas)
for text, text_embedding, metadata in tuples:
if CONTENT_ID not in metadata:
metadata[CONTENT_ID] = secrets.token_hex(8)
id = metadata[CONTENT_ID]
ids.append(id)

link_to_tags = set() # link to these tags
link_from_tags = set() # link from these tags

for tag in get_link_tags(metadata):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incom`ing link should be linked *from* nodes with the given tag.
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
tag.kind,
text_embedding,
)
if tag.direction == "outgoing" or tag.direction == "bidir":
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, list()).append(
(tag.kind, id)
)

cq.execute(
self._insert_passage,
(id, text, text_embedding, link_to_tags, link_from_tags),
)
for mime_type, embeddings in embeddings_dict.items():
for node, embedding in zip(mime_buckets[mime_type], embeddings):
if CONTENT_ID not in node.metadata:
node.metadata[CONTENT_ID] = secrets.token_hex(8)
node_id = node.metadata[CONTENT_ID]
ids.append(node_id)

link_to_tags = set()
link_from_tags = set()

for tag in get_link_tags(node.metadata):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction in ["incoming", "bidir"]:
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, {})[node_id] = (tag.kind, embedding)
if tag.direction in ["outgoing", "bidir"]:
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, []).append((tag.kind, node_id))

cq.execute(
self._insert_passage,
(node_id, node.text if isinstance(node, TextNode) else node.content, embedding, link_to_tags, link_from_tags),
)

# Step 2: Query information about those tags to determine the edges to add.
# Add edges as needed.
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/ragstack_langchain/graph_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import GraphStore, Node, TextNode
from .base import GraphStore, Node
from .cassandra import CassandraGraphStore

__all__ = ["CassandraGraphStore", "GraphStore", "Node", "TextNode"]
__all__ = ["CassandraGraphStore", "GraphStore", "Node"]
21 changes: 14 additions & 7 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ class Node(Serializable):
"""Metadata for the node. May contain information used to link this node
with other nodes."""

content: str = None
"""Encoded content"""

class TextNode(Node):
text: str
"""Text contained by the node."""
mime_type: str = None
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

def _texts_to_nodes(
texts: Iterable[str],
Expand All @@ -61,10 +64,11 @@ def _texts_to_nodes(
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("texts iterable longer than ids")
yield TextNode(
yield Node(
id=_id,
metadata=_metadata,
text=text,
mime_type="text/plain",
content=text,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than texts")
Expand All @@ -81,10 +85,13 @@ def _documents_to_nodes(
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("documents iterable longer than ids")
yield TextNode(

yield Node(
id=_id,
metadata=doc.metadata,
text=doc.page_content,
mime_type=doc.metadata.get('mime_type', 'text/plain'),
mime_encoding=doc.metadata.get('mime_encoding', None),
content=doc.page_content,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than documents")
Expand Down
33 changes: 8 additions & 25 deletions libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,9 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings

from .base import GraphStore, Node, TextNode
from ragstack_knowledge_store import EmbeddingModel, graph_store


class _EmbeddingModelAdapter(EmbeddingModel):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings

def embed_texts(self, texts: List[str]) -> List[List[float]]:
return self.embeddings.embed_documents(texts)

def embed_query(self, text: str) -> List[float]:
return self.embeddings.embed_query(text)

async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
return await self.embeddings.aembed_documents(texts)

async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)

from .base import GraphStore, Node
from .embedding_adapter import EmbeddingAdapter
from ragstack_knowledge_store import graph_store

def _row_to_document(row) -> Document:
return Document(
Expand Down Expand Up @@ -78,7 +61,7 @@ def __init__(
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)

self.store = graph_store.GraphStore(
embedding=_EmbeddingModelAdapter(embedding),
embedding=EmbeddingAdapter(embedding),
node_table=node_table,
edge_table=edge_table,
session=session,
Expand All @@ -98,11 +81,11 @@ def add_nodes(
):
_nodes = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
if not isinstance(node, Node):
raise ValueError("Only adding Node is supported at the moment")
_nodes.append(
graph_store.TextNode(id=node.id, text=node.text, metadata=node.metadata)
)
graph_store.Node(id=node.id, content=node.content, mime_type=node.mime_type, mime_encoding=node.mime_encoding, metadata=node.metadata)
)
return self.store.add_nodes(_nodes)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import List
from ragstack_knowledge_store import EmbeddingModel

class EmbeddingAdapter(EmbeddingModel):
def __init__(self, embeddings):
super().__init__(embeddings,
method_map={'embed_texts': 'embed_documents',
'aembed_texts': 'aembed_documents'})