Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/sqlite_rag/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,33 @@ def add(
help="Optional metadata in JSON format to associate with the document",
metavar="JSON",
),
only_extensions: Optional[str] = typer.Option(
None,
"--only",
help="Only process these file extensions from supported list (comma-separated, e.g. 'py,js')",
),
exclude_extensions: Optional[str] = typer.Option(
None,
"--exclude",
help="File extensions to exclude (comma-separated, e.g. 'py,js')",
),
):
"""Add a file path to the database"""
rag_context = ctx.obj["rag_context"]
start_time = time.time()

# Parse extension lists
only_list = only_extensions.split(",") if only_extensions else None
exclude_list = exclude_extensions.split(",") if exclude_extensions else None
Comment thread
danielebriggi marked this conversation as resolved.
Outdated

rag = rag_context.get_rag()
rag.add(
path,
recursive=recursive,
use_relative_paths=use_relative_paths,
metadata=json.loads(metadata or "{}"),
only_extensions=only_list,
exclude_extensions=exclude_list,
)

elapsed_time = time.time() - start_time
Expand Down
99 changes: 62 additions & 37 deletions src/sqlite_rag/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,61 @@

class FileReader:
extensions = [
".c",
".cpp",
".css",
".csv",
".docx",
".go",
".h",
".hpp",
".html",
".java",
".js",
".json",
".kt",
".md",
".mdx",
".mjs",
".pdf",
".php",
".pptx",
".py",
".rb",
".rs",
".svelte",
".swift",
".ts",
".tsx",
".txt",
".vue",
".xml",
".xlsx",
".yaml",
".yml",
"c",
"cpp",
"css",
"csv",
"docx",
"go",
"h",
"hpp",
"html",
"java",
"js",
"json",
"kt",
"md",
"mdx",
"mjs",
"pdf",
"php",
"pptx",
"py",
"rb",
"rs",
"svelte",
"swift",
"ts",
"tsx",
"txt",
"vue",
"xml",
"xlsx",
"yaml",
"yml",
]
Comment thread
danielebriggi marked this conversation as resolved.

@staticmethod
def is_supported(path: Path) -> bool:
def is_supported(
path: Path,
only_extensions: Optional[list[str]] = None,
exclude_extensions: Optional[list[str]] = None,
) -> bool:
"""Check if the file extension is supported"""
Comment thread
danielebriggi marked this conversation as resolved.
Outdated
return path.suffix.lower() in FileReader.extensions
extension = path.suffix.lower().lstrip(".")

supported_extensions = set(FileReader.extensions)
exclude_set = set()

# Only keep those that are in both lists
if only_extensions:
only_set = {ext.lower().lstrip(".") for ext in only_extensions}
supported_extensions &= only_set

if exclude_extensions:
exclude_set = {ext.lower().lstrip(".") for ext in exclude_extensions}

return extension in supported_extensions and extension not in exclude_set

@staticmethod
def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str:
Expand All @@ -65,12 +82,19 @@ def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str
raise ValueError(f"Failed to parse file {path}") from exc

@staticmethod
def collect_files(path: Path, recursive: bool = False) -> list[Path]:
def collect_files(
path: Path,
recursive: bool = False,
only_extensions: Optional[list[str]] = None,
exclude_extensions: Optional[list[str]] = None,
) -> list[Path]:
"""Collect files from the path, optionally recursively"""
if not path.exists():
raise FileNotFoundError(f"{path} does not exist.")

if path.is_file() and FileReader.is_supported(path):
if path.is_file() and FileReader.is_supported(
path, only_extensions, exclude_extensions
):
return [path]

files_to_process = []
Expand All @@ -83,7 +107,8 @@ def collect_files(path: Path, recursive: bool = False) -> list[Path]:
files_to_process = [
f
for f in files_to_process
if f.is_file() and FileReader.is_supported(f)
if f.is_file()
and FileReader.is_supported(f, only_extensions, exclude_extensions)
]

return files_to_process
20 changes: 18 additions & 2 deletions src/sqlite_rag/sqliterag.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,32 @@ def add(
recursive: bool = False,
use_relative_paths: bool = False,
metadata: dict = {},
only_extensions: Optional[list[str]] = None,
exclude_extensions: Optional[list[str]] = None,
) -> int:
"""Add the file content into the database"""
"""Add the file content into the database

Args:
path: File or directory path to add
recursive: Recursively add files in directories
use_relative_paths: Store relative paths instead of absolute paths
metadata: Metadata to associate with documents
only_extensions: Only process these file extensions from the supported list (e.g. ['py', 'js'])
exclude_extensions: Skip these file extensions (e.g. ['py', 'js'])
"""
self._ensure_initialized()

