diff --git a/src/skillspector/cli.py b/src/skillspector/cli.py index deae40a6..eb43d778 100644 --- a/src/skillspector/cli.py +++ b/src/skillspector/cli.py @@ -25,6 +25,7 @@ import os import shutil import sys +from dataclasses import replace from enum import StrEnum from pathlib import Path from typing import Annotated @@ -33,10 +34,12 @@ from langchain_core.runnables import RunnableConfig from rich.console import Console -from skillspector import __version__ +from skillspector import __version__, transitive from skillspector.graph import graph from skillspector.logging_config import get_logger, set_level +from skillspector.models import Finding from skillspector.multi_skill import MultiSkillDetectionResult, detect_skills +from skillspector.nodes.report import report from skillspector.suppression import build_baseline_dict, dump_baseline, load_baseline logger = get_logger(__name__) @@ -232,6 +235,36 @@ def scan( "do not count toward the risk score).", ), ] = False, + transitive: Annotated[ + bool, + typer.Option( + "--transitive", + help="Follow transitive external references after the initial scan.", + ), + ] = False, + transitive_depth: Annotated[ + int, + typer.Option( + "--transitive-depth", + help="Maximum transitive depth to scan for external references.", + ), + ] = 1, + transitive_allow_prefix: Annotated[ + list[str] | None, + typer.Option( + "--transitive-allow-prefix", + help=( + "Only scan transitive targets matching at least one canonical prefix. Repeatable." + ), + ), + ] = None, + transitive_deny_prefix: Annotated[ + list[str] | None, + typer.Option( + "--transitive-deny-prefix", + help=("Skip transitive targets matching any canonical prefix. Repeatable."), + ), + ] = None, verbose: Annotated[ bool, typer.Option( @@ -275,10 +308,24 @@ def scan( set_level("DEBUG") resolved_path = Path(input_path).resolve() + yara_dir = str(yara_rules_dir.resolve()) if yara_rules_dir else None if recursive and resolved_path.is_dir(): detection = detect_skills(resolved_path) if detection.is_multi_skill: - _scan_multi_skill(detection, format, output, no_llm, yara_rules_dir, verbose) + _scan_multi_skill( + detection=detection, + format=format, + output=output, + no_llm=no_llm, + baseline=baseline, + show_suppressed=show_suppressed, + transitive_enabled=transitive, + transitive_depth=transitive_depth, + transitive_allow_prefix=transitive_allow_prefix, + transitive_deny_prefix=transitive_deny_prefix, + yara_dir=yara_dir, + verbose=verbose, + ) return if not detection.has_root_skill and len(detection.skills) == 0: console.print( @@ -295,26 +342,19 @@ def scan( result = None try: - yara_dir = str(yara_rules_dir.resolve()) if yara_rules_dir else None - state = _scan_state( - input_path, - format, - no_llm, - yara_rules_dir=yara_dir, + result = _scan_skill( + input_path=input_path, + format=format, + no_llm=no_llm, baseline=baseline, + yara_rules_dir=Path(yara_dir) if yara_dir else None, + verbose=verbose, show_suppressed=show_suppressed, + transitive_enabled=transitive, + transitive_depth=transitive_depth, + transitive_allow_prefix=transitive_allow_prefix, + transitive_deny_prefix=transitive_deny_prefix, ) - if verbose: - console.print("[dim]Running scan...[/dim]") - logger.debug( - "Scan started: input_path=%s, format=%s, use_llm=%s", - input_path, - format, - not no_llm, - ) - trace_config = _build_trace_config(input_path, format, no_llm) - result = graph.invoke(state, config=trace_config) - _write_result(result, output, format) if (result.get("risk_score") or 0) > 50: @@ -353,62 +393,320 @@ def _build_trace_config(input_path: str, format: FormatChoice, no_llm: bool) -> } +def _coerce_str_path_list(value: object) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item) for item in value if isinstance(item, str)] + + +def _coerce_findings_list(value: object) -> list[Finding]: + if not isinstance(value, list): + return [] + return [finding for finding in value if isinstance(finding, Finding)] + + +def _merge_unique_by_path(items: list[dict[str, object]]) -> list[dict[str, object]]: + merged: list[dict[str, object]] = [] + seen: set[str] = set() + for item in items: + path = str(item.get("path", "")) + if path in seen: + continue + seen.add(path) + merged.append(item) + return merged + + +def _scan_state_with_baseline( + input_path: str, + format: FormatChoice, + no_llm: bool, + *, + yara_rules_dir: str | None = None, + baseline: Path | None = None, + show_suppressed: bool = False, +) -> dict[str, object]: + return _scan_state( + input_path=input_path, + format=format, + no_llm=no_llm, + yara_rules_dir=yara_rules_dir, + baseline=baseline, + show_suppressed=show_suppressed, + ) + + +def _run_graph_scan( + input_path: str, + format: FormatChoice, + no_llm: bool, + yara_dir: str | None = None, + baseline: Path | None = None, + show_suppressed: bool = False, +) -> dict[str, object]: + state = _scan_state_with_baseline( + input_path=input_path, + format=format, + no_llm=no_llm, + yara_rules_dir=yara_dir, + baseline=baseline, + show_suppressed=show_suppressed, + ) + trace_config = _build_trace_config(input_path, format, no_llm) + return graph.invoke(state, config=trace_config) + + +def _annotate_transitive_findings( + findings: list[Finding], + source_url: str, + transitive_depth: int, +) -> list[Finding]: + return [ + replace(finding, transitive_depth=transitive_depth, source_url=source_url) + for finding in findings + ] + + +def _scan_transitive( + initial_result: dict[str, object], + format: FormatChoice, + no_llm: bool, + max_depth: int, + transitive_allow_prefix: tuple[str, ...] | list[str] | None, + transitive_deny_prefix: tuple[str, ...] | list[str] | None, + baseline: Path | None, + show_suppressed: bool, + visited: set[str], + yara_dir: str | None = None, +) -> dict[str, object]: + if max_depth <= 0: + report_result = report(initial_result) + report_result["temp_dir_for_cleanup"] = initial_result.get("temp_dir_for_cleanup") + report_result["transitive_finding_count"] = 0 + report_result["transitive_sources"] = [] + return report_result + + transitive_sources: set[str] = set() + merged_filtered_findings: list[Finding] = _coerce_findings_list( + initial_result.get("filtered_findings") + ) + merged_findings: list[Finding] = _coerce_findings_list(initial_result.get("findings")) + merged_components = _coerce_str_path_list(initial_result.get("components")) + merged_file_cache = initial_result.get("file_cache") or {} + file_cache = merged_file_cache if isinstance(merged_file_cache, dict) else {} + component_metadata = _coerce_component_metadata(initial_result.get("component_metadata")) + has_executable_scripts = bool(initial_result.get("has_executable_scripts", False)) + + frontier: list[tuple[int, list[str]]] = [(1, transitive.extract_external_refs(file_cache))] + + while frontier: + current_depth, refs = frontier.pop(0) + targets = transitive.plan_transitive_targets( + refs=refs, + visited=visited, + current_depth=current_depth, + max_depth=max_depth, + allow_prefixes=transitive_allow_prefix, + deny_prefixes=transitive_deny_prefix, + ) + for target in targets: + child_result: dict[str, object] | None = None + try: + child_result = _run_graph_scan( + input_path=target, + format=format, + no_llm=no_llm, + yara_dir=yara_dir, + baseline=baseline, + show_suppressed=show_suppressed, + ) + transitive_sources.add(target) + child_filtered_findings = _coerce_findings_list( + child_result.get("filtered_findings") + ) + child_findings = _coerce_findings_list(child_result.get("findings")) + merged_filtered_findings.extend( + _annotate_transitive_findings( + child_filtered_findings, source_url=target, transitive_depth=current_depth + ) + ) + merged_findings.extend( + _annotate_transitive_findings( + child_findings, source_url=target, transitive_depth=current_depth + ) + ) + + child_metadata = _coerce_component_metadata(child_result.get("component_metadata")) + component_metadata.extend(child_metadata) + if any(entry.get("executable") for entry in child_metadata): + has_executable_scripts = True + merged_components.extend(_coerce_str_path_list(child_result.get("components"))) + + if current_depth < max_depth: + child_file_cache = child_result.get("file_cache") or {} + if isinstance(child_file_cache, dict): + child_refs = transitive.extract_external_refs(child_file_cache) + frontier.append((current_depth + 1, child_refs)) + except Exception as e: + if format == FormatChoice.json: + logger.warning("Transitive scan failed for %s: %s", target, e) + else: + console.print( + f"[yellow]Warning:[/yellow] Transitive scan failed for {target}: {e}" + ) + finally: + if child_result is not None: + _cleanup_result(child_result) + + merged_result: dict[str, object] = { + **initial_result, + "filtered_findings": merged_filtered_findings, + "findings": merged_findings, + "components": merged_components, + "component_metadata": _merge_unique_by_path(component_metadata), + "has_executable_scripts": has_executable_scripts, + } + transitive_finding_count = sum( + 1 for finding in merged_filtered_findings if finding.source_url is not None + ) + report_result = report(merged_result) + report_result["temp_dir_for_cleanup"] = initial_result.get("temp_dir_for_cleanup") + report_result["transitive_finding_count"] = transitive_finding_count + report_result["transitive_sources"] = sorted(transitive_sources) + return report_result + + +def _coerce_component_metadata(value: object) -> list[dict[str, object]]: + if not isinstance(value, list): + return [] + return [item for item in value if isinstance(item, dict)] + + +def _scan_skill( + input_path: str, + format: FormatChoice, + no_llm: bool, + baseline: Path | None, + yara_rules_dir: Path | None, + verbose: bool, + show_suppressed: bool, + transitive_enabled: bool, + transitive_depth: int, + transitive_allow_prefix: tuple[str, ...] | list[str] | None, + transitive_deny_prefix: tuple[str, ...] | list[str] | None, + visited: set[str] | None = None, +) -> dict[str, object]: + yara_dir = str(yara_rules_dir.resolve()) if yara_rules_dir else None + active_visited = visited if visited is not None else set() + try: + if verbose: + console.print("[dim]Running scan...[/dim]") + logger.debug( + "Scan started: input_path=%s, format=%s, use_llm=%s, transitive=%s", + input_path, + format, + not no_llm, + transitive_enabled, + ) + result = _run_graph_scan( + input_path=input_path, + format=format, + no_llm=no_llm, + yara_dir=yara_dir, + baseline=baseline, + show_suppressed=show_suppressed, + ) + if not transitive_enabled: + return result + transitive_allow_prefix = tuple(transitive_allow_prefix or ()) + transitive_deny_prefix = tuple(transitive_deny_prefix or ()) + try: + active_visited.add(transitive.canonicalize_source_identity(input_path)) + except ValueError: + pass + return _scan_transitive( + initial_result=result, + format=format, + no_llm=no_llm, + max_depth=transitive_depth, + transitive_allow_prefix=transitive_allow_prefix, + transitive_deny_prefix=transitive_deny_prefix, + baseline=baseline, + show_suppressed=show_suppressed, + visited=active_visited, + yara_dir=yara_dir, + ) + except Exception: + raise + + def _scan_multi_skill( detection: MultiSkillDetectionResult, format: FormatChoice, output: Path | None, no_llm: bool, - yara_rules_dir: Path | None, + baseline: Path | None, + show_suppressed: bool, + transitive_enabled: bool, + transitive_depth: int, + transitive_allow_prefix: tuple[str, ...] | list[str] | None, + transitive_deny_prefix: tuple[str, ...] | list[str] | None, + yara_dir: str | None, verbose: bool, ) -> None: """Scan each detected sub-skill independently and produce a combined report.""" skills = detection.skills console.print(f"[bold]Multi-skill directory detected:[/bold] {len(skills)} skills found\n") + visited: set[str] = set() results: list[dict[str, object]] = [] max_score = 0 + transitive_finding_count = 0 + transitive_sources: set[str] = set() for i, skill in enumerate(skills, 1): console.print( f" [{i}/{len(skills)}] Scanning [bold]{skill.name}[/bold] ({skill.relative_path}/)" ) - yara_dir = str(yara_rules_dir.resolve()) if yara_rules_dir else None - state = _scan_state(str(skill.path), format, no_llm, yara_rules_dir=yara_dir) - trace_config = _build_trace_config(str(skill.path), format, no_llm) - try: - result = graph.invoke(state, config=trace_config) + result = _scan_skill( + input_path=str(skill.path), + format=format, + no_llm=no_llm, + baseline=baseline, + yara_rules_dir=Path(yara_dir) if yara_dir else None, + verbose=verbose, + show_suppressed=show_suppressed, + transitive_enabled=transitive_enabled, + transitive_depth=transitive_depth, + transitive_allow_prefix=transitive_allow_prefix, + transitive_deny_prefix=transitive_deny_prefix, + visited=visited, + ) results.append(result) score = result.get("risk_score") or 0 if isinstance(score, int) and score > max_score: max_score = score + transitive_finding_count += int(result.get("transitive_finding_count") or 0) + for source in _coerce_str_path_list(result.get("transitive_sources")): + transitive_sources.add(source) severity = result.get("risk_severity") or "LOW" console.print(f" Score: {score}/100 ({severity})\n") except Exception as e: console.print(f" [red]Error:[/red] {e}\n") results.append({"skill_name": skill.name, "error": str(e)}) - console.print("\n[bold]═══ Multi-Skill Summary ═══[/bold]\n") - console.print(f" {'Skill':<30} {'Score':<8} {'Severity':<12} {'Findings':<10}") - console.print(f" {'─' * 30} {'─' * 8} {'─' * 12} {'─' * 10}") - - for skill, result in zip(skills, results, strict=True): - if "error" in result: - console.print(f" {skill.name:<30} {'ERROR':<8} {'—':<12} {'—':<10}") - continue - score = result.get("risk_score", 0) - severity = result.get("risk_severity", "LOW") - filtered = result.get("filtered_findings") or result.get("findings") - finding_count = len(filtered) if isinstance(filtered, list) else 0 - console.print(f" {skill.name:<30} {score:<8} {severity:<12} {finding_count:<10}") - - console.print("") + # Existing direct output behavior remains, but shared traversal and visited state + # are now handled by _scan_skill, including transitive helper path. + _print_multi_summary(skills, results) if output and format == FormatChoice.json: combined = { "multi_skill": True, "skill_count": len(skills), "max_risk_score": max_score, + "transitive_finding_count": transitive_finding_count, + "transitive_sources": sorted(transitive_sources), "skills": [], } for skill, result in zip(skills, results, strict=True): @@ -424,6 +722,8 @@ def _scan_multi_skill( "finding_count": len( result.get("filtered_findings") or result.get("findings") or [] ), + "transitive_finding_count": result.get("transitive_finding_count", 0), + "transitive_sources": result.get("transitive_sources", []), } ) Path(output).write_text(json.dumps(combined, indent=2), encoding="utf-8") @@ -437,6 +737,22 @@ def _scan_multi_skill( raise typer.Exit(code=1) +def _print_multi_summary(skills: list, results: list[dict[str, object]]) -> None: + console.print("\n[bold]=== Multi-Skill Summary ===[/bold]\n") + console.print(f" {'Skill':<30} {'Score':<8} {'Severity':<12} {'Findings':<10}") + console.print(f" {'-' * 30} {'-' * 8} {'-' * 12} {'-' * 10}") + + for skill, result in zip(skills, results, strict=True): + if "error" in result: + console.print(f" {skill.name:<30} {'ERROR':<8} {'n/a':<12} {'n/a':<10}") + continue + score = result.get("risk_score", 0) + severity = result.get("risk_severity", "LOW") + filtered = result.get("filtered_findings") or result.get("findings") + finding_count = len(filtered) if isinstance(filtered, list) else 0 + console.print(f" {skill.name:<30} {score:<8} {severity:<12} {finding_count:<10}") + + @app.command() def mcp( transport: Annotated[ diff --git a/src/skillspector/input_handler.py b/src/skillspector/input_handler.py index b70d0f20..6a993e43 100644 --- a/src/skillspector/input_handler.py +++ b/src/skillspector/input_handler.py @@ -35,7 +35,7 @@ import tempfile import zipfile from pathlib import Path -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import httpx @@ -54,6 +54,7 @@ ALLOWED_DOWNLOAD_HOSTS = frozenset( { "github.com", + "codeload.github.com", "raw.githubusercontent.com", "gitlab.com", "bitbucket.org", @@ -61,6 +62,26 @@ } ) +_DIRECT_FILE_URL_SUFFIXES = ( + ".md", + ".py", + ".sh", + ".bash", + ".zsh", + ".js", + ".ts", + ".rb", + ".go", + ".rs", + ".pl", + ".json", + ".yaml", + ".yml", + ".toml", + ".txt", + ".zip", +) + def _is_private_ip(host: str) -> bool: """Return True if host resolves to a private/reserved IP address.""" @@ -147,7 +168,11 @@ def _is_git_url(self, path: str) -> bool: parsed = urlparse(path) host = parsed.hostname or "" if any(allowed in host for allowed in ALLOWED_GIT_HOSTS): - if "/raw/" in path or "/blob/" in path or path.endswith((".md", ".py", ".sh")): + if ( + "/raw/" in path + or "/blob/" in path + or path.lower().endswith(_DIRECT_FILE_URL_SUFFIXES) + ): return False return True if path.endswith(".git"): @@ -208,15 +233,12 @@ def _clone_git(self, url: str) -> Path: def _download_file(self, url: str) -> Path: """Download a file from URL to a temporary directory.""" - self._validate_url_host(url, ALLOWED_DOWNLOAD_HOSTS) temp_dir = self._get_temp_dir() - parsed = urlparse(url) - filename = Path(parsed.path).name or "SKILL.md" try: - with httpx.Client(follow_redirects=False, timeout=30) as client: - response = client.get(url) - response.raise_for_status() - content = response.content + response, final_url = self._download_with_redirect_validation(url) + parsed = urlparse(final_url) + filename = Path(parsed.path).name or "SKILL.md" + content = response.content except httpx.HTTPError as e: logger.warning("Download failed for %s: %s", url, e) raise ValueError(f"Failed to download file: {e}") from e @@ -230,6 +252,22 @@ def _download_file(self, url: str) -> Path: file_path.write_bytes(content) return temp_dir + def _download_with_redirect_validation(self, url: str) -> tuple[httpx.Response, str]: + current_url = url + for _ in range(5): + self._validate_url_host(current_url, ALLOWED_DOWNLOAD_HOSTS) + with httpx.Client(follow_redirects=False, timeout=30) as client: + response = client.get(current_url) + if response.status_code in {301, 302, 303, 307, 308}: + location = response.headers.get("location") + if not location: + raise ValueError(f"Redirect response missing location: {current_url}") + current_url = urljoin(current_url, location) + continue + response.raise_for_status() + return response, current_url + raise ValueError(f"Too many redirects while downloading: {url}") + def _extract_zip(self, zip_path: Path) -> Path: """Extract a zip file to a temporary directory with path traversal protection.""" if not zip_path.exists(): diff --git a/src/skillspector/models.py b/src/skillspector/models.py index 9a478219..4e9dd632 100644 --- a/src/skillspector/models.py +++ b/src/skillspector/models.py @@ -82,6 +82,8 @@ class Finding: tags: list[str] = field(default_factory=list) context: str | None = None matched_text: str | None = None + transitive_depth: int = 0 + source_url: str | None = None def to_dict(self) -> dict[str, object]: """Return a JSON-serializable dict representation (full finding shape).""" @@ -104,6 +106,8 @@ def to_dict(self) -> dict[str, object]: # Tags surface markers like "llm-unconfirmed" (a high-severity static # finding the LLM filter did not confirm but which is preserved anyway). "tags": list(self.tags), + "transitive_depth": self.transitive_depth, + "source_url": self.source_url, } def __str__(self) -> str: diff --git a/src/skillspector/nodes/report.py b/src/skillspector/nodes/report.py index 6295e12c..088343a8 100644 --- a/src/skillspector/nodes/report.py +++ b/src/skillspector/nodes/report.py @@ -206,6 +206,9 @@ def _build_sarif( results: list[SarifResult] = [] seen_rule_ids: dict[str, str] = {} + def _finding_properties(finding: Finding) -> dict[str, object]: + return {"transitiveDepth": finding.transitive_depth, "sourceUrl": finding.source_url} + for finding in findings: if not finding.rule_id or not finding.message: continue @@ -225,6 +228,7 @@ def _build_sarif( ) ) ], + properties=_finding_properties(finding), ) ) if finding.rule_id not in seen_rule_ids: @@ -251,6 +255,7 @@ def _build_sarif( ) ) ], + properties=_finding_properties(finding), suppressions=[SarifSuppression(kind="external", justification=sf.reason)], ) ) @@ -552,6 +557,8 @@ def _format_markdown( lines.append(f"### {emoji} {sev}: {f.rule_id}\n") end = f"–{f.end_line}" if f.end_line and f.end_line != f.start_line else "" lines.append(f"**Location:** `{f.file}:{f.start_line}{end}` ") + if f.transitive_depth > 0 and f.source_url: + lines.append(f"**Transitive:** depth={f.transitive_depth}, source={f.source_url} ") lines.append(f"**Confidence:** {f.confidence:.0%} ") lines.append("") lines.append(f"**Message:** {f.message}") diff --git a/src/skillspector/sarif_models.py b/src/skillspector/sarif_models.py index c3256ad8..08a8e51a 100644 --- a/src/skillspector/sarif_models.py +++ b/src/skillspector/sarif_models.py @@ -84,6 +84,7 @@ class SarifResult(BaseModel): # When present, the result is suppressed; SARIF consumers (e.g. GitHub code # scanning) exclude suppressed results from counts but keep them for audit. suppressions: list[SarifSuppression] | None = None + properties: dict[str, object] | None = None class SarifReportingDescriptor(BaseModel): diff --git a/src/skillspector/suppression.py b/src/skillspector/suppression.py index f01de61b..e6cd0d5d 100644 --- a/src/skillspector/suppression.py +++ b/src/skillspector/suppression.py @@ -97,6 +97,8 @@ def finding_fingerprint(finding: Finding) -> str: str(finding.start_line or ""), str(finding.end_line or ""), (finding.message or "").strip(), + finding.source_url or "", + str(finding.transitive_depth or 0), ] ) digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16] diff --git a/src/skillspector/transitive.py b/src/skillspector/transitive.py new file mode 100644 index 00000000..cc140f9d --- /dev/null +++ b/src/skillspector/transitive.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for opt-in transitive external-source traversal.""" + +from __future__ import annotations + +import re +from urllib.parse import ParseResult, unquote, urlparse, urlunparse + +from skillspector.input_handler import ALLOWED_DOWNLOAD_HOSTS, ALLOWED_GIT_HOSTS + +_URL_PATTERN = re.compile(r"(?:https?://|git@)[^\s\]{}'\"<>`!?,;.)}]+") +_LEADING_PUNCTUATION = "([{\"'<" +_TRAILING_PUNCTUATION = "),.!?;:>\"'`]}" + +_NON_GIT_FILE_EXTENSIONS = frozenset( + {".md", ".py", ".sh", ".bash", ".zsh", ".js", ".ts", ".rb", ".go", ".rs", ".pl"} +) +_SUPPORTED_FILE_EXTENSIONS = frozenset( + { + ".md", + ".py", + ".sh", + ".bash", + ".zsh", + ".js", + ".ts", + ".rb", + ".go", + ".rs", + ".pl", + ".json", + ".yaml", + ".yml", + ".toml", + ".txt", + ".zip", + } +) + +_EXTERNAL_REF_PATTERN = re.compile(r"(?:https?://|git@)[^\s\"'<>`]+") + +_EXCLUDED_HOSTS = frozenset( + { + "img.shields.io", + "badge.fury.io", + "travis-ci.com", + "travis-ci.org", + } +) + +_EXCLUDED_PATH_MARKERS = frozenset( + { + "/badge", + "/badges", + "/blob/", + "/issues/", + "/pull/", + "/pulls/", + "/actions/", + "/workflows/", + "/checks/", + "/wiki", + "/ci/", + } +) + + +def canonicalize_source_identity(url: str) -> str: + """Return canonical URL identity used for dedupe and visited-state control.""" + token = _clean_token(url).strip() + if not token: + raise ValueError(f"Unsupported URL: {url}") + + parsed = _parse_url(token) + host = (parsed.hostname or "").lower() + if host.startswith("www."): + host = host[4:] + + netloc = host + if parsed.port: + netloc = f"{host}:{parsed.port}" + + path = (parsed.path or "/").rstrip("/") + path = path.removesuffix(".git") + return urlunparse(("https", netloc, path if path else "/", "", "", "")) + + +def extract_external_refs(file_cache: dict[str, str]) -> list[str]: + """Extract candidate external references from a file cache.""" + refs: list[str] = [] + seen: set[str] = set() + for raw_content in file_cache.values(): + if not isinstance(raw_content, str): + continue + for match in _EXTERNAL_REF_PATTERN.finditer(raw_content): + token = match.group(0) + try: + identity = canonicalize_source_identity(token) + except ValueError: + continue + if identity in seen: + continue + if not _is_source_reference(identity): + continue + refs.append(identity) + seen.add(identity) + return refs + + +def plan_transitive_targets( + refs: list[str], + visited: set[str], + current_depth: int, + max_depth: int, + allow_prefixes: tuple[str, ...], + deny_prefixes: tuple[str, ...], +) -> list[str]: + """Plan the next transitive scan wave and mutate visited for approved targets.""" + if current_depth > max_depth or max_depth <= 0: + return [] + if current_depth < 1: + current_depth = 1 + + normalized_allow_prefixes = tuple(_normalize_prefix(p) for p in allow_prefixes) + normalized_deny_prefixes = tuple(_normalize_prefix(p) for p in deny_prefixes) + + targets: list[str] = [] + for ref in refs: + try: + identity = canonicalize_source_identity(ref) + except ValueError: + continue + if not _is_source_reference(identity): + continue + if identity in visited: + continue + if normalized_allow_prefixes and not _matches_any_prefix( + identity, normalized_allow_prefixes + ): + continue + if normalized_deny_prefixes and _matches_any_prefix(identity, normalized_deny_prefixes): + continue + visited.add(identity) + targets.append(identity) + return targets + + +def _parse_url(url: str) -> ParseResult: + token = _clean_token(url) + if token.startswith("git@"): + return _parse_git_ssh_url(token) + parsed = urlparse(token) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Unsupported URL: {url}") + return parsed + + +def _parse_git_ssh_url(url: str) -> ParseResult: + match = re.fullmatch(r"git@([^:]+):(.+)", url) + if not match: + raise ValueError(f"Unsupported git URL format: {url}") + host = match.group(1).strip() + path = match.group(2).strip().lstrip("/") + return urlparse(f"https://{host}/{path}") + + +def _clean_token(token: str) -> str: + cleaned = token.strip() + while cleaned and cleaned[0] in _LEADING_PUNCTUATION: + cleaned = cleaned[1:] + while cleaned and cleaned[-1] in _TRAILING_PUNCTUATION: + cleaned = cleaned[:-1] + return cleaned.strip() + + +def _normalize_prefix(prefix: str) -> str: + if not prefix: + return "" + return canonicalize_source_identity(prefix) + + +def _matches_any_prefix(url: str, prefixes: tuple[str, ...]) -> bool: + return any(_matches_prefix(url, prefix) for prefix in prefixes if prefix) + + +def _matches_prefix(url: str, prefix: str) -> bool: + if url == prefix: + return True + if prefix.endswith("/"): + return url.startswith(prefix) + return url.startswith(prefix + "/") + + +def _is_source_reference(identity: str) -> bool: + parsed = urlparse(identity) + host = (parsed.hostname or "").lower() + if host.startswith("www."): + host = host[4:] + if not host: + return False + if host in _EXCLUDED_HOSTS: + return False + if not _is_allowed_host(host): + return False + + lower_path = unquote(parsed.path).lower() + if _has_excluded_path_marker(lower_path): + return False + + if _looks_like_git_reference(host, lower_path): + return True + return _looks_like_file_reference(host, lower_path, parsed.path) + + +def _has_excluded_path_marker(path: str) -> bool: + if path.endswith(".svg"): + return True + segments = [segment for segment in path.split("/") if segment] + if len(segments) < 3: + return False + ui_segment = segments[2] + return ui_segment in { + "actions", + "badge", + "badges", + "blob", + "checks", + "ci", + "issues", + "pull", + "pulls", + "tree", + "wiki", + "workflows", + } + + +def _looks_like_git_reference(host: str, path: str) -> bool: + if not _host_in_allowed_git_hosts(host): + return False + if not path or path == "/": + return False + if path.startswith("/raw/"): + return False + if path.startswith("/blob/"): + return False + if "/tree/" in path: + return False + + segments = [segment for segment in path.split("/") if segment] + if len(segments) < 2: + return False + if len(segments) >= 3 and segments[2] == "actions": + return False + + lower = path.lower() + return not any(lower.endswith(ext) for ext in _NON_GIT_FILE_EXTENSIONS) + + +def _looks_like_file_reference(host: str, lower_path: str, raw_path: str) -> bool: + if not _is_allowed_host(host): + return False + if raw_path.endswith("/"): + return False + extension = _split_extension(lower_path) + if not extension: + return False + return extension in _SUPPORTED_FILE_EXTENSIONS + + +def _is_allowed_host(host: str) -> bool: + return ( + host in ALLOWED_GIT_HOSTS + or host in {f"www.{entry}" for entry in ALLOWED_GIT_HOSTS} + or host in ALLOWED_DOWNLOAD_HOSTS + or host in {f"www.{entry}" for entry in ALLOWED_DOWNLOAD_HOSTS} + ) + + +def _host_in_allowed_git_hosts(host: str) -> bool: + return host in ALLOWED_GIT_HOSTS or host in {f"www.{entry}" for entry in ALLOWED_GIT_HOSTS} + + +def _split_extension(path: str) -> str: + return ( + "." + path.rsplit("/", 1)[-1].rsplit(".", 1)[-1] if "." in path.rsplit("/", 1)[-1] else "" + ) diff --git a/tests/nodes/test_report.py b/tests/nodes/test_report.py index 71445d65..475108e4 100644 --- a/tests/nodes/test_report.py +++ b/tests/nodes/test_report.py @@ -30,7 +30,7 @@ report, ) from skillspector.state import SkillspectorState -from skillspector.suppression import Baseline, SuppressionRule +from skillspector.suppression import Baseline, SuppressionRule, finding_fingerprint def _finding( @@ -474,6 +474,24 @@ def test_report_output_format_markdown(self) -> None: assert "## Components" in body assert "## Issues" in body + +def test_json_output_includes_transitive_provenance() -> None: + """JSON output includes transitive provenance fields from Finding.""" + finding = _finding("T1", severity="HIGH", confidence=1.0, file="dep.py") + finding.transitive_depth = 2 + finding.source_url = "https://github.com/org/transitive" + state: SkillspectorState = { + "filtered_findings": [finding], + "component_metadata": [], + "has_executable_scripts": False, + "manifest": {}, + "skill_path": "/tmp/skill", + "output_format": "json", + } + data = json.loads(report(state)["report_body"]) + assert data["issues"][0]["transitive_depth"] == 2 + assert data["issues"][0]["source_url"] == "https://github.com/org/transitive" + def test_report_output_format_terminal(self) -> None: """output_format terminal produces Rich-formatted output.""" state: SkillspectorState = { @@ -506,6 +524,25 @@ def test_report_output_format_sarif(self) -> None: assert "runs" in data assert data.get("$schema") or "runs" in data + +def test_markdown_output_labels_transitive_findings() -> None: + """Markdown output labels transitive findings with depth and source URL.""" + finding = _finding( + "T1", severity="HIGH", file="dep.py", confidence=0.9, message="transitive issue" + ) + finding.transitive_depth = 3 + finding.source_url = "https://github.com/org/transitive" + state: SkillspectorState = { + "filtered_findings": [finding], + "component_metadata": [], + "has_executable_scripts": False, + "manifest": {}, + "skill_path": "/tmp/skill", + "output_format": "markdown", + } + body = report(state)["report_body"] + assert "**Transitive:** depth=3, source=https://github.com/org/transitive" in body + def test_report_default_output_format_is_sarif(self) -> None: """When output_format is missing, report uses sarif.""" state: SkillspectorState = { @@ -573,6 +610,32 @@ def test_report_baseline_suppresses_finding_and_lowers_score() -> None: assert len(result["suppressed_findings"]) == 1 +def test_report_baseline_direct_fingerprint_does_not_suppress_transitive_finding() -> None: + """Transitive provenance keeps baseline fingerprints scoped to the original source.""" + direct = _finding("P5", "CRITICAL") + transitive = _finding("P5", "CRITICAL") + transitive.source_url = "https://github.com/evil/dep" + transitive.transitive_depth = 1 + + baseline = Baseline(fingerprints={finding_fingerprint(direct): "accepted root finding"}) + state: SkillspectorState = { + "findings": [transitive], + "filtered_findings": [transitive], + "component_metadata": [], + "has_executable_scripts": False, + "manifest": {}, + "skill_path": None, + "output_format": "json", + "baseline": baseline, + } + result = report(state) + body = json.loads(result["report_body"]) + assert result["risk_score"] > 0 + assert len(body["issues"]) == 1 + assert body["issues"][0]["source_url"] == "https://github.com/evil/dep" + assert result["suppressed_findings"] == [] + + def test_report_baseline_keeps_unmatched_finding() -> None: """Findings not matched by the baseline are kept and scored normally.""" baseline = Baseline(rules=[SuppressionRule(rule_id="SQP-1", reason="nit")]) diff --git a/tests/nodes/test_sarif_rules_and_empty_findings.py b/tests/nodes/test_sarif_rules_and_empty_findings.py index d4f9f945..820a31e8 100644 --- a/tests/nodes/test_sarif_rules_and_empty_findings.py +++ b/tests/nodes/test_sarif_rules_and_empty_findings.py @@ -19,6 +19,7 @@ from skillspector.models import Finding from skillspector.nodes.report import _build_sarif +from skillspector.sarif_models import validate_sarif_report def _make_finding(rule_id: str = "PE3", message: str = "Credential Access", **kwargs) -> Finding: @@ -155,3 +156,16 @@ def test_sarif_schema_present(self) -> None: sarif = _build_sarif(findings) assert "$schema" in sarif assert sarif["version"] == "2.1.0" + + +def test_sarif_transitive_properties_validate() -> None: + """Transitive provenance lands in SARIF properties and still validates.""" + finding = _make_finding("TR1", "Transitive Dependency") + finding.transitive_depth = 2 + finding.source_url = "https://github.com/org/dep" + sarif = _build_sarif([finding]) + validate_sarif_report(sarif) + result = sarif["runs"][0]["results"][0] + properties = result["properties"] + assert properties["transitiveDepth"] == 2 + assert properties["sourceUrl"] == "https://github.com/org/dep" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index b8c88238..c21ee2c4 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -20,7 +20,39 @@ from typer.testing import CliRunner +from skillspector import cli as cli +from skillspector import transitive from skillspector.cli import app +from skillspector.models import Finding + + +def _mock_graph_result( + findings: list[Finding] | None = None, + file_cache: dict[str, str] | None = None, + output_format: str = "json", +) -> dict[str, object]: + return { + "findings": findings or [], + "filtered_findings": findings or [], + "components": ["SKILL.md"], + "component_metadata": [], + "file_cache": file_cache or {}, + "has_executable_scripts": False, + "output_format": output_format, + } + + +def _finding(rule_id: str, message: str, file: str = "SKILL.md", depth: int = 0) -> Finding: + return Finding( + rule_id=rule_id, + message=message, + severity="HIGH", + confidence=0.9, + file=file, + start_line=1, + transitive_depth=depth, + ) + runner = CliRunner() @@ -113,3 +145,588 @@ def test_cli_baseline_generate_then_scan_round_trip(tmp_path: Path) -> None: data = json.loads(scan.output) assert data["issues"] == [] assert data["risk_assessment"]["score"] == 0 + + +def test_scan_without_transitive_invokes_graph_once(tmp_path: Path, monkeypatch) -> None: + """Direct scan without --transitive runs exactly one graph scan.""" + (tmp_path / "SKILL.md").write_text("# Safe", encoding="utf-8") + calls: list[str] = [] + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + calls.append(input_path) + return _mock_graph_result(output_format=format.value if format else "json") + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke(app, ["scan", str(tmp_path), "--format", "json"]) + assert result.exit_code == 0 + assert len(calls) == 1 + + +def test_scan_transitive_depth_one_merges_provenance(tmp_path: Path, monkeypatch) -> None: + """--transitive-depth 1 follows one approved external target and merges provenance.""" + direct_output = "See dependency: https://github.com/org/transitive.git" + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + if input_path == str(tmp_path): + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache={"SKILL.md": direct_output}, + output_format=format.value, + ) + return _mock_graph_result( + findings=[_finding("T1", "transitive finding", file="dep.py", depth=1)], + file_cache={}, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke( + app, + [ + "scan", + str(tmp_path), + "--format", + "json", + "--transitive", + "--transitive-depth", + "1", + "--no-llm", + ], + ) + assert result.exit_code == 0 + data = json.loads(result.output) + issues = data["issues"] + assert len(issues) == 2 + transitive_issue = next(issue for issue in issues if issue["source_url"] is not None) + assert transitive_issue["transitive_depth"] == 1 + assert transitive_issue["source_url"] == "https://github.com/org/transitive" + + +def test_scan_transitive_ignores_non_scannable_urls(tmp_path: Path, monkeypatch) -> None: + """Non-scannable documentation or badge URLs are not followed transitively.""" + calls: list[str] = [] + file_cache = { + "SKILL.md": ( + "badge: https://img.shields.io/github/stars/x/y " + "docs: https://github.com/org/repo/wiki/SkillSpector " + "issue: https://github.com/org/repo/issues/12" + ) + } + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + calls.append(input_path) + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache=file_cache, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke( + app, + [ + "scan", + str(tmp_path), + "--format", + "json", + "--transitive", + "--no-llm", + ], + ) + assert result.exit_code == 0 + assert len(calls) == 1 + data = json.loads(result.output) + assert len(data["issues"]) == 1 + + +def test_scan_transitive_allow_prefix_filters_targets(tmp_path: Path, monkeypatch) -> None: + """Allow prefix limits transitive traversal to matching canonical roots.""" + file_cache = { + "SKILL.md": "refs: https://github.com/allowed/dep.git and https://github.com/blocked/dep.git" + } + calls: list[str] = [] + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + calls.append(input_path) + if input_path == str(tmp_path): + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache=file_cache, + output_format=format.value, + ) + return _mock_graph_result( + findings=[_finding("T1", "transitive finding")], + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke( + app, + [ + "scan", + str(tmp_path), + "--format", + "json", + "--transitive", + "--transitive-allow-prefix", + "https://github.com/allowed/", + "--no-llm", + ], + ) + assert result.exit_code == 0 + assert calls[0] == str(tmp_path) + assert len(calls) == 2 + assert calls[1] == "https://github.com/allowed/dep" + data = json.loads(result.output) + assert any(issue["source_url"] == "https://github.com/allowed/dep" for issue in data["issues"]) + + +def test_scan_transitive_deny_prefix_skips_targets(tmp_path: Path, monkeypatch) -> None: + """Deny prefix blocks matching targets while still scanning siblings.""" + file_cache = { + "SKILL.md": ( + "refs: https://github.com/allowed/dep.git and https://github.com/blocked/dep.git" + ) + } + calls: list[str] = [] + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + calls.append(input_path) + if input_path == str(tmp_path): + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache=file_cache, + output_format=format.value, + ) + return _mock_graph_result( + findings=[_finding("T1", "transitive finding")], + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke( + app, + [ + "scan", + str(tmp_path), + "--format", + "json", + "--transitive", + "--transitive-deny-prefix", + "https://github.com/blocked/", + "--no-llm", + ], + ) + assert result.exit_code == 0 + assert calls[0] == str(tmp_path) + assert len(calls) == 2 + assert calls[1] == "https://github.com/allowed/dep" + + +def test_cli_passes_result_file_cache_to_transitive_owner(tmp_path: Path, monkeypatch) -> None: + """CLI passes completed direct graph file_cache into the transitive owner.""" + file_cache = {"SKILL.md": "deps https://github.com/org/dep.git"} + captured: list[dict[str, str]] = [] + + def fake_extract_external_refs(value: dict[str, str]) -> list[str]: + captured.append(value) + return [] + + def fake_run_graph_scan( + input_path: str, format, no_llm: bool, *args, **kwargs + ) -> dict[str, object]: + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache=file_cache if input_path == str(tmp_path) else {}, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + monkeypatch.setattr(transitive, "extract_external_refs", fake_extract_external_refs) + result = runner.invoke( + app, + [ + "scan", + str(tmp_path), + "--format", + "json", + "--transitive", + "--no-llm", + ], + ) + assert result.exit_code == 0 + assert captured == [file_cache] + + +def test_single_and_recursive_transitive_route_through_shared_helper( + tmp_path: Path, monkeypatch +) -> None: + """Both single and recursive scans call _scan_transitive for follow-up scanning.""" + (tmp_path / "SKILL.md").write_text("# Root", encoding="utf-8") + parent = tmp_path / "collection" + parent.mkdir() + for name in ("skill-a", "skill-b"): + skill = parent / name + skill.mkdir() + (skill / "SKILL.md").write_text(f"---\nname: {name}\n---\n# {name}", encoding="utf-8") + + single_calls: list[object] = [] + recursive_calls: list[object] = [] + + def fake_scan_transitive(*args, **kwargs) -> dict[str, object]: + if not recursive_calls and not single_calls: + single_calls.append(args) + else: + recursive_calls.append(args) + return { + "report_body": "{}", + "risk_score": 0, + "risk_severity": "LOW", + "transitive_finding_count": 0, + "transitive_sources": [], + } + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache={"SKILL.md": "x"}, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + monkeypatch.setattr(cli, "_scan_transitive", fake_scan_transitive) + + single = runner.invoke( + app, + ["scan", str(tmp_path), "--format", "json", "--transitive", "--no-llm"], + ) + assert single.exit_code == 0 + assert len(single_calls) == 1 + + multi_output = tmp_path / "multi.json" + recursive = runner.invoke( + app, + [ + "scan", + str(parent), + "--recursive", + "--format", + "json", + "--transitive", + "--output", + str(multi_output), + "--no-llm", + ], + ) + assert recursive.exit_code == 0 + # Each recursive child shares the same helper and runs once per child. + assert len(recursive_calls) == 2 + + +def test_transitive_resolver_failure_preserves_direct_report(tmp_path: Path, monkeypatch) -> None: + """A transitive resolver failure should preserve the direct report result.""" + target = "https://github.com/org/broken.git" + file_cache = {"SKILL.md": f"deps {target}"} + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + if input_path == str(tmp_path): + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache=file_cache, + output_format=format.value, + ) + raise ValueError("resolver failure") + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke( + app, + [ + "scan", + str(tmp_path), + "--format", + "json", + "--transitive", + "--no-llm", + ], + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert len(data["issues"]) == 1 + assert data["issues"][0]["id"] == "D1" + + +def test_scan_transitive_does_not_rescan_root_source(monkeypatch) -> None: + """A root external source is seeded in visited so self-references are not rescanned.""" + root_source = "https://github.com/org/root.git" + calls: list[str] = [] + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + calls.append(input_path) + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache={"SKILL.md": root_source}, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + result = runner.invoke( + app, + ["scan", root_source, "--format", "json", "--transitive", "--no-llm"], + ) + assert result.exit_code == 0 + assert calls == [root_source] + + +def test_scan_transitive_preserves_root_cleanup_and_counts_findings( + tmp_path: Path, monkeypatch +) -> None: + """Transitive merge keeps the root cleanup path and counts findings, not sources.""" + cleanup_root = tmp_path / "cleanup-root" + initial_result = _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache={"SKILL.md": "https://github.com/org/transitive.git"}, + ) + initial_result["temp_dir_for_cleanup"] = str(cleanup_root) + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + assert input_path == "https://github.com/org/transitive" + return _mock_graph_result( + findings=[ + _finding("T1", "transitive finding"), + _finding("T2", "second transitive finding"), + ], + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + merged = cli._scan_transitive( + initial_result=initial_result, + format=cli.FormatChoice.json, + no_llm=True, + max_depth=1, + transitive_allow_prefix=(), + transitive_deny_prefix=(), + baseline=None, + show_suppressed=False, + visited=set(), + ) + + assert merged["temp_dir_for_cleanup"] == str(cleanup_root) + assert merged["transitive_finding_count"] == 2 + assert merged["transitive_sources"] == ["https://github.com/org/transitive"] + + +def test_scan_transitive_zero_depth_preserves_root_cleanup(tmp_path: Path, monkeypatch) -> None: + """Zero-depth transitive scans preserve root cleanup metadata and do not recurse.""" + cleanup_root = tmp_path / "cleanup-root" + initial_result = _mock_graph_result(findings=[_finding("D1", "direct finding")]) + initial_result["temp_dir_for_cleanup"] = str(cleanup_root) + + def fail_run_graph_scan(*args, **kwargs) -> dict[str, object]: + raise AssertionError("zero-depth transitive scan should not recurse") + + monkeypatch.setattr(cli, "_run_graph_scan", fail_run_graph_scan) + merged = cli._scan_transitive( + initial_result=initial_result, + format=cli.FormatChoice.json, + no_llm=True, + max_depth=0, + transitive_allow_prefix=(), + transitive_deny_prefix=(), + baseline=None, + show_suppressed=False, + visited=set(), + ) + + assert merged["temp_dir_for_cleanup"] == str(cleanup_root) + assert merged["transitive_finding_count"] == 0 + assert merged["transitive_sources"] == [] + + +def test_recursive_transitive_json_includes_sources(tmp_path: Path, monkeypatch) -> None: + """Recursive combined JSON output records transitive source summaries.""" + root = tmp_path / "root" + root.mkdir() + for name in ("weather", "email"): + sub = root / name + sub.mkdir() + (sub / "SKILL.md").write_text(f"---\nname: {name}\n---\n", encoding="utf-8") + + calls: list[int] = [] + expected_sources = [ + "https://github.com/org/weather-transitive", + "https://github.com/org/email-transitive", + ] + expected_counts = [2, 1] + + def fake_scan_transitive(*args, **kwargs) -> dict[str, object]: + index = len(calls) + calls.append(index) + return { + "report_body": "{}", + "risk_score": 0, + "risk_severity": "LOW", + "transitive_finding_count": expected_counts[index], + "transitive_sources": [expected_sources[index]], + } + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache={"SKILL.md": "https://github.com/example/dummy.git"}, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + monkeypatch.setattr(cli, "_scan_transitive", fake_scan_transitive) + + out_file = root / "multi.json" + result = runner.invoke( + app, + [ + "scan", + str(root), + "--recursive", + "--format", + "json", + "--transitive", + "--output", + str(out_file), + "--no-llm", + ], + ) + assert result.exit_code == 0 + assert out_file.exists() + data = json.loads(out_file.read_text(encoding="utf-8")) + assert data["transitive_finding_count"] == sum(expected_counts) + assert sorted(data["transitive_sources"]) == sorted(expected_sources) + + +def test_recursive_transitive_reuses_shared_visited_set(tmp_path: Path, monkeypatch) -> None: + """Recursive scans reuse one visited set across sibling skills.""" + root = tmp_path / "root" + root.mkdir() + for name in ("weather", "email"): + sub = root / name + sub.mkdir() + (sub / "SKILL.md").write_text(f"---\nname: {name}\n---\n", encoding="utf-8") + + visited_snapshots: list[list[str]] = [] + + def fake_scan_transitive(*args, **kwargs) -> dict[str, object]: + visited = kwargs["visited"] + assert isinstance(visited, set) + visited_snapshots.append(sorted(str(item) for item in visited)) + visited.add(f"visit-{len(visited_snapshots)}") + return { + "report_body": "{}", + "risk_score": 0, + "risk_severity": "LOW", + "transitive_finding_count": 0, + "transitive_sources": [], + } + + def fake_run_graph_scan( + input_path: str, + format, + no_llm: bool, + yara_dir: str | None = None, + baseline=None, + show_suppressed: bool = False, + ) -> dict[str, object]: + return _mock_graph_result( + findings=[_finding("D1", "direct finding")], + file_cache={"SKILL.md": "https://github.com/example/dummy.git"}, + output_format=format.value, + ) + + monkeypatch.setattr(cli, "_run_graph_scan", fake_run_graph_scan) + monkeypatch.setattr(cli, "_scan_transitive", fake_scan_transitive) + + out_file = root / "multi.json" + result = runner.invoke( + app, + [ + "scan", + str(root), + "--recursive", + "--format", + "json", + "--transitive", + "--output", + str(out_file), + "--no-llm", + ], + ) + assert result.exit_code == 0 + assert visited_snapshots == [[], ["visit-1"]] diff --git a/tests/unit/test_transitive.py b/tests/unit/test_transitive.py new file mode 100644 index 00000000..eca02d8b --- /dev/null +++ b/tests/unit/test_transitive.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for transitive source extraction and traversal planning.""" + +from pathlib import Path + +import httpx + +from skillspector import input_handler as input_handler_module +from skillspector import transitive +from skillspector.input_handler import InputHandler + + +def test_plan_blocks_circular_reference() -> None: + """Visited identities block repeated canonical targets before second resolution.""" + refs = [ + "https://github.com/org/dup.git", + "git@github.com:org/dup.git", + "https://github.com/org/dup", + ] + visited: set[str] = set() + first = transitive.plan_transitive_targets( + refs, visited=visited, current_depth=1, max_depth=3, allow_prefixes=(), deny_prefixes=() + ) + second = transitive.plan_transitive_targets( + refs, visited=visited, current_depth=1, max_depth=3, allow_prefixes=(), deny_prefixes=() + ) + + assert first == ["https://github.com/org/dup"] + assert second == [] + assert visited == {"https://github.com/org/dup"} + + +def test_extract_excludes_badges_docs_and_issue_urls() -> None: + """Non-scan URLs should be filtered out, even when they look URL-like.""" + file_cache = { + "SKILL.md": ( + "badge https://img.shields.io/github/stars/user/repo?style=flat-square, " + "issue https://github.com/NVIDIA/SkillSpector/issues/12, " + "docs https://github.com/NVIDIA/SkillSpector/wiki, " + "ci https://github.com/NVIDIA/SkillSpector/actions, " + "src https://raw.githubusercontent.com/NVIDIA/SkillSpector/main/tool.py, " + "zip https://huggingface.co/abc/archive/main.zip" + ), + } + + refs = transitive.extract_external_refs(file_cache) + assert refs == [ + "https://raw.githubusercontent.com/NVIDIA/SkillSpector/main/tool.py", + "https://huggingface.co/abc/archive/main.zip", + ] + + +def test_extract_keeps_repos_with_reserved_word_names() -> None: + """Reserved UI words in org or repo names should not block valid repository targets.""" + file_cache = { + "SKILL.md": ( + "https://github.com/wiki-tools/skill.git " + "https://github.com/org/actions.git " + "https://github.com/badger/skill.git" + ), + } + + refs = transitive.extract_external_refs(file_cache) + assert refs == [ + "https://github.com/wiki-tools/skill", + "https://github.com/org/actions", + "https://github.com/badger/skill", + ] + + +def test_input_handler_treats_github_archive_zip_as_file_url() -> None: + """GitHub archive ZIP links should download as files, not route through git clone.""" + handler = InputHandler() + url = "https://github.com/org/repo/archive/refs/heads/main.zip" + + assert handler._is_git_url(url) is False + assert handler._is_file_url(url) is True + + +def test_input_handler_resolves_github_archive_zip_via_validated_redirect( + tmp_path: Path, monkeypatch +) -> None: + """GitHub archive ZIP redirects should still resolve as downloadable archives.""" + + class FakeResponse: + def __init__( + self, + status_code: int, + *, + headers: dict[str, str] | None = None, + content: bytes = b"", + ) -> None: + self.status_code = status_code + self.headers = headers or {} + self.content = content + + def raise_for_status(self) -> None: + if self.status_code >= 400: + request = httpx.Request("GET", "https://example.invalid") + response = httpx.Response( + self.status_code, + headers=self.headers, + content=self.content, + request=request, + ) + raise httpx.HTTPStatusError( + f"HTTP error {self.status_code}", request=request, response=response + ) + + class FakeClient: + def __init__(self, responses: list[FakeResponse], **kwargs) -> None: + self._responses = responses + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def get(self, url: str) -> FakeResponse: + return self._responses.pop(0) + + archive_url = "https://github.com/org/repo/archive/refs/heads/main.zip" + redirected_url = "https://codeload.github.com/org/repo/zip/refs/heads/main" + responses = [ + FakeResponse(302, headers={"location": redirected_url}), + FakeResponse(200, headers={"content-type": "application/zip"}, content=b"zip-bytes"), + ] + handler = InputHandler() + + monkeypatch.setattr(input_handler_module, "_is_private_ip", lambda host: False) + monkeypatch.setattr(httpx, "Client", lambda **kwargs: FakeClient(responses, **kwargs)) + monkeypatch.setattr(handler, "_extract_zip", lambda zip_path: tmp_path / Path(zip_path).stem) + + resolved_path, source_type = handler.resolve(archive_url) + + assert source_type == "url" + assert resolved_path == tmp_path / "download" + + +def test_plan_depth_limit_prevents_next_wave() -> None: + """When current depth exceeds max depth, no targets are returned.""" + refs = ["https://github.com/org/repo.git"] + visited: set[str] = set() + result = transitive.plan_transitive_targets( + refs=refs, + visited=visited, + current_depth=4, + max_depth=3, + allow_prefixes=(), + deny_prefixes=(), + ) + + assert result == [] + assert visited == set() + + +def test_plan_applies_allow_prefix() -> None: + """Only identities matching allow prefixes are returned.""" + refs = [ + "https://github.com/ok/repo.git", + "https://github.com/skip/repo.git", + ] + visited: set[str] = set() + allowed = ("https://github.com/ok/",) + + result = transitive.plan_transitive_targets( + refs=refs, + visited=visited, + current_depth=1, + max_depth=2, + allow_prefixes=allowed, + deny_prefixes=(), + ) + + assert result == ["https://github.com/ok/repo"] + + +def test_plan_allow_prefix_respects_path_boundaries() -> None: + """Allow prefixes should not match sibling org names sharing a string prefix.""" + refs = [ + "https://github.com/trusted/repo.git", + "https://github.com/trusted-malicious/repo.git", + ] + visited: set[str] = set() + + result = transitive.plan_transitive_targets( + refs=refs, + visited=visited, + current_depth=1, + max_depth=2, + allow_prefixes=("https://github.com/trusted/",), + deny_prefixes=(), + ) + + assert result == ["https://github.com/trusted/repo"] + + +def test_plan_applies_deny_prefix() -> None: + """Deny prefixes skip matching identities even if they are otherwise valid.""" + refs = [ + "https://github.com/ok/repo.git", + "https://github.com/skip/repo.git", + ] + visited: set[str] = set() + denied = ("https://github.com/skip/",) + + result = transitive.plan_transitive_targets( + refs=refs, + visited=visited, + current_depth=1, + max_depth=2, + allow_prefixes=(), + deny_prefixes=denied, + ) + + assert result == ["https://github.com/ok/repo"] + + +def test_plan_deny_prefix_respects_path_boundaries() -> None: + """Deny prefixes should not block sibling org names that only share a string prefix.""" + refs = [ + "https://github.com/trusted/repo.git", + "https://github.com/trusted-malicious/repo.git", + ] + visited: set[str] = set() + + result = transitive.plan_transitive_targets( + refs=refs, + visited=visited, + current_depth=1, + max_depth=2, + allow_prefixes=(), + deny_prefixes=("https://github.com/trusted/",), + ) + + assert result == ["https://github.com/trusted-malicious/repo"]