diff --git a/.gitignore b/.gitignore index 82f9275..728f902 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,9 @@ celerybeat.pid # SageMath parsed files *.sage.py +# Local idea tracking +NEW_IDEAS.md + # Environments .env .venv diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..2588586 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,103 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +### Install dependencies +```bash +pip install -r requirements.txt +pip install -r requirements_ci.txt # dev/linting tools +``` + +### Run tests +```bash +# All tests +python -m unittest discover -s core_tests -p "*.py" + +# Single test file +python -m unittest core_tests.runner_test +``` + +### Lint and format +```bash +pylint core/ +isort --profile=black --check-only core/ core_tests/ +isort --profile=black core/ core_tests/ # to fix imports +``` + +### Build +```bash +python -m build +``` + +### Run evals + +Evals verify that the V2 agent scanner still catches expected vulnerabilities across 5 fixture files (`evals/fixtures/`). Run them after any change to `core/agent.py`, `core/code_scanner/agent_scanner.py`, or the system prompt. + +```bash +# Standard model (gpt-4o-mini) — all fixtures should pass +set -a && source .env && set +a +python3 evals/run_evals.py --provider openai --model gpt-4o-mini + +# Advanced model (gpt-4o) — stricter; also requires YAML deserialization, +# race condition, JWT algorithm confusion, and timing attack findings +python3 evals/run_evals.py --provider openai --model gpt-4o + +# Single fixture only +python3 evals/run_evals.py --provider openai --model gpt-4o-mini --fixture auth_service.py +``` + +Findings are split into two tiers in `evals/expected_findings.json`: +- **standard** — required from any model (SQL injection, XSS, path traversal, pickle, hardcoded secrets, etc.) +- **advanced** — only required when running gpt-4o (YAML code execution, race conditions, JWT algorithm confusion, timing attacks) + +Exit code 0 = all fixtures at or above the 80% threshold. Exit code 1 = one or more failed. +### Run the CLI locally +```bash +# V1 runner (sends all files to AI at once) +python3 -m core.runner --provider openai + +# V2 runner (file-by-file via Pydantic-AI agent) +python3 -m core.runner_v2 --provider openai +``` + +## Architecture + +CodeScanAI is a CLI tool that scans codebases for security vulnerabilities using AI models. + +### Two parallel codepaths + +There are two scanner implementations that share the same CLI argument surface (`core/utils/argument_parser.py`): + +1. **V1 (`core/runner.py` → `core/code_scanner/code_scanner.py`)**: Aggregates all file content into a single code summary, sends it to the AI provider in one call, returns a markdown string. Uses the provider abstraction layer. + +2. **V2 (`core/runner_v2.py` → `core/code_scanner/agent_scanner.py`)**: Iterates file-by-file, runs a Pydantic-AI `Agent` synchronously on each, and streams structured `FileScanResult` output to stdout. Also supports posting inline PR review comments via `GithubIntegration`. This is the more feature-rich path. + +The active entrypoint is `core.runner_v2:main` (V2), set in `pyproject.toml` under `[project.scripts]`. + +### Provider abstraction (V1 only) + +`core/providers/base_ai_provider.py` defines the `BaseAIProvider` interface with a single `scan_code(code_summary)` method. Concrete implementations: +- `OpenAIProvider` — uses `openai` SDK +- `GoogleGeminiAIProvider` — uses `google-generativeai` +- `CustomAIProvider` — HTTP requests to a self-hosted server (Ollama, etc.) + +`core/utils/provider_creator.py` maps CLI `--provider` values to provider classes. + +### Pydantic-AI agent (V2 only) + +`core/agent.py` defines structured output types (`Vulnerability`, `FileScanResult`) and factory functions. It also holds pre-configured system prompts for different scan modes: `SECURITY_AGENT_PROMPT`, `PERFORMANCE_AGENT_PROMPT`, `CLEAN_CODE_AGENT_PROMPT`. Custom providers route through the OpenAI-compatible interface via `OPENAI_BASE_URL`. + +### GitHub integration + +`core/utils/github_integration.py` (`GithubIntegration`) is used only in V2. It posts inline PR review comments using PyGithub. Falls back to a regular issue comment if the line isn't in the PR diff. + +`core/utils/file_extractor.py` handles both local git-diff file discovery and PR file listing via the GitHub API, shared by both V1 and V2. + +### Scan modes + +Both runners support three modes driven by CLI args: +- **Full scan** (default): walks `--directory` and scans all files +- **Changes only** (`--changes_only`): scans files changed in local git repo +- **PR scan** (`--repo` + `--pr_number` + `--github_token`): fetches changed files from a GitHub PR diff --git a/core/agent.py b/core/agent.py new file mode 100644 index 0000000..bd77388 --- /dev/null +++ b/core/agent.py @@ -0,0 +1,108 @@ +""" +Defines the structured output types, agent factory, and pre-configured system prompts +used by the V2 Pydantic-AI scanner. +""" + +from typing import Optional, Type + +from pydantic import BaseModel, Field +from pydantic_ai import Agent + + +class Vulnerability(BaseModel): + """Represents a single security vulnerability found in a file.""" + + line_number: Optional[int] = Field( + default=None, + description=( + "The exact line number where the issue is found inside the file. " + "Omit if the issue is architectural or spans multiple lines." + ), + ) + description: str = Field( + description="A detailed description of the issue and why it is a security risk." + ) + remediation: str = Field( + description="Actionable suggestion or code snippet on how to fix this specific vulnerability." + ) + severity: str = Field(description="Severity measure (Low, Medium, High, Critical)") + vulnerability_type: str = Field( + description="The category of issue (e.g. SQL Injection, Big-O inefficiency, etc.)" + ) + + +class FileScanResult(BaseModel): + """Structured result returned by the agent for a single scanned file.""" + + vulnerabilities: list[Vulnerability] = Field( + description="List of issues found in the file. Empty if zero issues are found." + ) + + +def get_pydantic_ai_model(provider: str, model: Optional[str]) -> str: + """Map a CLI provider name and optional model string to a Pydantic-AI model identifier.""" + if provider == "openai": + return f"openai:{model or 'gpt-4o-mini'}" + if provider == "gemini": + return f"gemini:{model or 'gemini-1.5-flash'}" + if provider == "custom": + # Falls back to the OpenAI-compatible interface via OPENAI_BASE_URL + return f"openai:{model or 'custom-model'}" + return "openai:gpt-4o-mini" + + +def create_agent( + model_str: str, system_prompt: str, result_type: Type[BaseModel] = FileScanResult +) -> Agent: + """ + Creates and returns a Pydantic-AI Agent configured for a custom, laser-focused task. + By passing different `system_prompt` and `result_type` schemas, you can deploy + multiple types of agents. + """ + return Agent( + model_str, + result_type=result_type, + system_prompt=system_prompt, + ) + + +# --- Define pre-configured Agent Prompts for laser-focused tasks --- + +SECURITY_AGENT_PROMPT = ( + "You are an expert in software security analysis, adept at identifying and explaining " + "potential vulnerabilities in code. " + "You will be given complete code snippets from various applications. " + "EVERY line of the source code is prefixed with its exact line number " + "(e.g. `14: def foo():`). " + "Your task is to analyze the provided code, pinpoint potential security risks, " + "and offer clear suggestions for enhancing the application's security posture. " + "Focus on the critical issues that could impact the overall security of the application. " + "You MUST be exhaustive. Carefully audit the entire script from top to bottom " + "and return EVERY vulnerability you find. Do not stop at the first issue. " + "If any are found, use the explicitly provided line numbers to pinpoint the defect " + "where possible. For architectural or multi-line issues, you may omit the line number. " + "Also, strictly provide an actionable `remediation` that makes suggestions on how to " + "rewrite or fix the code securely. " + "If no vulnerabilities are found, return an empty list. " + "When scanning a pull request or diff, some lines will be marked with `[CHANGED]` " + "after the line number (e.g. `14: [CHANGED] def foo():`). " + "These lines are newly added or modified in the change under review. " + "Prioritise your analysis on `[CHANGED]` lines, but use the full file context — " + "imports, surrounding functions, class definitions, and data flow — to assess " + "whether those changes introduce or worsen a vulnerability." +) + +PERFORMANCE_AGENT_PROMPT = ( + "You are a Senior Staff Software Engineer laser-focused on performance optimization. " + "Analyze the following code for memory leaks, O(N^2) bottlenecks, or CPU inefficiencies. " + "Pinpoint exact line numbers and return a list of performance issues. " + "If none are found, return an empty list." +) + +CLEAN_CODE_AGENT_PROMPT = ( + "You are an expert in code refactoring and Clean Code methodologies. " + "Analyze the code for anti-patterns, confusing variable names, massive functions, " + "or high cyclomatic complexity. " + "Pinpoint exact line numbers and return a list of maintainability issues. " + "If the code is perfectly clean, return an empty list." +) diff --git a/core/code_scanner/agent_scanner.py b/core/code_scanner/agent_scanner.py new file mode 100644 index 0000000..eedc176 --- /dev/null +++ b/core/code_scanner/agent_scanner.py @@ -0,0 +1,162 @@ +""" +V2 agent-based scanner. Scans source files one at a time using a Pydantic-AI Agent +and streams structured FileScanResult output to stdout. +""" + +import logging +import os + +from core.agent import ( + SECURITY_AGENT_PROMPT, + FileScanResult, + create_agent, + get_pydantic_ai_model, +) +from core.utils.file_extractor import ( + get_changed_files_in_pr, + get_changed_files_in_repo, + get_local_changed_line_numbers, + get_pr_changed_line_numbers, +) +from core.utils.github_integration import GithubIntegration + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class AgentScanner: + """ + Scans source code via the Pydantic AI agent, file by file. + Optionally posts inline review comments to GitHub PRs. + """ + + def __init__(self, args) -> None: + self.args = args + model_str = get_pydantic_ai_model(args.provider, args.model) + + # Set OPENAI_BASE_URL for custom providers so Pydantic-AI targets the correct backend. + if args.provider == "custom" and args.host: + host_url = f"{args.host}:{args.port}" if args.port else args.host + if args.endpoint: + host_url += args.endpoint + os.environ["OPENAI_BASE_URL"] = host_url + if args.token: + os.environ["OPENAI_API_KEY"] = args.token + + self.agent = create_agent( + model_str=model_str, + system_prompt=SECURITY_AGENT_PROMPT, + result_type=FileScanResult, + ) + self.github_integration = ( + GithubIntegration(args) + if args.repo and args.pr_number and args.github_token + else None + ) + + def scan(self): + """ + Scans the code by identifying files based on PR context or local directory + and iterates through them using the Pydantic AI agent. + """ + if self.args.changes_only or (self.args.repo and self.args.pr_number): + return self._scan_changes() + return self._scan_files() + + def _scan_changes(self): + try: + if self.args.repo and self.args.pr_number: + changed_files = get_changed_files_in_pr( + self.args.repo, self.args.pr_number, self.args.github_token + ) + changed_line_map = get_pr_changed_line_numbers( + self.args.repo, self.args.pr_number, self.args.github_token + ) + else: + changed_files = get_changed_files_in_repo(self.args.directory) + changed_line_map = None + except ValueError as e: + logging.error(e) + return + + if not changed_files: + logging.info("No changes detected.") + return + + for filename in changed_files: + filepath = os.path.join(self.args.directory, filename) + if changed_line_map is not None: + changed_lines = changed_line_map.get(filename, set()) + else: + changed_lines = get_local_changed_line_numbers(self.args.directory, filename) + self._scan_single_file(filepath, display_name=filename, changed_lines=changed_lines) + + def _scan_files(self): + file_paths = [] + for root, _, files in os.walk(self.args.directory): + for file in files: + file_paths.append(os.path.join(root, file)) + + for filepath in file_paths: + self._scan_single_file( + filepath, display_name=os.path.relpath(filepath, self.args.directory) + ) + + def _scan_single_file(self, file_path: str, display_name: str, changed_lines: set = None): + """Scan a single file and print any vulnerabilities found.""" + if not os.path.isfile(file_path): + logging.warning("Skipping %s: Not a valid file or not found locally.", file_path) + return + + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning("Skipping %s: %s", file_path, e) + return + + if not content.strip(): + return + + logging.info("Scanning file: %s ...", display_name) + + def _format_line(idx, line): + lineno = idx + 1 + if changed_lines and lineno in changed_lines: + return f"{lineno}: [CHANGED] {line}" + return f"{lineno}: {line}" + + numbered_content = "\n".join([_format_line(idx, line) for idx, line in enumerate(content.splitlines())]) + + try: + result = self.agent.run_sync(f"File: {display_name}\n\n{numbered_content}") + scan_result = result.data + + if scan_result.vulnerabilities: + print(f"\n--- Vulnerabilities found in {display_name} ---") + md_output = "" + for vuln in scan_result.vulnerabilities: + line_info = f"Line {vuln.line_number}: " if vuln.line_number else "" + md_output += f" - **{line_info}[{vuln.severity}] {vuln.vulnerability_type}**\n" + md_output += f" - **Issue**: {vuln.description}\n" + md_output += f" - **Fix**: {vuln.remediation}\n" + print(md_output) + + if self.github_integration: + for vuln in scan_result.vulnerabilities: + comment_body = ( + f"**[{vuln.severity.upper()} SEVERITY] {vuln.vulnerability_type}**" + f"\n\n{vuln.description}" + f"\n\n**Suggested Fix:**\n{vuln.remediation}" + ) + self.github_integration.post_inline_comment( + path=display_name, + line=vuln.line_number, + body=comment_body, + ) + else: + logging.info("No vulnerabilities found in %s.", display_name) + + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Error scanning %s: %s", display_name, e) diff --git a/core/runner_v2.py b/core/runner_v2.py new file mode 100644 index 0000000..630d4b6 --- /dev/null +++ b/core/runner_v2.py @@ -0,0 +1,21 @@ +""" +This is the V2 runner of the codescan-ai CLI tool. +It utilizes pydantic-ai orchestrations to process data file by file. +""" + +from core.code_scanner.agent_scanner import AgentScanner +from core.utils.argument_parser import parse_arguments + + +def main(): + """ + Main entry point for the V2 CLI. Parses arguments, calls the AgentScanner + (which performs the file-by-file scanning by using the AI agent), + and displays the results progressively. + """ + args = parse_arguments() + AgentScanner(args).scan() + + +if __name__ == "__main__": + main() diff --git a/core/utils/file_extractor.py b/core/utils/file_extractor.py index 4bcf383..7ac2c92 100644 --- a/core/utils/file_extractor.py +++ b/core/utils/file_extractor.py @@ -1,11 +1,12 @@ """ -This module contains utilities for checking -if a directory is a Git repository, retrieving changed files from local repositories +This module contains utilities for checking +if a directory is a Git repository, retrieving changed files from local repositories or GitHub pull requests. """ import logging import os +import re import subprocess from github import Github @@ -36,6 +37,75 @@ def is_git_repo(directory): return False +def _parse_changed_lines(patch: str) -> set: + """ + Parse a unified diff patch string and return the set of new-file line numbers + that correspond to added or modified lines (i.e. lines prefixed with '+'). + """ + changed = set() + new_line = 0 + for line in patch.split("\n"): + hunk = re.match(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,\d+)? @@", line) + if hunk: + new_line = int(hunk.group(1)) + elif line.startswith("+"): + changed.add(new_line) + new_line += 1 + elif line.startswith("-"): + pass # removed line — does not advance the new-file counter + else: + new_line += 1 # context line + return changed + + +def get_pr_changed_line_numbers(repo_name, pr_number, github_token): + """ + Returns a mapping of filename -> set of new-file line numbers that were + added or modified in the pull request. + + Parameters: + repo_name (string): The name of the repository (e.g. 'owner/repo'). + pr_number (int): The pull request number. + github_token (string): A GitHub personal access token. + + Returns: + dict[str, set[int]]: Filename to changed line numbers. + """ + if not github_token: + raise ValueError("GitHub token is required for scanning PR changes.") + + files = Github(github_token).get_repo(repo_name).get_pull(pr_number).get_files() + result = {} + for f in files: + if f.patch: + result[f.filename] = _parse_changed_lines(f.patch) + else: + result[f.filename] = set() + return result + + +def get_local_changed_line_numbers(directory, filename): + """ + Returns the set of new-file line numbers that are modified (unstaged) for a + given file in a local git repository. + + Parameters: + directory (string): The path to the git repository root. + filename (string): The file path relative to the repository root. + + Returns: + set[int]: Changed line numbers, or an empty set on any error. + """ + try: + patch = subprocess.check_output( + ["git", "-C", directory, "diff", "--", filename], text=True + ) + return _parse_changed_lines(patch) + except subprocess.CalledProcessError as e: + logging.warning("Could not get diff for %s: %s", filename, e) + return set() + + def get_changed_files_in_pr(repo_name, pr_number, github_token): """ Returns a list of files that have been changed in the specified pull request. diff --git a/core/utils/github_integration.py b/core/utils/github_integration.py new file mode 100644 index 0000000..db990a9 --- /dev/null +++ b/core/utils/github_integration.py @@ -0,0 +1,58 @@ +"""Utilities for posting inline review comments to GitHub pull requests.""" + +import logging + +from github import Github + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class GithubIntegration: + """Wraps PyGithub to post inline PR review comments for scanner findings.""" + + def __init__(self, args): + self.repo_name = args.repo + self.pr_number = args.pr_number + self.github_token = args.github_token + + if self.github_token: + self.gh = Github(self.github_token) + self.repo = self.gh.get_repo(self.repo_name) + self.pr = self.repo.get_pull(self.pr_number) + + commits = list(self.pr.get_commits()) + self.latest_commit = commits[-1] if commits else None + else: + self.gh = None + + def post_inline_comment(self, path: str, line: int, body: str): + """ + Post an inline review comment on a specific line of a PR diff. + Falls back to a regular issue comment if the line is not part of the diff. + """ + if not self.gh or not self.latest_commit: + return + + try: + logging.info("Attempting to post PR review comment to %s:%s", path, line) + self.pr.create_review_comment( + body=body, + commit=self.latest_commit, + path=path, + line=line, + ) + except Exception: # pylint: disable=broad-exception-caught + logging.warning( + "Could not post inline comment on %s:%s " + "(line might not be part of pull request diff). " + "Falling back to normal PR comment.", + path, + line, + ) + try: + fallback_body = f"**File: `{path}` (Line {line})**\n\n{body}" + self.pr.create_issue_comment(fallback_body) + except Exception: # pylint: disable=broad-exception-caught + logging.error("Failed to post fallback PR comment for %s:%s", path, line) diff --git a/core_tests/agent_scanner_test.py b/core_tests/agent_scanner_test.py new file mode 100644 index 0000000..f841652 --- /dev/null +++ b/core_tests/agent_scanner_test.py @@ -0,0 +1,292 @@ +import os +import unittest +from unittest.mock import MagicMock, mock_open, patch + +from core.agent import FileScanResult, Vulnerability, get_pydantic_ai_model +from core.code_scanner.agent_scanner import AgentScanner + + +class TestGetPydanticAiModel(unittest.TestCase): + def test__openai__returnsDefaultModel(self): + self.assertEqual(get_pydantic_ai_model("openai", None), "openai:gpt-4o-mini") + + def test__openai__returnsSpecifiedModel(self): + self.assertEqual(get_pydantic_ai_model("openai", "gpt-4o"), "openai:gpt-4o") + + def test__gemini__returnsDefaultModel(self): + self.assertEqual(get_pydantic_ai_model("gemini", None), "gemini:gemini-1.5-flash") + + def test__gemini__returnsSpecifiedModel(self): + self.assertEqual(get_pydantic_ai_model("gemini", "gemini-pro"), "gemini:gemini-pro") + + def test__custom__returnsDefaultModel(self): + self.assertEqual(get_pydantic_ai_model("custom", None), "openai:custom-model") + + def test__custom__returnsSpecifiedModel(self): + self.assertEqual(get_pydantic_ai_model("custom", "my-model"), "openai:my-model") + + +class TestVulnerabilityModel(unittest.TestCase): + def test__vulnerability__withAllFields(self): + vuln = Vulnerability( + line_number=14, + description="SQL Injection via f-string", + remediation="Use parameterized queries", + severity="Critical", + vulnerability_type="SQL Injection", + ) + self.assertEqual(vuln.line_number, 14) + self.assertEqual(vuln.severity, "Critical") + self.assertEqual(vuln.vulnerability_type, "SQL Injection") + + def test__vulnerability__lineNumberIsOptional(self): + vuln = Vulnerability( + description="MD5 is a weak hashing algorithm", + remediation="Use bcrypt or Argon2", + severity="High", + vulnerability_type="Weak Cryptography", + ) + self.assertIsNone(vuln.line_number) + + def test__fileScanResult__emptyVulnerabilities(self): + result = FileScanResult(vulnerabilities=[]) + self.assertEqual(result.vulnerabilities, []) + + def test__fileScanResult__withVulnerabilities(self): + vuln = Vulnerability( + line_number=6, + description="Pickle deserialization", + remediation="Use json.loads", + severity="Critical", + vulnerability_type="Insecure Deserialization", + ) + result = FileScanResult(vulnerabilities=[vuln]) + self.assertEqual(len(result.vulnerabilities), 1) + self.assertEqual(result.vulnerabilities[0].line_number, 6) + + +class TestAgentScanner(unittest.TestCase): + + def _make_args(self, **kwargs): + defaults = dict( + provider="openai", + model=None, + host=None, + port=None, + token=None, + endpoint=None, + directory=".", + changes_only=False, + repo=None, + pr_number=None, + github_token=None, + ) + defaults.update(kwargs) + return MagicMock(**defaults) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__init__createsAgentWithCorrectModelString(self, mock_create_agent): + args = self._make_args(provider="openai", model="gpt-4o") + AgentScanner(args) + mock_create_agent.assert_called_once() + call_kwargs = mock_create_agent.call_args[1] + self.assertEqual(call_kwargs["model_str"], "openai:gpt-4o") + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__init__customProvider_setsOpenAIBaseURL(self, mock_create_agent): + args = self._make_args(provider="custom", host="http://localhost", port=11434, endpoint="/v1", token="tok") + with patch.dict(os.environ, {}, clear=False): + AgentScanner(args) + self.assertEqual(os.environ.get("OPENAI_BASE_URL"), "http://localhost:11434/v1") + self.assertEqual(os.environ.get("OPENAI_API_KEY"), "tok") + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__init__noGithubArgs_githubIntegrationIsNone(self, mock_create_agent): + args = self._make_args(repo=None, pr_number=None, github_token=None) + scanner = AgentScanner(args) + self.assertIsNone(scanner.github_integration) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scan__routesToScanFiles_whenNoChangesFlag(self, mock_create_agent): + args = self._make_args(changes_only=False, repo=None, pr_number=None) + scanner = AgentScanner(args) + scanner._scan_files = MagicMock() + scanner.scan() + scanner._scan_files.assert_called_once() + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scan__routesToScanChanges_whenChangesOnlyTrue(self, mock_create_agent): + args = self._make_args(changes_only=True, repo=None, pr_number=None) + scanner = AgentScanner(args) + scanner._scan_changes = MagicMock() + scanner.scan() + scanner._scan_changes.assert_called_once() + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__skipsNonExistentFile(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + with patch("os.path.isfile", return_value=False): + scanner._scan_single_file("/fake/path.py", "path.py") + scanner.agent.run_sync.assert_not_called() + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__skipsEmptyFile(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data=" ")): + scanner._scan_single_file("/fake/path.py", "path.py") + scanner.agent.run_sync.assert_not_called() + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__prefixesLinesWithNumbers(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + mock_result = MagicMock() + mock_result.data = FileScanResult(vulnerabilities=[]) + scanner.agent.run_sync = MagicMock(return_value=mock_result) + + file_content = "line one\nline two\nline three" + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data=file_content)): + scanner._scan_single_file("/fake/path.py", "path.py") + + call_arg = scanner.agent.run_sync.call_args[0][0] + self.assertIn("1: line one", call_arg) + self.assertIn("2: line two", call_arg) + self.assertIn("3: line three", call_arg) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__annotatesChangedLines(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + mock_result = MagicMock() + mock_result.data = FileScanResult(vulnerabilities=[]) + scanner.agent.run_sync = MagicMock(return_value=mock_result) + + file_content = "line one\nline two\nline three" + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data=file_content)): + scanner._scan_single_file("/fake/path.py", "path.py", changed_lines={2}) + + call_arg = scanner.agent.run_sync.call_args[0][0] + self.assertIn("1: line one", call_arg) + self.assertIn("2: [CHANGED] line two", call_arg) + self.assertIn("3: line three", call_arg) + self.assertNotIn("1: [CHANGED]", call_arg) + self.assertNotIn("3: [CHANGED]", call_arg) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__noChangedLinesMarker_whenChangedLinesIsNone(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + mock_result = MagicMock() + mock_result.data = FileScanResult(vulnerabilities=[]) + scanner.agent.run_sync = MagicMock(return_value=mock_result) + + file_content = "line one\nline two" + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data=file_content)): + scanner._scan_single_file("/fake/path.py", "path.py", changed_lines=None) + + call_arg = scanner.agent.run_sync.call_args[0][0] + self.assertNotIn("[CHANGED]", call_arg) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__printsVulnerabilitiesWithLineNumber(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + vuln = Vulnerability( + line_number=14, + description="SQL Injection", + remediation="Use parameterized queries", + severity="Critical", + vulnerability_type="SQL Injection", + ) + mock_result = MagicMock() + mock_result.data = FileScanResult(vulnerabilities=[vuln]) + scanner.agent.run_sync = MagicMock(return_value=mock_result) + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="some code")), \ + patch("builtins.print") as mock_print: + scanner._scan_single_file("/fake/path.py", "path.py") + + printed = " ".join(str(c) for c in mock_print.call_args_list) + self.assertIn("Line 14", printed) + self.assertIn("Critical", printed) + self.assertIn("SQL Injection", printed) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__printsVulnerabilitiesWithoutLineNumber(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + vuln = Vulnerability( + line_number=None, + description="MD5 is weak", + remediation="Use bcrypt", + severity="High", + vulnerability_type="Weak Cryptography", + ) + mock_result = MagicMock() + mock_result.data = FileScanResult(vulnerabilities=[vuln]) + scanner.agent.run_sync = MagicMock(return_value=mock_result) + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="some code")), \ + patch("builtins.print") as mock_print: + scanner._scan_single_file("/fake/path.py", "path.py") + + printed = " ".join(str(c) for c in mock_print.call_args_list) + self.assertNotIn("Line None", printed) + self.assertIn("High", printed) + self.assertIn("Weak Cryptography", printed) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanChanges__usesGitRepoFiles_whenNoPrArgs(self, mock_create_agent): + args = self._make_args(changes_only=True, repo=None, pr_number=None, directory="/repo") + scanner = AgentScanner(args) + scanner._scan_single_file = MagicMock() + + with patch("core.code_scanner.agent_scanner.get_changed_files_in_repo", return_value=["a.py"]), \ + patch("core.code_scanner.agent_scanner.get_local_changed_line_numbers", return_value={3, 5}) as mock_local_lines: + scanner._scan_changes() + + mock_local_lines.assert_called_once_with("/repo", "a.py") + scanner._scan_single_file.assert_called_once_with( + os.path.join("/repo", "a.py"), display_name="a.py", changed_lines={3, 5} + ) + + @patch("core.code_scanner.agent_scanner.GithubIntegration") + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanChanges__usesPrFiles_whenRepoAndPrNumberSet(self, mock_create_agent, mock_github): + args = self._make_args( + changes_only=True, repo="owner/repo", pr_number=42, + github_token="tok", directory="/repo" + ) + scanner = AgentScanner(args) + scanner._scan_single_file = MagicMock() + + with patch("core.code_scanner.agent_scanner.get_changed_files_in_pr", return_value=["b.py"]), \ + patch("core.code_scanner.agent_scanner.get_pr_changed_line_numbers", return_value={"b.py": {10, 11}}): + scanner._scan_changes() + + scanner._scan_single_file.assert_called_once_with( + os.path.join("/repo", "b.py"), display_name="b.py", changed_lines={10, 11} + ) + + @patch("core.code_scanner.agent_scanner.create_agent") + def test__scanSingleFile__handlesAgentException(self, mock_create_agent): + args = self._make_args() + scanner = AgentScanner(args) + scanner.agent.run_sync = MagicMock(side_effect=Exception("API error")) + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="some code")): + # Should not raise — errors are caught and logged + scanner._scan_single_file("/fake/path.py", "path.py") + + +if __name__ == "__main__": + unittest.main() diff --git a/evals/expected_findings.json b/evals/expected_findings.json new file mode 100644 index 0000000..4c63d98 --- /dev/null +++ b/evals/expected_findings.json @@ -0,0 +1,198 @@ +{ + "_comment": "Ground-truth vulnerability manifest for eval fixtures. Each finding has a 'tier': 'standard' (expected from any model) or 'advanced' (only required when running gpt-4o-class models). When running a standard-tier model, advanced findings are treated as optional bonuses. 'threshold' is the minimum fraction of tier-appropriate required findings that must be detected to pass.", + "threshold": 0.80, + "model_tiers": { + "standard": ["gpt-4o-mini", "gemini-1.5-flash", "gemini-flash", "custom-model"], + "advanced": ["gpt-4o", "gpt-4-turbo", "gpt-4", "gemini-1.5-pro", "gemini-pro", "claude"] + }, + "fixtures": { + "web_api.py": { + "required": [ + { + "id": "web_api-1", + "tier": "standard", + "vulnerability_type_keywords": ["sql injection", "sql"], + "line_hint": 37, + "severity": "Critical", + "note": "Unsanitised user input concatenated into SQL query in search_users()" + }, + { + "id": "web_api-2", + "tier": "standard", + "vulnerability_type_keywords": ["xss", "cross-site scripting", "stored xss"], + "line_hint": 53, + "severity": "High", + "note": "Raw HTML/JS bio stored and returned without sanitisation in update_profile()" + }, + { + "id": "web_api-3", + "tier": "standard", + "vulnerability_type_keywords": ["open redirect", "redirect"], + "line_hint": 78, + "severity": "Medium", + "note": "next_url from request redirected without validation in login_callback()" + } + ], + "optional": [ + { + "id": "web_api-4", + "tier": "standard", + "vulnerability_type_keywords": ["ssrf", "server-side request forgery"], + "line_hint": 68, + "severity": "High", + "note": "Arbitrary URL fetched server-side in proxy_image() — SSRF" + } + ] + }, + "auth_service.py": { + "required": [ + { + "id": "auth-1", + "tier": "standard", + "vulnerability_type_keywords": ["hardcoded secret", "hardcoded", "secret", "jwt secret"], + "line_hint": 18, + "severity": "Critical", + "note": "JWT_SECRET embedded in source code" + }, + { + "id": "auth-2", + "tier": "advanced", + "vulnerability_type_keywords": ["algorithm confusion", "jwt", "algorithm"], + "line_hint": 46, + "severity": "High", + "note": "jwt.decode accepts both HS256 and RS256 — algorithm confusion attack" + }, + { + "id": "auth-3", + "tier": "standard", + "vulnerability_type_keywords": ["weak cryptography", "md5", "weak hash", "predictable token"], + "line_hint": 57, + "severity": "High", + "note": "MD5 used to generate password reset token — predictable and broken" + }, + { + "id": "auth-4", + "tier": "standard", + "vulnerability_type_keywords": ["sensitive data", "logging", "password", "log"], + "line_hint": 98, + "severity": "High", + "note": "Plaintext password written to logs in register_user()" + } + ], + "optional": [ + { + "id": "auth-5", + "tier": "advanced", + "vulnerability_type_keywords": ["timing attack", "timing", "comparison"], + "line_hint": 72, + "severity": "Medium", + "note": "Non-constant-time string comparison in validate_api_key()" + } + ] + }, + "file_handler.py": { + "required": [ + { + "id": "file-1", + "tier": "standard", + "vulnerability_type_keywords": ["path traversal", "directory traversal"], + "line_hint": 40, + "severity": "Critical", + "note": "Unsanitised filename joined onto upload dir in save_upload()" + }, + { + "id": "file-2", + "tier": "standard", + "vulnerability_type_keywords": ["command injection", "os.system", "shell injection"], + "line_hint": 56, + "severity": "Critical", + "note": "Unsanitised filename interpolated into os.system() shell command in generate_thumbnail()" + }, + { + "id": "file-3", + "tier": "standard", + "vulnerability_type_keywords": ["insecure deserialization", "pickle", "deserialization"], + "line_hint": 76, + "severity": "Critical", + "note": "pickle.loads on untrusted input in load_user_session()" + }, + { + "id": "file-4", + "tier": "standard", + "vulnerability_type_keywords": ["zip slip", "tar slip", "path traversal", "archive"], + "line_hint": 87, + "severity": "High", + "note": "tarfile.extractall without path validation — tar slip in extract_archive()" + } + ], + "optional": [] + }, + "mobile_backend.py": { + "required": [ + { + "id": "mobile-1", + "tier": "standard", + "vulnerability_type_keywords": ["xxe", "xml external entity", "xml"], + "line_hint": 56, + "severity": "Critical", + "note": "lxml parses XML with external entity resolution enabled in import_user_preferences()" + }, + { + "id": "mobile-2", + "tier": "standard", + "vulnerability_type_keywords": ["idor", "insecure direct object reference", "authorization"], + "line_hint": 70, + "severity": "High", + "note": "No ownership check in get_notification_settings() — any user can read another's settings" + }, + { + "id": "mobile-3", + "tier": "standard", + "vulnerability_type_keywords": ["mass assignment", "mass assignment", "privilege escalation"], + "line_hint": 82, + "severity": "High", + "note": "All request body fields applied to user record without allowlist in update_user_settings()" + } + ], + "optional": [] + }, + "data_pipeline.py": { + "required": [ + { + "id": "pipeline-1", + "tier": "standard", + "vulnerability_type_keywords": ["hardcoded", "credential", "password", "secret"], + "line_hint": 24, + "severity": "Critical", + "note": "Production DB password hardcoded in DB_CONFIG dict" + }, + { + "id": "pipeline-2", + "tier": "advanced", + "vulnerability_type_keywords": ["yaml", "deserialization", "code execution"], + "line_hint": 49, + "severity": "High", + "note": "yaml.load() without SafeLoader allows arbitrary code execution in load_pipeline_config()" + }, + { + "id": "pipeline-3", + "tier": "advanced", + "vulnerability_type_keywords": ["race condition", "toctou", "thread", "concurrency"], + "line_hint": 75, + "severity": "Medium", + "note": "Shared list _batch_results modified without lock in worker_ingest()" + } + ], + "optional": [ + { + "id": "pipeline-4", + "tier": "advanced", + "vulnerability_type_keywords": ["tempfile", "toctou", "mktemp", "symlink"], + "line_hint": 92, + "severity": "Medium", + "note": "tempfile.mktemp() has TOCTOU window — use mkstemp() instead" + } + ] + } + } +} diff --git a/evals/fixtures/auth_service.py b/evals/fixtures/auth_service.py new file mode 100644 index 0000000..aff58ba --- /dev/null +++ b/evals/fixtures/auth_service.py @@ -0,0 +1,141 @@ +""" +auth_service.py — Authentication and session management service. +Handles user registration, login, password reset, and JWT issuance. +""" + +import hmac +import hashlib +import time +import json +import base64 +import jwt +import bcrypt +from datetime import datetime, timedelta +from typing import Optional + + +# ------------------------------------------------------------------ # +# VULNERABILITY 1: Hardcoded Secret — JWT signing key is embedded in +# source; anyone with repo access can forge tokens. +# ------------------------------------------------------------------ # +JWT_SECRET = "s3cr3t-jwt-key-do-not-share" +JWT_ALGORITHM = "HS256" +TOKEN_EXPIRY_HOURS = 24 + +PASSWORD_RESET_EXPIRY_MINUTES = 30 +MAX_LOGIN_ATTEMPTS = 5 + + +def hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +def verify_password(password: str, hashed: str) -> bool: + return bcrypt.checkpw(password.encode(), hashed.encode()) + + +def issue_token(user_id: int, email: str) -> str: + payload = { + "sub": user_id, + "email": email, + "iat": datetime.utcnow(), + "exp": datetime.utcnow() + timedelta(hours=TOKEN_EXPIRY_HOURS), + } + return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 2: JWT Algorithm Confusion — `algorithms` accepts both +# HS256 and RS256. An attacker can obtain the public key, re-sign a +# forged token using HS256, and it will be accepted as valid. +# ------------------------------------------------------------------ # +def decode_token(token: str) -> Optional[dict]: + try: + # Dangerous: accepting multiple algorithms enables algorithm confusion attacks + return jwt.decode(token, JWT_SECRET, algorithms=["HS256", "RS256"]) + except jwt.ExpiredSignatureError: + return None + except jwt.InvalidTokenError: + return None + + +# ------------------------------------------------------------------ # +# VULNERABILITY 3: Weak Password Reset Token — MD5 of email + timestamp +# is predictable. Attackers who know the email can brute-force the token +# or precompute it if the timestamp window is known. +# ------------------------------------------------------------------ # +def generate_password_reset_token(email: str) -> str: + seed = f"{email}{int(time.time())}" + # Dangerous: MD5 is cryptographically broken and predictable + return hashlib.md5(seed.encode()).hexdigest() + + +def validate_reset_token(email: str, token: str, issued_at: int) -> bool: + elapsed = time.time() - issued_at + if elapsed > PASSWORD_RESET_EXPIRY_MINUTES * 60: + return False + expected = hashlib.md5(f"{email}{issued_at}".encode()).hexdigest() + return hmac.compare_digest(token, expected) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 4: Timing Attack on login — string comparison via `==` +# leaks information about how many characters matched, enabling +# character-by-character brute-forcing of valid tokens/session IDs. +# ------------------------------------------------------------------ # +def validate_api_key(provided_key: str, stored_key: str) -> bool: + # Dangerous: non-constant-time comparison leaks timing information + return provided_key == stored_key + + +def validate_api_key_safe(provided_key: str, stored_key: str) -> bool: + return hmac.compare_digest(provided_key.encode(), stored_key.encode()) + + +def build_login_response(user_id: int, email: str, roles: list) -> dict: + token = issue_token(user_id, email) + return { + "token": token, + "user_id": user_id, + "email": email, + "roles": roles, + "issued_at": datetime.utcnow().isoformat(), + } + + +def is_account_locked(failed_attempts: int) -> bool: + return failed_attempts >= MAX_LOGIN_ATTEMPTS + + +def log_login_attempt(email: str, success: bool, ip: str): + entry = { + "email": email, + "success": success, + "ip": ip, + "timestamp": datetime.utcnow().isoformat(), + } + print(json.dumps(entry)) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 5: Sensitive data in logs — password is written to +# stdout/log output; log aggregators (Splunk, CloudWatch) will store it. +# ------------------------------------------------------------------ # +def register_user(email: str, password: str, db): + print(f"[DEBUG] register_user called with email={email} password={password}") + if db.find_user(email): + raise ValueError("Email already registered") + hashed = hash_password(password) + db.insert_user(email, hashed) + return True + + +def change_password(user_id: int, old_password: str, new_password: str, db): + user = db.get_user_by_id(user_id) + if not user: + raise ValueError("User not found") + if not verify_password(old_password, user["password_hash"]): + raise ValueError("Incorrect current password") + new_hash = hash_password(new_password) + db.update_password(user_id, new_hash) + return True diff --git a/evals/fixtures/data_pipeline.py b/evals/fixtures/data_pipeline.py new file mode 100644 index 0000000..68c2e1b --- /dev/null +++ b/evals/fixtures/data_pipeline.py @@ -0,0 +1,162 @@ +""" +data_pipeline.py — ETL pipeline that ingests external data feeds, +transforms records, and writes results to a PostgreSQL data warehouse. +Runs as a scheduled batch job. +""" + +import os +import threading +import tempfile +import logging +import psycopg2 +import requests +import yaml +from typing import Optional + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 1: Hardcoded Database Credentials — production DB +# password is embedded in source. Anyone with repo read access (or who +# finds the binary) can connect to the production warehouse directly. +# ------------------------------------------------------------------ # +DB_CONFIG = { + "host": "prod-db.internal", + "port": 5432, + "dbname": "analytics", + "user": "pipeline_user", + "password": "Sup3rS3cr3tProdP@ssw0rd!", # Dangerous: hardcoded secret +} + + +def get_db_connection(): + return psycopg2.connect(**DB_CONFIG) + + +def build_dsn(config: dict) -> str: + return ( + f"postgresql://{config['user']}:{config['password']}" + f"@{config['host']}:{config['port']}/{config['dbname']}" + ) + + +def ping_db() -> bool: + try: + conn = get_db_connection() + conn.close() + return True + except psycopg2.OperationalError: + return False + + +# ------------------------------------------------------------------ # +# VULNERABILITY 2: YAML Deserialization (yaml.load without Loader) — +# PyYAML's full Loader allows arbitrary Python object construction. +# A crafted YAML file can execute OS commands on load. +# ------------------------------------------------------------------ # +def load_pipeline_config(config_path: str) -> dict: + with open(config_path, "r") as f: + # Dangerous: yaml.load without Loader=yaml.SafeLoader + return yaml.load(f) + + +def load_pipeline_config_safe(config_path: str) -> dict: + with open(config_path, "r") as f: + return yaml.safe_load(f) + + +# Shared mutable state for inter-thread communication +_batch_results: list = [] +_batch_lock = threading.Lock() + + +def fetch_feed(url: str, api_key: str) -> list: + resp = requests.get(url, headers={"Authorization": f"Bearer {api_key}"}, timeout=30) + resp.raise_for_status() + return resp.json().get("records", []) + + +def transform_record(raw: dict) -> dict: + return { + "source_id": raw["id"], + "value": float(raw.get("value", 0)), + "category": raw.get("category", "unknown").lower(), + "ingested_at": raw.get("timestamp"), + } + + +# ------------------------------------------------------------------ # +# VULNERABILITY 3: Race Condition — `_batch_results` is a shared list +# modified by multiple worker threads without holding the lock for +# the entire read-modify-write cycle. Concurrent appends can corrupt +# the list or produce duplicate / dropped records under load. +# ------------------------------------------------------------------ # +def worker_ingest(records: list): + transformed = [transform_record(r) for r in records] + # Dangerous: lock not held while appending — race condition + _batch_results.extend(transformed) + + +def worker_ingest_safe(records: list): + transformed = [transform_record(r) for r in records] + with _batch_lock: + _batch_results.extend(transformed) + + +def run_parallel_ingest(batches: list): + threads = [threading.Thread(target=worker_ingest, args=(b,)) for b in batches] + for t in threads: + t.start() + for t in threads: + t.join() + + +# ------------------------------------------------------------------ # +# VULNERABILITY 4: Insecure Temporary File — tempfile.mktemp() returns +# a filename but does NOT create the file atomically. Between the name +# being returned and the caller opening the file, an attacker can +# create a symlink at that path (TOCTOU / symlink attack). +# ------------------------------------------------------------------ # +def write_staging_file(data: bytes) -> str: + # Dangerous: mktemp() has a TOCTOU window + tmp_path = tempfile.mktemp(suffix=".csv") + with open(tmp_path, "wb") as f: + f.write(data) + return tmp_path + + +def write_staging_file_safe(data: bytes) -> str: + fd, tmp_path = tempfile.mkstemp(suffix=".csv") + with os.fdopen(fd, "wb") as f: + f.write(data) + return tmp_path + + +def flush_results_to_db(results: list, conn) -> int: + if not results: + return 0 + cur = conn.cursor() + rows = [ + (r["source_id"], r["value"], r["category"], r["ingested_at"]) + for r in results + ] + cur.executemany( + "INSERT INTO staging (source_id, value, category, ingested_at) VALUES (%s, %s, %s, %s)", + rows, + ) + conn.commit() + return len(rows) + + +def clear_batch(): + with _batch_lock: + _batch_results.clear() + + +def get_batch_summary() -> dict: + with _batch_lock: + return { + "total_records": len(_batch_results), + "categories": list({r["category"] for r in _batch_results}), + } diff --git a/evals/fixtures/file_handler.py b/evals/fixtures/file_handler.py new file mode 100644 index 0000000..d3a0bfa --- /dev/null +++ b/evals/fixtures/file_handler.py @@ -0,0 +1,143 @@ +""" +file_handler.py — File upload, export, and report generation service. +Handles user-submitted documents, archives, and scheduled exports. +""" + +import os +import pickle +import subprocess +import tarfile +import tempfile +import shutil +from pathlib import Path +from typing import Optional + + +UPLOAD_DIR = "/var/app/uploads" +EXPORT_DIR = "/var/app/exports" +ALLOWED_EXTENSIONS = {".pdf", ".png", ".jpg", ".jpeg", ".txt", ".csv"} +MAX_FILE_SIZE_MB = 10 + + +def get_upload_path(user_id: int, filename: str) -> Path: + return Path(UPLOAD_DIR) / str(user_id) / filename + + +def allowed_extension(filename: str) -> bool: + return Path(filename).suffix.lower() in ALLOWED_EXTENSIONS + + +def human_readable_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB"]: + if size_bytes < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f} TB" + + +# ------------------------------------------------------------------ # +# VULNERABILITY 1: Path Traversal — filename from user input is joined +# directly onto the base upload dir. A filename like +# `../../etc/passwd` lets the caller read or overwrite arbitrary files. +# ------------------------------------------------------------------ # +def save_upload(user_id: int, filename: str, data: bytes) -> str: + # Dangerous: filename is not sanitised — allows directory traversal + dest = os.path.join(UPLOAD_DIR, str(user_id), filename) + os.makedirs(os.path.dirname(dest), exist_ok=True) + with open(dest, "wb") as f: + f.write(data) + return dest + + +def safe_save_upload(user_id: int, filename: str, data: bytes) -> str: + safe_name = Path(filename).name # strip any directory components + dest = Path(UPLOAD_DIR) / str(user_id) / safe_name + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(data) + return str(dest) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 2: Command Injection — the filename is interpolated +# directly into a shell command string. A filename like +# `file.pdf; rm -rf /` will execute arbitrary shell commands. +# ------------------------------------------------------------------ # +def generate_thumbnail(filename: str, width: int = 200) -> str: + output = filename.replace(".", "_thumb.") + # Dangerous: unsanitised filename in shell command + os.system(f"convert {filename} -resize {width}x {output}") + return output + + +def generate_thumbnail_safe(filename: str, width: int = 200) -> str: + output = filename.replace(".", "_thumb.") + # Safe: passes args as a list — no shell interpolation + subprocess.run( + ["convert", filename, "-resize", f"{width}x", output], + check=True, + shell=False, + ) + return output + + +# ------------------------------------------------------------------ # +# VULNERABILITY 3: Insecure Deserialization — pickle.loads on untrusted +# data allows arbitrary code execution. An attacker-crafted payload can +# run any Python code on the server at deserialization time. +# ------------------------------------------------------------------ # +def load_user_session(session_blob: bytes) -> dict: + # Dangerous: pickle.loads on attacker-controlled bytes + return pickle.loads(session_blob) + + +def load_user_session_safe(session_blob: bytes) -> dict: + import json + return json.loads(session_blob.decode("utf-8")) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 4: Zip/Tar Slip — extracting a tar archive without +# checking member paths allows files to be written outside the target +# directory (e.g. overwriting ~/.bashrc or /etc/cron.d entries). +# ------------------------------------------------------------------ # +def extract_archive(archive_path: str, dest_dir: str): + # Dangerous: no path validation on archive members + with tarfile.open(archive_path, "r:gz") as tar: + tar.extractall(path=dest_dir) + + +def extract_archive_safe(archive_path: str, dest_dir: str): + dest = Path(dest_dir).resolve() + with tarfile.open(archive_path, "r:gz") as tar: + for member in tar.getmembers(): + member_path = (dest / member.name).resolve() + if not str(member_path).startswith(str(dest)): + raise ValueError(f"Path traversal detected in archive: {member.name}") + tar.extractall(path=dest_dir) + + +def delete_upload(user_id: int, filename: str) -> bool: + path = get_upload_path(user_id, filename) + if path.exists(): + path.unlink() + return True + return False + + +def list_uploads(user_id: int) -> list: + base = Path(UPLOAD_DIR) / str(user_id) + if not base.exists(): + return [] + return [f.name for f in base.iterdir() if f.is_file()] + + +def get_file_metadata(path: str) -> Optional[dict]: + p = Path(path) + if not p.exists(): + return None + stat = p.stat() + return { + "name": p.name, + "size": human_readable_size(stat.st_size), + "modified": stat.st_mtime, + } diff --git a/evals/fixtures/mobile_backend.py b/evals/fixtures/mobile_backend.py new file mode 100644 index 0000000..4e5a193 --- /dev/null +++ b/evals/fixtures/mobile_backend.py @@ -0,0 +1,133 @@ +""" +mobile_backend.py — Backend API for a mobile application. +Manages push notifications, user preferences, XML config imports, +and device registration. +""" + +import xml.etree.ElementTree as ET +from lxml import etree +import requests +import json +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +PUSH_SERVICE_URL = "https://push.internal/notify" +DEVICE_REGISTRY = {} + + +def register_device(user_id: int, device_token: str, platform: str) -> dict: + if platform not in ("ios", "android"): + raise ValueError(f"Unsupported platform: {platform}") + DEVICE_REGISTRY[device_token] = { + "user_id": user_id, + "platform": platform, + "active": True, + } + return {"status": "registered", "token": device_token} + + +def deregister_device(device_token: str) -> bool: + if device_token in DEVICE_REGISTRY: + DEVICE_REGISTRY[device_token]["active"] = False + return True + return False + + +def get_active_devices(user_id: int) -> list: + return [ + {"token": tok, **info} + for tok, info in DEVICE_REGISTRY.items() + if info["user_id"] == user_id and info["active"] + ] + + +# ------------------------------------------------------------------ # +# VULNERABILITY 1: XXE (XML External Entity) — lxml's default parser +# resolves external entities. A crafted XML payload can read local +# files (e.g. /etc/passwd) or trigger SSRF to internal services. +# ------------------------------------------------------------------ # +def import_user_preferences(xml_payload: str) -> dict: + # Dangerous: default lxml parser resolves external entities + root = etree.fromstring(xml_payload.encode()) + prefs = {} + for child in root: + prefs[child.tag] = child.text + return prefs + + +def import_user_preferences_safe(xml_payload: str) -> dict: + parser = etree.XMLParser(resolve_entities=False, no_network=True) + root = etree.fromstring(xml_payload.encode(), parser=parser) + prefs = {} + for child in root: + prefs[child.tag] = child.text + return prefs + + +# ------------------------------------------------------------------ # +# VULNERABILITY 2: Insecure Direct Object Reference (IDOR) — the +# endpoint accepts a `target_user_id` parameter without checking +# that the requester owns that account. Any authenticated user can +# read or overwrite another user's notification preferences. +# ------------------------------------------------------------------ # +def get_notification_settings(requester_id: int, target_user_id: int, db) -> dict: + # Dangerous: no ownership check — IDOR + return db.get_notification_settings(target_user_id) + + +def get_notification_settings_safe(requester_id: int, target_user_id: int, db) -> dict: + if requester_id != target_user_id: + raise PermissionError("Cannot access another user's notification settings") + return db.get_notification_settings(target_user_id) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 3: Mass Assignment — the entire request body is applied +# to the user record without filtering. An attacker can set privileged +# fields such as `is_admin=true` or `subscription_tier=enterprise`. +# ------------------------------------------------------------------ # +def update_user_settings(user_id: int, request_body: dict, db) -> dict: + # Dangerous: no field allowlist — all keys from request are written + db.update_user(user_id, **request_body) + return {"status": "updated"} + + +def update_user_settings_safe(user_id: int, request_body: dict, db) -> dict: + ALLOWED_FIELDS = {"display_name", "language", "timezone", "notifications_enabled"} + filtered = {k: v for k, v in request_body.items() if k in ALLOWED_FIELDS} + db.update_user(user_id, **filtered) + return {"status": "updated"} + + +def send_push_notification(device_token: str, title: str, body: str) -> bool: + payload = {"token": device_token, "title": title, "body": body} + try: + resp = requests.post(PUSH_SERVICE_URL, json=payload, timeout=5) + resp.raise_for_status() + return True + except requests.RequestException as e: + logger.error("Push notification failed: %s", e) + return False + + +def broadcast_to_user(user_id: int, title: str, body: str) -> int: + devices = get_active_devices(user_id) + sent = 0 + for device in devices: + if send_push_notification(device["token"], title, body): + sent += 1 + return sent + + +def parse_device_metadata(raw_json: str) -> Optional[dict]: + try: + data = json.loads(raw_json) + return { + "os_version": data.get("os_version"), + "app_version": data.get("app_version"), + "locale": data.get("locale"), + } + except (json.JSONDecodeError, KeyError): + return None diff --git a/evals/fixtures/web_api.py b/evals/fixtures/web_api.py new file mode 100644 index 0000000..8607818 --- /dev/null +++ b/evals/fixtures/web_api.py @@ -0,0 +1,142 @@ +""" +web_api.py — User management REST API built with Flask + SQLite. +Handles account creation, login, profile updates, and image proxying. +""" + +import sqlite3 +import hashlib +import urllib.request +from flask import Flask, request, jsonify, redirect + +app = Flask(__name__) +DB_PATH = "users.db" + + +def get_db(): + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + + +def validate_email(email: str) -> bool: + return "@" in email and "." in email.split("@")[-1] + + +def normalize_username(username: str) -> str: + return username.strip().lower() + + +# ------------------------------------------------------------------ # +# VULNERABILITY 1: SQL Injection (line ~37) +# User input is concatenated directly into the SQL query string. +# ------------------------------------------------------------------ # +@app.route("/api/users/search") +def search_users(): + query = request.args.get("q", "") + conn = get_db() + # Dangerous: unsanitised `query` is injected directly into SQL + sql = f"SELECT id, username, email FROM users WHERE username LIKE '%{query}%'" + results = conn.execute(sql).fetchall() + conn.close() + return jsonify([dict(r) for r in results]) + + +@app.route("/api/users/") +def get_user(user_id): + conn = get_db() + row = conn.execute( + "SELECT id, username, email, bio FROM users WHERE id = ?", (user_id,) + ).fetchone() + conn.close() + if row is None: + return jsonify({"error": "not found"}), 404 + return jsonify(dict(row)) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 2: Stored XSS — bio is stored raw and returned +# without sanitisation; any HTML/JS in bio will execute in the browser. +# ------------------------------------------------------------------ # +@app.route("/api/users//profile", methods=["PUT"]) +def update_profile(user_id): + data = request.get_json() + bio = data.get("bio", "") + # No sanitisation — raw HTML/JS stored directly + conn = get_db() + conn.execute("UPDATE users SET bio = ? WHERE id = ?", (bio, user_id)) + conn.commit() + conn.close() + return jsonify({"status": "updated"}) + + +@app.route("/api/users//avatar") +def get_avatar(user_id): + conn = get_db() + row = conn.execute( + "SELECT avatar_url FROM users WHERE id = ?", (user_id,) + ).fetchone() + conn.close() + if not row: + return jsonify({"error": "not found"}), 404 + return jsonify({"url": row["avatar_url"]}) + + +# ------------------------------------------------------------------ # +# VULNERABILITY 3: SSRF — the `url` param is fetched server-side with +# no allowlist; attackers can target internal services (e.g. metadata API). +# ------------------------------------------------------------------ # +@app.route("/api/proxy/image") +def proxy_image(): + url = request.args.get("url", "") + # Dangerous: fetches arbitrary URLs — enables SSRF + with urllib.request.urlopen(url) as response: + data = response.read() + return data, 200, {"Content-Type": "image/jpeg"} + + +# ------------------------------------------------------------------ # +# VULNERABILITY 4: Open Redirect — destination is taken directly from +# user input; attackers craft phishing links via this endpoint. +# ------------------------------------------------------------------ # +@app.route("/login/callback") +def login_callback(): + token = request.args.get("token") + next_url = request.args.get("next", "/dashboard") + if not token: + return jsonify({"error": "missing token"}), 400 + # Dangerous: redirects to attacker-controlled URL with no validation + return redirect(next_url) + + +@app.route("/health") +def health(): + return jsonify({"status": "ok"}) + + +@app.route("/api/stats") +def stats(): + conn = get_db() + total = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0] + conn.close() + return jsonify({"total_users": total}) + + +def create_tables(): + conn = get_db() + conn.execute( + """CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + email TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + bio TEXT DEFAULT '', + avatar_url TEXT DEFAULT '' + )""" + ) + conn.commit() + conn.close() + + +if __name__ == "__main__": + create_tables() + app.run(debug=True) diff --git a/evals/run_evals.py b/evals/run_evals.py new file mode 100644 index 0000000..f021af7 --- /dev/null +++ b/evals/run_evals.py @@ -0,0 +1,265 @@ +""" +run_evals.py — Eval harness for CodeScanAI's V2 agent scanner. + +For each fixture file, runs the agent and checks the structured output +against the expected findings in expected_findings.json. + +A required finding is matched if the scanner returns at least one +Vulnerability whose vulnerability_type (lowercased) contains any of the +expected keywords. + +Exit code 0 = all files passed their threshold. +Exit code 1 = one or more files failed. + +Usage: + python evals/run_evals.py [--provider openai] [--model gpt-4o-mini] +""" + +import argparse +import json +import os +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Optional + +# Make sure the repo root is on the path when running from evals/ +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from core.code_scanner.agent_scanner import AgentScanner +from core.agent import FileScanResult, Vulnerability + + +FIXTURES_DIR = Path(__file__).parent / "fixtures" +MANIFEST_PATH = Path(__file__).parent / "expected_findings.json" + +# ANSI colours +GREEN = "\033[92m" +RED = "\033[91m" +YELLOW = "\033[93m" +CYAN = "\033[96m" +BOLD = "\033[1m" +RESET = "\033[0m" + + +def detect_model_tier(model: Optional[str], manifest: dict) -> str: + """ + Return 'advanced' if the model is an advanced-tier model, otherwise 'standard'. + Standard-tier patterns are checked first so that 'gpt-4o-mini' is not + accidentally matched by the 'gpt-4o' advanced pattern. + """ + if not model: + return "standard" + model_lower = model.lower() + tiers = manifest.get("model_tiers", {}) + # Standard wins if any standard pattern matches — prevents substring false-positives + for pattern in tiers.get("standard", []): + if pattern.lower() == model_lower: + return "standard" + for pattern in tiers.get("advanced", []): + if pattern.lower() == model_lower: + return "advanced" + return "standard" + + +def _match_finding(vuln: Vulnerability, expected: dict) -> bool: + """Return True if `vuln` matches an expected finding by keyword.""" + vtype = vuln.vulnerability_type.lower() + desc = vuln.description.lower() + combined = vtype + " " + desc + return any(kw in combined for kw in expected["vulnerability_type_keywords"]) + + +def eval_fixture( + scanner: AgentScanner, + fixture_path: Path, + manifest: dict, + threshold: float, + model_tier: str = "standard", +) -> dict: + """ + Scan a single fixture and compare against the manifest. + Advanced-tier findings are only required when model_tier == 'advanced'; + otherwise they are treated as optional bonuses. + Returns a result dict with pass/fail details. + """ + filename = fixture_path.name + fixture_manifest = manifest["fixtures"].get(filename) + if not fixture_manifest: + return {"file": filename, "skipped": True} + + all_required = fixture_manifest["required"] + all_optional = fixture_manifest.get("optional", []) + + # Partition required findings by whether the current tier demands them + if model_tier == "advanced": + required = all_required + optional = all_optional + advanced_promoted = [] # nothing demoted in advanced mode + else: + required = [f for f in all_required if f.get("tier", "standard") == "standard"] + # Advanced-tier required findings become optional bonuses on standard models + advanced_promoted = [f for f in all_required if f.get("tier") == "advanced"] + optional = all_optional + advanced_promoted + + # Capture structured results by monkey-patching _scan_single_file + found_vulns: list[Vulnerability] = [] + original = scanner._scan_single_file + + def capturing_scan(file_path, display_name=""): + nonlocal found_vulns + import os as _os + if not _os.path.isfile(file_path): + return + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + if not content.strip(): + return + lines = content.splitlines() + numbered = "\n".join(f"{i+1}: {l}" for i, l in enumerate(lines)) + try: + result = scanner.agent.run_sync(f"File: {display_name}\n\n{numbered}") + found_vulns.extend(result.data.vulnerabilities) + except Exception as e: + print(f" {RED}Agent error:{RESET} {e}") + + scanner._scan_single_file = capturing_scan + + try: + scanner._scan_single_file(str(fixture_path), display_name=filename) + finally: + scanner._scan_single_file = original + + # Match required findings + matched_required = [] + missed_required = [] + for exp in required: + if any(_match_finding(v, exp) for v in found_vulns): + matched_required.append(exp) + else: + missed_required.append(exp) + + # Match optional findings + matched_optional = [ + exp for exp in optional if any(_match_finding(v, exp) for v in found_vulns) + ] + + total_required = len(required) + score = len(matched_required) / total_required if total_required else 1.0 + passed = score >= threshold + + return { + "file": filename, + "skipped": False, + "passed": passed, + "score": score, + "threshold": threshold, + "required_total": total_required, + "required_matched": len(matched_required), + "matched_required": matched_required, + "missed_required": missed_required, + "optional_matched": matched_optional, + "advanced_promoted": advanced_promoted, + "total_vulnerabilities_found": len(found_vulns), + } + + +def print_result(result: dict): + if result.get("skipped"): + print(f" {YELLOW}SKIPPED{RESET} (not in manifest)") + return + + status = f"{GREEN}PASS{RESET}" if result["passed"] else f"{RED}FAIL{RESET}" + pct = f"{result['score']*100:.0f}%" + print( + f" {status} {pct} required findings detected " + f"({result['required_matched']}/{result['required_total']}) " + f"[threshold: {result['threshold']*100:.0f}%]" + ) + + if result["missed_required"]: + print(f" {RED}Missed required:{RESET}") + for m in result["missed_required"]: + print(f" - [{m['severity']}] {m['note']} (id: {m['id']})") + + if result["optional_matched"]: + tier_note = "" + print(f" {GREEN}Bonus findings detected:{RESET}") + for m in result["optional_matched"]: + tier_label = f" {CYAN}[advanced]{RESET}" if m.get("tier") == "advanced" else "" + print(f" +{tier_label} {m['note']} (id: {m['id']})") + + if result.get("advanced_promoted"): + print(f" {CYAN}Advanced-tier findings treated as optional (standard model):{RESET}") + for m in result["advanced_promoted"]: + print(f" ~ [{m['severity']}] {m['note']} (id: {m['id']})") + + print(f" Total vulnerabilities returned by scanner: {result['total_vulnerabilities_found']}") + + +def main(): + parser = argparse.ArgumentParser(description="Run CodeScanAI evals") + parser.add_argument("--provider", default="openai") + parser.add_argument("--model", default=None) + parser.add_argument("--host", default=None) + parser.add_argument("--port", default=None, type=int) + parser.add_argument("--token", default=None) + parser.add_argument("--endpoint", default=None) + parser.add_argument("--fixture", default=None, help="Run a single fixture by filename") + cli_args = parser.parse_args() + + with open(MANIFEST_PATH) as f: + manifest = json.load(f) + + threshold = manifest["threshold"] + + # Build a minimal args namespace that AgentScanner expects + scanner_args = SimpleNamespace( + provider=cli_args.provider, + model=cli_args.model, + host=cli_args.host, + port=cli_args.port, + token=cli_args.token, + endpoint=cli_args.endpoint, + directory=str(FIXTURES_DIR), + changes_only=False, + repo=None, + pr_number=None, + github_token=None, + ) + + model_tier = detect_model_tier(cli_args.model, manifest) + tier_label = f"{CYAN}advanced{RESET}" if model_tier == "advanced" else f"{YELLOW}standard{RESET}" + + print(f"\n{BOLD}CodeScanAI Evals{RESET}") + print(f"Provider: {cli_args.provider} Model: {cli_args.model or 'default'} Tier: {tier_label}") + print(f"Threshold: {threshold*100:.0f}% Fixtures: {FIXTURES_DIR}\n") + + scanner = AgentScanner(scanner_args) + + fixtures = sorted(FIXTURES_DIR.glob("*.py")) + if cli_args.fixture: + fixtures = [f for f in fixtures if f.name == cli_args.fixture] + if not fixtures: + print(f"{RED}No fixture named '{cli_args.fixture}' found.{RESET}") + sys.exit(1) + + all_passed = True + for fixture_path in fixtures: + print(f"{BOLD}{fixture_path.name}{RESET}") + result = eval_fixture(scanner, fixture_path, manifest, threshold, model_tier=model_tier) + print_result(result) + print() + if not result.get("skipped") and not result["passed"]: + all_passed = False + + if all_passed: + print(f"{GREEN}{BOLD}All evals passed.{RESET}") + sys.exit(0) + else: + print(f"{RED}{BOLD}One or more evals failed.{RESET}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 6d3d074..3dc4291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "codescanai" -version = "0.1.1" +version = "0.2.0" description = "A CLI tool that scans your codebases for security vulnerabilities powered by powerful AI models." readme = "README.md" authors = [{ name = "Caleb Abhulimhen", email = "calebabhulimhen@gmail.com" }] @@ -20,7 +20,9 @@ dependencies = [ "PyGithub", "requests", "google-generativeai", - "ipython" + "ipython", + "pydantic", + "pydantic-ai" ] requires-python = ">=3.10" @@ -31,7 +33,7 @@ dev = ["pylint", "black", "isort"] Homepage = "https://github.com/codescan-ai/codescan" [project.scripts] -codescanai = "core.runner:main" +codescanai = "core.runner_v2:main" [tool.setuptools.packages.find] include = ["core", "core.*"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9075510..14f21d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,6 @@ openai==1.70.0 PyGithub==2.6.1 requests==2.32.3 google-generativeai==0.8.4 -ipython==8.27.0 \ No newline at end of file +ipython +pydantic +pydantic-ai \ No newline at end of file