if not Path(path).exists():
raise FileNotFoundError(f"{path} does not exist.")

parent = Path(path).parent

files_to_process = FileReader.collect_files(Path(path), recursive=recursive)
files_to_process = FileReader.collect_files(
Path(path),
recursive=recursive,
only_extensions=only_extensions,
exclude_extensions=exclude_extensions,
)

self._engine.create_new_context()

Expand Down
95 changes: 92 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,102 @@ def test_collect_files_recursive_directory(self):
assert file2 in files

def test_is_supported(self):
unsupported_extensions = [".exe", ".bin", ".jpg", ".png"]
unsupported_extensions = ["exe", "bin", "jpg", "png"]

for ext in FileReader.extensions:
assert FileReader.is_supported(Path(f"test{ext}"))
assert FileReader.is_supported(Path(f"test.{ext}"))

for ext in unsupported_extensions:
assert not FileReader.is_supported(Path(f"test{ext}"))
assert not FileReader.is_supported(Path(f"test.{ext}"))

def test_is_supported_with_only_extensions(self):
"""Test is_supported with only_extensions parameter"""
# Test with only_extensions - should only allow specified extensions
assert FileReader.is_supported(Path("test.py"), only_extensions=["py", "js"])
assert FileReader.is_supported(Path("test.js"), only_extensions=["py", "js"])
assert not FileReader.is_supported(
Path("test.txt"), only_extensions=["py", "js"]
)
assert not FileReader.is_supported(
Path("test.md"), only_extensions=["py", "js"]
)

# Test with dots in extensions (should be normalized)
assert FileReader.is_supported(Path("test.py"), only_extensions=[".py", ".js"])
assert FileReader.is_supported(Path("test.js"), only_extensions=[".py", ".js"])

# Test case insensitive
assert FileReader.is_supported(Path("test.py"), only_extensions=["PY", "JS"])
assert FileReader.is_supported(Path("test.JS"), only_extensions=["py", "js"])

def test_is_supported_with_exclude_extensions(self):
"""Test is_supported with exclude_extensions parameter"""
# Test basic exclusion - py files should be excluded
assert not FileReader.is_supported(Path("test.py"), exclude_extensions=["py"])
assert FileReader.is_supported(Path("test.js"), exclude_extensions=["py"])
assert FileReader.is_supported(Path("test.txt"), exclude_extensions=["py"])

# Test with dots in extensions (should be normalized)
assert not FileReader.is_supported(Path("test.py"), exclude_extensions=[".py"])
assert FileReader.is_supported(Path("test.js"), exclude_extensions=[".py"])

# Test case insensitive
assert not FileReader.is_supported(Path("test.py"), exclude_extensions=["PY"])
assert not FileReader.is_supported(Path("test.PY"), exclude_extensions=["py"])

# Test multiple exclusions
assert not FileReader.is_supported(
Path("test.py"), exclude_extensions=["py", "js"]
)
assert not FileReader.is_supported(
Path("test.js"), exclude_extensions=["py", "js"]
)
assert FileReader.is_supported(
Path("test.txt"), exclude_extensions=["py", "js"]
)

def test_is_supported_with_only_and_exclude_extensions(self):
"""Test is_supported with both only_extensions and exclude_extensions"""
# Include py and js, but exclude py - should only allow js
assert not FileReader.is_supported(
Path("test.py"), only_extensions=["py", "js"], exclude_extensions=["py"]
)
assert FileReader.is_supported(
Path("test.js"), only_extensions=["py", "js"], exclude_extensions=["py"]
)
assert not FileReader.is_supported(
Path("test.txt"), only_extensions=["py", "js"], exclude_extensions=["py"]
)

# Include py, txt, md, but exclude md - should only allow py and txt
assert FileReader.is_supported(
Path("test.py"),
only_extensions=["py", "txt", "md"],
exclude_extensions=["md"],
)
assert FileReader.is_supported(
Path("test.txt"),
only_extensions=["py", "txt", "md"],
exclude_extensions=["md"],
)
assert not FileReader.is_supported(
Path("test.md"),
only_extensions=["py", "txt", "md"],
exclude_extensions=["md"],
)
assert not FileReader.is_supported(
Path("test.js"),
only_extensions=["py", "txt", "md"],
exclude_extensions=["md"],
)

def test_is_supported_with_unsupported_extensions_in_only(self):
"""Test that only_extensions can't add unsupported extensions"""
# .exe is not in FileReader.extensions, so should not be supported even if in only_extensions
assert not FileReader.is_supported(
Path("test.exe"), only_extensions=["exe", "py"]
)
assert FileReader.is_supported(Path("test.py"), only_extensions=["exe", "py"])

def test_parse_html_into_markdown(self):
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f:
Expand Down
Loading