diff --git a/docs/how_tos/debugging/index.md b/docs/how_tos/debugging/index.md index 7a7b403..f483175 100644 --- a/docs/how_tos/debugging/index.md +++ b/docs/how_tos/debugging/index.md @@ -2,6 +2,18 @@ # Debugging This page describes how to debug certain Spark errors when using sparkctl. +```{eval-rst} +.. toctree:: + :maxdepth: 1 + + mcp_server +``` + +(ai-assisted-debugging)= +## AI-Assisted Debugging +sparkctl includes an MCP server that provides AI-assisted diagnosis of Spark failures. +See {ref}`mcp-server` for details on using the MCP server with AI assistants like Claude. + (spark-web-ui)= ## Spark web UI The web UI is a good first place to look for problems. Connect to ports 8080 and 4040 on the nodes diff --git a/docs/how_tos/debugging/mcp_server.md b/docs/how_tos/debugging/mcp_server.md new file mode 100644 index 0000000..4fab6e2 --- /dev/null +++ b/docs/how_tos/debugging/mcp_server.md @@ -0,0 +1,146 @@ +(mcp-server)= +# MCP Server for Log Analysis + +sparkctl includes an MCP (Model Context Protocol) server that provides AI-assisted diagnosis +of Spark job failures. The server analyzes logs from master, worker, executor, thrift-server, +and connect-server components to detect error patterns and suggest recovery actions. + +## Installation + +The MCP server requires the optional `mcp` dependency: + +```console +$ pip install 'sparkctl[mcp]' +``` + +## Running the Server + +Start the MCP server: + +```console +$ sparkctl-mcp-server +``` + +The server communicates over stdio using the MCP protocol. It is designed to be used with +AI assistants like Claude that support MCP. + +## Available Tools + +The MCP server provides four tools: + +### get_spark_logs + +Retrieve and aggregate Spark logs from the cluster. + +**Parameters:** +- `spark_scratch` (required): Path to the spark_scratch directory +- `log_type`: One of "master", "worker", "executor", "connect", "thrift", or "all" (default: "all") +- `app_id`: Filter executor logs by application ID +- `executor_id`: Filter by specific executor ID +- `tail_lines`: Number of lines from end of each log (default: 500) + +**Example use case:** "Show me the last 100 lines of executor logs for app-20240115120000-0000" + +### analyze_spark_failure + +Analyze logs for error patterns and provide diagnosis. This is the primary diagnostic tool. + +**Parameters:** +- `spark_scratch` (required): Path to the spark_scratch directory +- `app_id`: Specific application to analyze +- `include_stack_traces`: Include full stack traces (default: true) +- `max_errors`: Maximum errors to return (default: 50) + +**Detected error patterns:** +- Out of memory (OOM) +- Shuffle failures (FetchFailedException) +- Stage and task failures +- Resource exhaustion +- Connection/network issues +- Disk space issues +- Serialization errors +- Timeout errors + +**Example use case:** "Analyze why my Spark job failed" + +### get_recovery_suggestions + +Get prioritized recovery suggestions based on detected errors. + +**Parameters:** +- `error_types` (required): List of error types from analyze_spark_failure +- `current_config`: Current Spark configuration (optional) + +**Example use case:** "How do I fix the OOM errors you found?" + +### list_spark_applications + +List Spark applications found in spark_scratch. + +**Parameters:** +- `spark_scratch` (required): Path to the spark_scratch directory + +**Example use case:** "What applications have run in this cluster?" + +## Integration with torc + +The sparkctl MCP server is designed to work alongside [torc](https://github.com/NREL/torc)'s +`analyze_workflow_logs` tool for full-stack diagnostics: + +| Layer | Tool | Diagnostics | +|-------|------|-------------| +| Application | sparkctl MCP | Spark-specific: OOM, shuffle, stage failures, serialization | +| Infrastructure | torc MCP | System-level: Slurm errors, node failures, filesystem issues | + +When sparkctl detects system-level issues (Slurm cancellation, filesystem errors, node health +problems), it will recommend using torc's analyze_workflow_logs tool for further investigation. + +## Example Workflow + +1. A Spark job fails on your HPC cluster +2. Ask your AI assistant: "Analyze my failed Spark job in ./spark_scratch" +3. The assistant uses `analyze_spark_failure` to detect error patterns +4. It identifies OOM errors in executors and shuffle failures +5. The assistant uses `get_recovery_suggestions` to get fixes +6. You apply the suggested configuration changes and rerun + +## Direct Python Usage + +The MCP tools can also be used directly in Python without the MCP server: + +```python +from sparkctl.mcp_server import ( + analyze_spark_failure, + get_recovery_suggestions, + get_spark_logs, + list_spark_applications, +) + +# Analyze failures +analysis = analyze_spark_failure("./spark_scratch") +print(f"Root cause: {analysis.likely_root_cause}") +print(f"Errors: {analysis.error_summary}") + +# Get recovery suggestions +suggestions = get_recovery_suggestions(list(analysis.error_summary.keys())) +for s in suggestions.suggestions: + print(f"[{s.priority}] {s.title}") + if s.sparkctl_command: + print(f" Run: {s.sparkctl_command}") +``` + +## Claude Code Configuration + +To use the sparkctl MCP server with Claude Code, add it to your MCP configuration. The server +requires no arguments and communicates over stdio. + +```json +{ + "mcpServers": { + "sparkctl": { + "command": "sparkctl-mcp-server", + "args": [] + } + } +} +``` diff --git a/docs/reference/index.md b/docs/reference/index.md index 892a16f..c01932a 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -8,6 +8,7 @@ :hidden: sparkctl_api + mcp_server_api hpc/index cli_reference ``` diff --git a/docs/reference/mcp_server_api.md b/docs/reference/mcp_server_api.md new file mode 100644 index 0000000..fe1e9eb --- /dev/null +++ b/docs/reference/mcp_server_api.md @@ -0,0 +1,69 @@ +(mcp-server-api)= + +# MCP Server API + +The MCP server module provides tools for AI-assisted diagnosis of Spark job failures. + +## Tools + +These functions can be used directly in Python or through the MCP server. + +```{eval-rst} +.. autofunction:: sparkctl.mcp_server.get_spark_logs +``` + +```{eval-rst} +.. autofunction:: sparkctl.mcp_server.analyze_spark_failure +``` + +```{eval-rst} +.. autofunction:: sparkctl.mcp_server.get_recovery_suggestions +``` + +```{eval-rst} +.. autofunction:: sparkctl.mcp_server.list_spark_applications +``` + +## Response Models + +```{eval-rst} +.. autopydantic_model:: sparkctl.mcp_server.models.SparkLogsResponse + :members: +``` + +```{eval-rst} +.. autopydantic_model:: sparkctl.mcp_server.models.SparkFailureAnalysis + :members: +``` + +```{eval-rst} +.. autopydantic_model:: sparkctl.mcp_server.models.RecoverySuggestions + :members: +``` + +```{eval-rst} +.. autopydantic_model:: sparkctl.mcp_server.models.SparkApplicationList + :members: +``` + +## Utilities + +```{eval-rst} +.. autoclass:: sparkctl.mcp_server.SparkLogParser + :members: +``` + +```{eval-rst} +.. autoclass:: sparkctl.mcp_server.SparkLogLocator + :members: +``` + +```{eval-rst} +.. autoclass:: sparkctl.mcp_server.ErrorPatternRegistry + :members: +``` + +```{eval-rst} +.. autoclass:: sparkctl.mcp_server.RecoveryEngine + :members: +``` diff --git a/pyproject.toml b/pyproject.toml index 5d8a740..564f994 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ dependencies = [ pyspark = [ "pyspark == 4.0.0", ] +mcp = [ + "mcp >= 1.7.0", +] dev = [ "furo", "mypy >=1.13, < 2", @@ -64,6 +67,7 @@ Source = "https://github.com/NREL/sparkctl" [project.scripts] sparkctl = "sparkctl.cli.sparkctl:cli" +sparkctl-mcp-server = "sparkctl.mcp_server.server:main" [tool.setuptools.packages.find] where = ["src"] diff --git a/src/sparkctl/mcp_server/__init__.py b/src/sparkctl/mcp_server/__init__.py new file mode 100644 index 0000000..4640e87 --- /dev/null +++ b/src/sparkctl/mcp_server/__init__.py @@ -0,0 +1,62 @@ +"""MCP Server for sparkctl Spark job failure diagnostics. + +This module provides an MCP (Model Context Protocol) server that diagnoses +Spark job failures by analyzing logs from master, worker, executor, +thrift-server, and connect-server components. + +Tools provided: +- get_spark_logs: Retrieve and aggregate Spark logs +- analyze_spark_failure: Detect error patterns and diagnose issues +- get_recovery_suggestions: Get remediation suggestions for detected errors +- list_spark_applications: List Spark applications in spark_scratch + +Usage: + Run the MCP server with: sparkctl-mcp-server + +The server is designed to work alongside torc's analyze_workflow_logs tool +for full-stack diagnostics (Spark + Slurm/system-level). +""" + +from sparkctl.mcp_server.error_patterns import ErrorPatternRegistry +from sparkctl.mcp_server.log_parser import SparkLogLocator, SparkLogParser +from sparkctl.mcp_server.models import ( + ErrorCategory, + ErrorOccurrence, + LogEntry, + RecoverySuggestions, + SparkApplication, + SparkApplicationList, + SparkFailureAnalysis, + SparkLogsResponse, + Suggestion, +) +from sparkctl.mcp_server.recovery import RecoveryEngine +from sparkctl.mcp_server.tools import ( + analyze_spark_failure, + get_recovery_suggestions, + get_spark_logs, + list_spark_applications, +) + +__all__ = [ + # Models + "ErrorCategory", + "ErrorOccurrence", + "LogEntry", + "RecoverySuggestions", + "SparkApplication", + "SparkApplicationList", + "SparkFailureAnalysis", + "SparkLogsResponse", + "Suggestion", + # Tools (can be used directly without MCP) + "analyze_spark_failure", + "get_recovery_suggestions", + "get_spark_logs", + "list_spark_applications", + # Utilities + "ErrorPatternRegistry", + "RecoveryEngine", + "SparkLogLocator", + "SparkLogParser", +] diff --git a/src/sparkctl/mcp_server/error_patterns.py b/src/sparkctl/mcp_server/error_patterns.py new file mode 100644 index 0000000..d23b292 --- /dev/null +++ b/src/sparkctl/mcp_server/error_patterns.py @@ -0,0 +1,303 @@ +"""Error pattern registry for classifying Spark log errors.""" + +import re +from dataclasses import dataclass, field +from typing import Literal + +from sparkctl.mcp_server.models import ErrorCategory + + +@dataclass +class ErrorPattern: + """Defines an error pattern with its classification.""" + + category: ErrorCategory + severity: Literal["critical", "error", "warning"] + patterns: list[re.Pattern[str]] + description: str + common_causes: list[str] = field(default_factory=list) + + +class ErrorPatternRegistry: + """Registry of known Spark error patterns.""" + + PATTERNS: list[ErrorPattern] = [ + # OOM Errors + ErrorPattern( + category=ErrorCategory.OOM, + severity="critical", + patterns=[ + re.compile(r"java\.lang\.OutOfMemoryError", re.IGNORECASE), + re.compile(r"GC overhead limit exceeded"), + re.compile(r"Java heap space"), + re.compile(r"Unable to acquire.*memory", re.IGNORECASE), + re.compile(r"Required executor memory.*is above", re.IGNORECASE), + re.compile(r"Container killed by YARN for exceeding memory limits"), + re.compile(r"ExecutorLostFailure.*memory", re.IGNORECASE), + ], + description="Out of memory error in JVM heap or off-heap", + common_causes=[ + "Executor memory too small for data", + "Memory leak in UDF", + "Too many partitions cached", + "Large broadcast variables", + "Data skew causing some tasks to process much more data", + ], + ), + # Shuffle Failures + ErrorPattern( + category=ErrorCategory.SHUFFLE, + severity="error", + patterns=[ + re.compile(r"FetchFailedException"), + re.compile(r"Failed to connect to.*shuffle", re.IGNORECASE), + re.compile(r"Too Large Frame.*shuffle", re.IGNORECASE), + re.compile(r"Shuffle.*failed", re.IGNORECASE), + re.compile(r"ShuffleBlockFetcherIterator.*failed", re.IGNORECASE), + re.compile(r"Unable to fetch.*shuffle", re.IGNORECASE), + re.compile(r"MetadataFetchFailedException"), + ], + description="Shuffle data transfer failed between executors", + common_causes=[ + "Executor lost during shuffle", + "Network connectivity issues", + "Disk full on shuffle storage", + "Shuffle block too large (>2GB)", + "Too few shuffle partitions", + ], + ), + # Stage Failures + ErrorPattern( + category=ErrorCategory.STAGE, + severity="error", + patterns=[ + re.compile(r"Lost task \d+\.\d+ in stage"), + re.compile(r"Stage \d+ \(.*\) failed"), + re.compile(r"Job aborted due to stage failure"), + re.compile(r"TaskSetManager.*Lost task"), + re.compile(r"Task \d+ failed \d+ times"), + re.compile(r"Aborting job.*stage.*failed"), + ], + description="Task or stage execution failed", + common_causes=[ + "Exception in task code", + "Executor failure", + "Data corruption", + "Resource constraints", + "Task exceeded max retries", + ], + ), + # Resource Issues + ErrorPattern( + category=ErrorCategory.RESOURCE, + severity="critical", + patterns=[ + re.compile(r"requires more resource than any of Workers", re.IGNORECASE), + re.compile(r"Not enough workers", re.IGNORECASE), + re.compile(r"Could not find.*worker", re.IGNORECASE), + re.compile(r"No workers available", re.IGNORECASE), + re.compile(r"Executor.*lost", re.IGNORECASE), + re.compile(r"Worker.*removed", re.IGNORECASE), + re.compile(r"Initial job has not accepted any resources"), + ], + description="Cluster resources insufficient for job", + common_causes=[ + "Cluster too small for requested resources", + "Resource configuration mismatch", + "Workers crashed or became unhealthy", + "Memory requirements exceed node capacity", + ], + ), + # Connection Failures + ErrorPattern( + category=ErrorCategory.CONNECTION, + severity="error", + patterns=[ + re.compile(r"Unable to connect to", re.IGNORECASE), + re.compile(r"Connection refused"), + re.compile(r"Connection reset by peer"), + re.compile(r"Failed to connect", re.IGNORECASE), + re.compile(r"java\.net\.ConnectException"), + re.compile(r"java\.net\.SocketException"), + re.compile(r"Connection timed out"), + ], + description="Network connection failure", + common_causes=[ + "Node unreachable", + "Firewall blocking ports", + "Service not running", + "DNS resolution failure", + "Network partition", + ], + ), + # Serialization Errors + ErrorPattern( + category=ErrorCategory.SERIALIZATION, + severity="error", + patterns=[ + re.compile(r"NotSerializableException"), + re.compile(r"Task not serializable"), + re.compile(r"java\.io\.NotSerializableException"), + re.compile(r"Kryo.*serialization", re.IGNORECASE), + re.compile(r"Failed to serialize", re.IGNORECASE), + re.compile(r"InvalidClassException"), + ], + description="Object serialization failed", + common_causes=[ + "Non-serializable object in closure", + "Missing serializer registration", + "Complex object graph with circular references", + "Class version mismatch", + ], + ), + # Disk Failures + ErrorPattern( + category=ErrorCategory.DISK, + severity="critical", + patterns=[ + re.compile(r"No space left on device"), + re.compile(r"IOException.*disk", re.IGNORECASE), + re.compile(r"DiskBlockManager.*failed", re.IGNORECASE), + re.compile(r"Failed to write", re.IGNORECASE), + re.compile(r"Disk I/O error", re.IGNORECASE), + re.compile(r"BlockManager.*failed to persist", re.IGNORECASE), + ], + description="Disk I/O failure", + common_causes=[ + "Disk full", + "Disk hardware failure", + "Insufficient temp space for shuffle", + "Permission issues on storage directory", + ], + ), + # Timeout Errors + ErrorPattern( + category=ErrorCategory.TIMEOUT, + severity="warning", + patterns=[ + re.compile(r"TimeoutException"), + re.compile(r"timed out", re.IGNORECASE), + re.compile(r"heartbeat.*timeout", re.IGNORECASE), + re.compile(r"Connection timed out"), + re.compile(r"RPC.*timeout", re.IGNORECASE), + re.compile(r"Executor heartbeat timed out"), + ], + description="Operation timeout", + common_causes=[ + "Slow network", + "Overloaded node", + "Long GC pauses", + "Timeout configured too short", + "Executor doing expensive computation", + ], + ), + # Configuration Errors + ErrorPattern( + category=ErrorCategory.CONFIGURATION, + severity="error", + patterns=[ + re.compile(r"Invalid.*configuration", re.IGNORECASE), + re.compile(r"ClassNotFoundException"), + re.compile(r"NoSuchMethodError"), + re.compile(r"NoClassDefFoundError"), + re.compile(r"IllegalArgumentException.*config", re.IGNORECASE), + re.compile(r"spark\..*is not set", re.IGNORECASE), + ], + description="Configuration or classpath error", + common_causes=[ + "Missing JAR dependency", + "Incompatible library versions", + "Invalid configuration value", + "Missing required configuration", + ], + ), + ] + + # Patterns that suggest system-level issues (recommend torc analysis) + TORC_INDICATORS: dict[str, list[re.Pattern[str]]] = { + "slurm": [ + re.compile(r"SLURM.*CANCELLED", re.IGNORECASE), + re.compile(r"srun:.*error", re.IGNORECASE), + re.compile(r"Job.*exceeded.*limit", re.IGNORECASE), + re.compile(r"PREEMPTED", re.IGNORECASE), + ], + "filesystem": [ + re.compile(r"Permission denied"), + re.compile(r"Stale file handle"), + re.compile(r"Input/output error"), + re.compile(r"Read-only file system"), + ], + "node_health": [ + re.compile(r"Node.*unhealthy", re.IGNORECASE), + re.compile(r"Lost executor on.*Node", re.IGNORECASE), + re.compile(r"Could not reach", re.IGNORECASE), + re.compile(r"Node blacklisted", re.IGNORECASE), + ], + } + + @classmethod + def classify_error(cls, text: str) -> ErrorPattern | None: + """Classify text by matching against known error patterns. + + Parameters + ---------- + text + The log text to classify (can be multi-line). + + Returns + ------- + ErrorPattern | None + The matching ErrorPattern or None if no match. + """ + for pattern in cls.PATTERNS: + for regex in pattern.patterns: + if regex.search(text): + return pattern + return None + + @classmethod + def should_recommend_torc(cls, text: str) -> tuple[bool, str | None]: + """Determine if torc analysis should be recommended based on log text. + + Parameters + ---------- + text + The log text to check. + + Returns + ------- + tuple[bool, str | None] + Tuple of (should_recommend, reason). + """ + for category, patterns in cls.TORC_INDICATORS.items(): + for pattern in patterns: + if pattern.search(text): + reason_map = { + "slurm": "Slurm job management issue detected", + "filesystem": "Filesystem/permission issue detected", + "node_health": "Node health issue detected", + } + return True, reason_map.get(category, f"{category} issue detected") + return False, None + + @classmethod + def get_root_cause_priority(cls) -> dict[ErrorCategory, int]: + """Return priority ordering for root cause determination. + + Returns + ------- + dict[ErrorCategory, int] + Priority ordering where lower number = more likely to be root cause. + """ + return { + ErrorCategory.OOM: 1, + ErrorCategory.RESOURCE: 2, + ErrorCategory.DISK: 3, + ErrorCategory.SHUFFLE: 4, + ErrorCategory.CONNECTION: 5, + ErrorCategory.SERIALIZATION: 6, + ErrorCategory.STAGE: 7, + ErrorCategory.TIMEOUT: 8, + ErrorCategory.CONFIGURATION: 9, + ErrorCategory.UNKNOWN: 10, + } diff --git a/src/sparkctl/mcp_server/log_parser.py b/src/sparkctl/mcp_server/log_parser.py new file mode 100644 index 0000000..8f65f78 --- /dev/null +++ b/src/sparkctl/mcp_server/log_parser.py @@ -0,0 +1,298 @@ +"""Log parsing utilities for Spark logs.""" + +import re +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + + +@dataclass +class ParsedLogEntry: + """A parsed log entry from a Spark log file.""" + + timestamp: datetime | None + level: str # INFO, WARN, ERROR, DEBUG, FATAL + logger: str # e.g., "Master", "Worker", "BlockManager" + message: str + full_text: str # Original text including stack trace + line_start: int + line_end: int + + +@dataclass +class LogFileInfo: + """Information about a log file.""" + + path: Path + source_type: str # master, worker, executor, connect, thrift + source_id: str # e.g., "0" for executor-0 + last_modified: datetime | None = None + + +class SparkLogLocator: + """Discovers Spark log files from spark_scratch directory.""" + + def __init__(self, spark_scratch: Path): + self.spark_scratch = Path(spark_scratch) + self.logs_dir = self.spark_scratch / "logs" + self.workers_dir = self.spark_scratch / "workers" + + def get_all_log_files(self) -> list[LogFileInfo]: + """Get all log files with their metadata.""" + files: list[LogFileInfo] = [] + files.extend(self.get_master_logs()) + files.extend(self.get_worker_logs()) + files.extend(self.get_connect_logs()) + files.extend(self.get_thrift_logs()) + files.extend(self.get_executor_logs()) + return files + + def get_master_logs(self) -> list[LogFileInfo]: + """Get master log files. + + Returns + ------- + list[LogFileInfo] + Master log files matching pattern: + spark-{user}-org.apache.spark.deploy.master.Master-*.out + """ + if not self.logs_dir.exists(): + return [] + files = [] + for path in self.logs_dir.glob("spark-*-org.apache.spark.deploy.master.Master-*.out"): + files.append(self._make_log_info(path, "master", "0")) + return files + + def get_worker_logs(self) -> list[LogFileInfo]: + """Get worker log files. + + Returns + ------- + list[LogFileInfo] + Worker log files matching pattern: + spark-{user}-org.apache.spark.deploy.worker.Worker-*.out + """ + if not self.logs_dir.exists(): + return [] + files = [] + for path in self.logs_dir.glob("spark-*-org.apache.spark.deploy.worker.Worker-*.out"): + # Extract worker instance from filename + match = re.search(r"Worker-(\d+)-", path.name) + worker_id = match.group(1) if match else "0" + files.append(self._make_log_info(path, "worker", worker_id)) + return files + + def get_connect_logs(self) -> list[LogFileInfo]: + """Get Spark Connect server log files. + + Returns + ------- + list[LogFileInfo] + Connect server log files matching pattern: + spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out + """ + if not self.logs_dir.exists(): + return [] + files = [] + for path in self.logs_dir.glob("spark-*-org.apache.spark.sql.connect.*.out"): + files.append(self._make_log_info(path, "connect", "0")) + return files + + def get_thrift_logs(self) -> list[LogFileInfo]: + """Get Thrift server log files. + + Returns + ------- + list[LogFileInfo] + Thrift server log files matching pattern: + spark-*-org.apache.spark.sql.hive.thriftserver.*.out + """ + if not self.logs_dir.exists(): + return [] + files = [] + for path in self.logs_dir.glob("spark-*-org.apache.spark.sql.hive.thriftserver.*.out"): + files.append(self._make_log_info(path, "thrift", "0")) + return files + + def get_executor_logs(self, app_id: str | None = None) -> list[LogFileInfo]: + """Get executor log files. + + Parameters + ---------- + app_id + Optional application ID to filter executor logs. + + Returns + ------- + list[LogFileInfo] + Executor log files matching pattern: + spark_scratch/workers/app-{id}/{executor}/stderr + """ + if not self.workers_dir.exists(): + return [] + files = [] + pattern = f"{app_id}/*/stderr" if app_id else "app-*/*/stderr" + for path in self.workers_dir.glob(pattern): + # Extract app_id and executor_id from path + parts = path.parts + exec_id = parts[-2] # The executor directory name + app = parts[-3] # The app directory name + source_id = f"{app}/{exec_id}" + files.append(self._make_log_info(path, "executor", source_id)) + return files + + def get_application_ids(self) -> list[str]: + """Get list of application IDs from executor logs.""" + if not self.workers_dir.exists(): + return [] + app_ids = set() + for path in self.workers_dir.glob("app-*/"): + if path.is_dir(): + app_ids.add(path.name) + return sorted(app_ids) + + def _make_log_info(self, path: Path, source_type: str, source_id: str) -> LogFileInfo: + """Create LogFileInfo with metadata.""" + try: + mtime = datetime.fromtimestamp(path.stat().st_mtime) + except OSError: + mtime = None + return LogFileInfo( + path=path, + source_type=source_type, + source_id=source_id, + last_modified=mtime, + ) + + +class SparkLogParser: + """Parses Spark log files handling multiline entries (stack traces).""" + + # Matches: 25/01/15 12:00:00 INFO Master: Starting Spark... + LOG_PATTERN = re.compile( + r"^(\d{2}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}) " # timestamp + r"(DEBUG|INFO|WARN|ERROR|FATAL) " # level + r"(\S+): " # logger name + r"(.*)$" # message + ) + + def parse_file(self, filepath: Path, tail_lines: int | None = None) -> list[ParsedLogEntry]: + """Parse log file, handling multiline entries (stack traces). + + Parameters + ---------- + filepath + Path to the log file. + tail_lines + If specified, only read the last N lines. + + Returns + ------- + list[ParsedLogEntry] + List of parsed log entries. + """ + lines = self._read_lines(filepath, tail_lines) + return self._parse_lines(lines) + + def _parse_lines(self, lines: list[str]) -> list[ParsedLogEntry]: + """Parse lines into log entries, handling multiline stack traces.""" + entries: list[ParsedLogEntry] = [] + current_lines: list[str] = [] + current_start = 0 + + for i, line in enumerate(lines): + if self.LOG_PATTERN.match(line): + # New log entry starts + if current_lines: + entries.append(self._create_entry(current_lines, current_start, i - 1)) + current_lines = [line] + current_start = i + else: + # Continuation of previous entry (stack trace or wrapped line) + current_lines.append(line) + + # Don't forget the last entry + if current_lines: + entries.append(self._create_entry(current_lines, current_start, len(lines) - 1)) + + return entries + + def _read_lines(self, filepath: Path, tail_lines: int | None) -> list[str]: + """Read lines from file, optionally tailing. + + Parameters + ---------- + filepath + Path to the file. + tail_lines + If specified, only return the last N lines. + + Returns + ------- + list[str] + List of lines from the file. + """ + try: + content = filepath.read_text(encoding="utf-8", errors="replace") + except OSError: + return [] + + all_lines = content.splitlines() + + if tail_lines is None or tail_lines >= len(all_lines): + return all_lines + + return all_lines[-tail_lines:] + + def _create_entry(self, lines: list[str], start: int, end: int) -> ParsedLogEntry: + """Create ParsedLogEntry from accumulated lines.""" + full_text = "\n".join(lines) + match = self.LOG_PATTERN.match(lines[0]) + + if match: + ts_str, level, logger, message = match.groups() + try: + # Handle year ambiguity - assume 20xx + timestamp = datetime.strptime(f"20{ts_str}", "%Y/%m/%d %H:%M:%S") + except ValueError: + timestamp = None + else: + timestamp, level, logger, message = None, "UNKNOWN", "unknown", lines[0] + + return ParsedLogEntry( + timestamp=timestamp, + level=level, + logger=logger, + message=message, + full_text=full_text, + line_start=start, + line_end=end, + ) + + def filter_by_level( + self, entries: list[ParsedLogEntry], levels: list[str] + ) -> list[ParsedLogEntry]: + """Filter entries by log level. + + Parameters + ---------- + entries + List of parsed log entries. + levels + List of levels to include (e.g., ["ERROR", "WARN"]). + + Returns + ------- + list[ParsedLogEntry] + Filtered list of entries. + """ + levels_upper = [x.upper() for x in levels] + return [e for e in entries if e.level in levels_upper] + + def get_errors_only(self, entries: list[ParsedLogEntry]) -> list[ParsedLogEntry]: + """Get only ERROR and FATAL entries.""" + return self.filter_by_level(entries, ["ERROR", "FATAL"]) + + def get_warnings_and_errors(self, entries: list[ParsedLogEntry]) -> list[ParsedLogEntry]: + """Get WARN, ERROR, and FATAL entries.""" + return self.filter_by_level(entries, ["WARN", "ERROR", "FATAL"]) diff --git a/src/sparkctl/mcp_server/models.py b/src/sparkctl/mcp_server/models.py new file mode 100644 index 0000000..8f85d23 --- /dev/null +++ b/src/sparkctl/mcp_server/models.py @@ -0,0 +1,170 @@ +"""Pydantic models for MCP server responses.""" + +from datetime import datetime +from enum import StrEnum +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class MCPBaseModel(BaseModel): + """Base model for MCP response models.""" + + model_config = ConfigDict( + str_strip_whitespace=True, + extra="forbid", + use_enum_values=True, + ) + + +class ErrorCategory(StrEnum): + """Categories of Spark errors.""" + + OOM = "out_of_memory" + SHUFFLE = "shuffle_failure" + STAGE = "stage_failure" + RESOURCE = "resource_exhaustion" + CONNECTION = "connection_failure" + SERIALIZATION = "serialization_error" + DISK = "disk_failure" + TIMEOUT = "timeout" + CONFIGURATION = "configuration_error" + UNKNOWN = "unknown" + + +# --- Log retrieval models --- + + +class LogEntry(MCPBaseModel): + """A single log file's content.""" + + source: str = Field(description="Log source type (e.g., 'master', 'worker-0', 'executor-1')") + filepath: str = Field(description="Path to the log file") + content: str = Field(description="Log content (possibly truncated)") + line_count: int = Field(description="Number of lines in the returned content") + last_modified: datetime | None = Field( + default=None, description="Last modification time of the log file" + ) + + +class SparkLogsResponse(MCPBaseModel): + """Response from get_spark_logs tool.""" + + logs: list[LogEntry] = Field(default_factory=list, description="List of log entries") + total_entries: int = Field(description="Total number of log files found") + truncated: bool = Field(default=False, description="Whether any log content was truncated") + spark_scratch: str = Field(description="Path to spark_scratch directory used") + + +# --- Failure analysis models --- + + +class ErrorOccurrence(MCPBaseModel): + """A detected error in the logs.""" + + error_type: ErrorCategory = Field(description="Category of the error") + severity: Literal["critical", "error", "warning"] = Field( + description="Severity level of the error" + ) + message: str = Field(description="Error message or summary") + stack_trace: str | None = Field(default=None, description="Full stack trace if available") + source_file: str = Field(description="Log file where error was found") + line_number: int = Field(description="Line number in the log file") + timestamp: datetime | None = Field( + default=None, description="Timestamp of the error if parseable" + ) + context_lines: list[str] = Field(default_factory=list, description="Surrounding context lines") + + +class SparkFailureAnalysis(MCPBaseModel): + """Response from analyze_spark_failure tool.""" + + app_id: str | None = Field(default=None, description="Application ID if detected") + errors: list[ErrorOccurrence] = Field( + default_factory=list, description="List of detected errors" + ) + error_summary: dict[str, int] = Field( + default_factory=dict, description="Count of errors by type" + ) + likely_root_cause: str | None = Field( + default=None, description="Most likely root cause based on error patterns" + ) + affected_stages: list[str] = Field( + default_factory=list, description="Stage IDs that had failures" + ) + affected_executors: list[str] = Field( + default_factory=list, description="Executor IDs that had failures" + ) + analysis_timestamp: datetime = Field( + default_factory=datetime.now, description="When analysis was performed" + ) + recommend_torc_analysis: bool = Field( + default=False, + description="Whether torc's analyze_workflow_logs should be used for system-level issues", + ) + torc_recommendation: str | None = Field( + default=None, description="Specific recommendation for torc analysis" + ) + + +# --- Recovery suggestion models --- + + +class Suggestion(MCPBaseModel): + """A recovery suggestion.""" + + priority: int = Field(description="Priority (1 = highest)") + category: str = Field( + description="Category (e.g., 'memory', 'shuffle', 'resource', 'configuration')" + ) + title: str = Field(description="Short title for the suggestion") + description: str = Field(description="Detailed description of the fix") + config_changes: dict[str, str] | None = Field( + default=None, description="Spark configuration changes to apply" + ) + sparkctl_command: str | None = Field( + default=None, description="sparkctl command to apply the fix" + ) + estimated_impact: str = Field(description="Expected impact of applying this fix") + + +class RecoverySuggestions(MCPBaseModel): + """Response from get_recovery_suggestions tool.""" + + suggestions: list[Suggestion] = Field( + default_factory=list, description="Ordered list of recovery suggestions" + ) + requires_cluster_restart: bool = Field( + default=False, description="Whether applying these changes requires a cluster restart" + ) + recommend_torc_analysis: bool = Field( + default=False, description="Whether to recommend torc analysis for system-level issues" + ) + torc_recommendation: str | None = Field( + default=None, description="Specific guidance for using torc" + ) + + +# --- Application listing models --- + + +class SparkApplication(MCPBaseModel): + """Information about a Spark application.""" + + app_id: str = Field(description="Application ID (e.g., 'app-20240115120000-0000')") + start_time: datetime | None = Field(default=None, description="When the application started") + has_executor_logs: bool = Field( + default=False, description="Whether executor logs exist for this app" + ) + executor_count: int = Field(default=0, description="Number of executors found") + log_files: list[str] = Field(default_factory=list, description="Paths to associated log files") + + +class SparkApplicationList(MCPBaseModel): + """Response from list_spark_applications tool.""" + + applications: list[SparkApplication] = Field( + default_factory=list, description="List of Spark applications" + ) + total_count: int = Field(description="Total number of applications found") + spark_scratch: str = Field(description="Path to spark_scratch directory used") diff --git a/src/sparkctl/mcp_server/recovery.py b/src/sparkctl/mcp_server/recovery.py new file mode 100644 index 0000000..bdb286d --- /dev/null +++ b/src/sparkctl/mcp_server/recovery.py @@ -0,0 +1,399 @@ +"""Recovery suggestion engine for Spark failures.""" + +from sparkctl.mcp_server.models import ( + ErrorCategory, + RecoverySuggestions, + Suggestion, +) + + +class RecoveryEngine: + """Generates recovery suggestions based on detected errors.""" + + # Recovery strategies organized by error category + RECOVERY_STRATEGIES: dict[ErrorCategory, list[Suggestion]] = { + ErrorCategory.OOM: [ + Suggestion( + priority=1, + category="memory", + title="Increase executor memory", + description=( + "Allocate more memory per executor to handle larger data. " + "Double the current executor memory as a starting point." + ), + config_changes={"spark.executor.memory": "16g"}, + sparkctl_command="sparkctl configure --executor-memory-gb 16", + estimated_impact="High - directly addresses memory pressure", + ), + Suggestion( + priority=2, + category="memory", + title="Reduce executor cores", + description=( + "Fewer cores per executor means more memory per task. " + "This trades parallelism for memory headroom." + ), + config_changes={"spark.executor.cores": "2"}, + sparkctl_command="sparkctl configure --executor-cores 2", + estimated_impact="Medium - trades parallelism for memory headroom", + ), + Suggestion( + priority=3, + category="shuffle", + title="Increase shuffle partitions", + description=( + "More partitions means smaller data chunks per task. " + "Increase the shuffle partition multiplier to spread data more evenly." + ), + config_changes={"spark.sql.shuffle.partitions": "400"}, + sparkctl_command="sparkctl configure --shuffle-partition-multiplier 2", + estimated_impact="Medium - reduces memory per partition", + ), + Suggestion( + priority=4, + category="memory", + title="Enable off-heap memory", + description=( + "Use off-heap memory for caching and shuffle to reduce GC pressure. " + "Requires additional configuration in spark-defaults.conf." + ), + config_changes={ + "spark.memory.offHeap.enabled": "true", + "spark.memory.offHeap.size": "4g", + }, + sparkctl_command=None, + estimated_impact="Medium - reduces GC pressure", + ), + ], + ErrorCategory.SHUFFLE: [ + Suggestion( + priority=1, + category="shuffle", + title="Increase shuffle partitions", + description=( + "Shuffle block too large errors occur when partitions exceed 2GB. " + "Increase partitions to reduce individual shuffle block sizes." + ), + config_changes={"spark.sql.shuffle.partitions": "800"}, + sparkctl_command="sparkctl configure --shuffle-partition-multiplier 4", + estimated_impact="High - prevents 2GB shuffle block limit", + ), + Suggestion( + priority=2, + category="shuffle", + title="Enable shuffle compression", + description=("Compress shuffle data to reduce network transfer and disk usage."), + config_changes={ + "spark.shuffle.compress": "true", + "spark.shuffle.spill.compress": "true", + }, + sparkctl_command=None, + estimated_impact="Medium - reduces shuffle data size", + ), + Suggestion( + priority=3, + category="resource", + title="Use local storage for shuffle", + description=( + "Use fast local disks instead of shared filesystem for shuffle. " + "Significantly improves shuffle I/O performance." + ), + config_changes=None, + sparkctl_command="sparkctl configure --local-storage", + estimated_impact="High - faster shuffle I/O", + ), + Suggestion( + priority=4, + category="network", + title="Increase network timeout", + description=("Allow more time for shuffle block transfers on slow networks."), + config_changes={ + "spark.network.timeout": "600s", + "spark.shuffle.io.maxRetries": "6", + }, + sparkctl_command=None, + estimated_impact="Medium - prevents premature timeouts", + ), + ], + ErrorCategory.RESOURCE: [ + Suggestion( + priority=1, + category="resource", + title="Reduce executor resource requirements", + description=( + "Lower executor memory/cores to fit within worker capacity. " + "Check that requested resources don't exceed any single worker's limits." + ), + config_changes=None, + sparkctl_command="sparkctl configure --executor-cores 4 --executor-memory-gb 8", + estimated_impact="High - ensures jobs can be scheduled", + ), + Suggestion( + priority=2, + category="resource", + title="Enable dynamic allocation", + description=( + "Let Spark scale executors based on workload demand. " + "Helps with resource utilization and prevents resource starvation." + ), + config_changes=None, + sparkctl_command="sparkctl configure --dynamic-allocation", + estimated_impact="Medium - better resource utilization", + ), + Suggestion( + priority=3, + category="resource", + title="Check cluster health", + description=( + "Verify all workers are running and healthy. " + "Use sparkctl status or check Spark UI for worker status." + ), + config_changes=None, + sparkctl_command=None, + estimated_impact="High - identifies infrastructure issues", + ), + ], + ErrorCategory.DISK: [ + Suggestion( + priority=1, + category="disk", + title="Use local storage", + description=( + "Use node-local storage for shuffle and temp data. " + "Provides dedicated disk space separate from shared filesystem." + ), + config_changes=None, + sparkctl_command="sparkctl configure --local-storage", + estimated_impact="High - dedicated local disk space", + ), + Suggestion( + priority=2, + category="disk", + title="Clean up scratch directory", + description=( + "Remove old application data from spark_scratch to free disk space. " + "Old executor logs and shuffle data can accumulate." + ), + config_changes=None, + sparkctl_command="rm -rf spark_scratch/workers/app-* spark_scratch/local/*", + estimated_impact="High - frees disk space immediately", + ), + Suggestion( + priority=3, + category="shuffle", + title="Enable shuffle compression", + description=("Compress shuffle data to reduce disk usage during spills."), + config_changes={ + "spark.shuffle.compress": "true", + "spark.shuffle.spill.compress": "true", + }, + sparkctl_command=None, + estimated_impact="Medium - reduces disk space for shuffle", + ), + ], + ErrorCategory.CONNECTION: [ + Suggestion( + priority=1, + category="network", + title="Increase network timeout", + description=( + "Increase timeout for RPC and network operations. " + "Helps with transient network issues." + ), + config_changes={ + "spark.network.timeout": "600s", + "spark.rpc.askTimeout": "600s", + }, + sparkctl_command=None, + estimated_impact="Medium - prevents premature timeouts", + ), + Suggestion( + priority=2, + category="network", + title="Check network connectivity", + description=( + "Verify all nodes can communicate on required ports. " + "Spark uses various ports for master/worker communication." + ), + config_changes=None, + sparkctl_command=None, + estimated_impact="High - identifies network issues", + ), + Suggestion( + priority=3, + category="resource", + title="Restart the cluster", + description=( + "Connection issues may resolve after a cluster restart. " + "Stop and start the cluster to reset all connections." + ), + config_changes=None, + sparkctl_command="sparkctl stop && sparkctl start", + estimated_impact="Medium - resets connection state", + ), + ], + ErrorCategory.SERIALIZATION: [ + Suggestion( + priority=1, + category="code", + title="Check for non-serializable objects", + description=( + "Ensure all objects used in closures are serializable. " + "Move non-serializable objects outside of closures or use broadcast variables." + ), + config_changes=None, + sparkctl_command=None, + estimated_impact="High - fixes root cause", + ), + Suggestion( + priority=2, + category="configuration", + title="Use Kryo serialization", + description=( + "Kryo is faster and more compact than Java serialization. " + "May require registering custom classes." + ), + config_changes={ + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + }, + sparkctl_command=None, + estimated_impact="Medium - better serialization performance", + ), + ], + ErrorCategory.TIMEOUT: [ + Suggestion( + priority=1, + category="configuration", + title="Increase heartbeat interval", + description=("Increase executor heartbeat timeout for long-running tasks."), + config_changes={ + "spark.executor.heartbeatInterval": "60s", + "spark.network.timeout": "600s", + }, + sparkctl_command=None, + estimated_impact="Medium - prevents false timeouts", + ), + Suggestion( + priority=2, + category="memory", + title="Reduce GC pressure", + description=( + "Long GC pauses can cause heartbeat timeouts. " + "Increase memory or enable off-heap to reduce GC." + ), + config_changes={"spark.executor.memory": "16g"}, + sparkctl_command="sparkctl configure --executor-memory-gb 16", + estimated_impact="Medium - reduces GC pauses", + ), + ], + ErrorCategory.STAGE: [ + Suggestion( + priority=1, + category="retry", + title="Increase task max failures", + description=("Allow more retries for transient failures."), + config_changes={"spark.task.maxFailures": "8"}, + sparkctl_command=None, + estimated_impact="Low - masks underlying issues", + ), + Suggestion( + priority=2, + category="resource", + title="Check executor logs", + description=( + "Stage failures are usually caused by task failures. " + "Check executor logs for the root cause." + ), + config_changes=None, + sparkctl_command=None, + estimated_impact="High - identifies root cause", + ), + ], + ErrorCategory.CONFIGURATION: [ + Suggestion( + priority=1, + category="configuration", + title="Check classpath and dependencies", + description=( + "ClassNotFoundException usually indicates missing JARs. " + "Ensure all required dependencies are available to executors." + ), + config_changes=None, + sparkctl_command=None, + estimated_impact="High - fixes missing dependencies", + ), + Suggestion( + priority=2, + category="configuration", + title="Verify configuration values", + description=("Check spark-defaults.conf for invalid configuration values."), + config_changes=None, + sparkctl_command=None, + estimated_impact="Medium - fixes configuration errors", + ), + ], + } + + def get_suggestions( + self, + error_categories: list[ErrorCategory], + current_config: dict[str, str] | None = None, + ) -> RecoverySuggestions: + """Generate recovery suggestions based on detected error categories. + + Parameters + ---------- + error_categories + List of error categories detected. + current_config + Current Spark configuration (optional, for context). + + Returns + ------- + RecoverySuggestions + RecoverySuggestions with prioritized remediation steps. + """ + suggestions: list[Suggestion] = [] + seen_titles: set[str] = set() + + # Collect suggestions for each error category + for category in error_categories: + if category in self.RECOVERY_STRATEGIES: + for suggestion in self.RECOVERY_STRATEGIES[category]: + # Avoid duplicate suggestions + if suggestion.title not in seen_titles: + suggestions.append(suggestion) + seen_titles.add(suggestion.title) + + # Sort by priority + suggestions.sort(key=lambda s: s.priority) + + # Determine if cluster restart is needed + restart_categories = { + "memory", + "resource", + } + requires_restart = any(s.category in restart_categories for s in suggestions) + + # Determine if torc analysis is recommended + system_categories = { + ErrorCategory.CONNECTION, + ErrorCategory.DISK, + ErrorCategory.TIMEOUT, + } + recommend_torc = bool(set(error_categories) & system_categories) + + torc_recommendation = None + if recommend_torc: + torc_recommendation = ( + "System-level issues detected (connection, disk, or timeout errors). " + "Use torc's analyze_workflow_logs tool to check for Slurm job failures, " + "node health issues, or filesystem problems that may be causing Spark failures." + ) + + return RecoverySuggestions( + suggestions=suggestions, + requires_cluster_restart=requires_restart, + recommend_torc_analysis=recommend_torc, + torc_recommendation=torc_recommendation, + ) diff --git a/src/sparkctl/mcp_server/server.py b/src/sparkctl/mcp_server/server.py new file mode 100644 index 0000000..880e21a --- /dev/null +++ b/src/sparkctl/mcp_server/server.py @@ -0,0 +1,225 @@ +"""MCP Server for sparkctl Spark job failure diagnostics. + +This server provides tools for diagnosing Spark job failures by analyzing +logs from master, worker, executor, thrift-server, and connect-server. + +Usage: + sparkctl-mcp-server + +The server communicates over stdio using the MCP protocol. +""" + +import sys +from typing import Literal + +# Check for mcp package availability +try: + from mcp.server.fastmcp import FastMCP +except ImportError as e: + print( + "Error: The 'mcp' package is required to run the MCP server.\n" + "Install it with: pip install 'sparkctl[mcp]'", + file=sys.stderr, + ) + raise SystemExit(1) from e + +from sparkctl.mcp_server import tools + + +# Create the MCP server +mcp = FastMCP( + name="sparkctl-diagnostics", +) + + +@mcp.tool() +def get_spark_logs( + spark_scratch: str, + log_type: Literal["master", "worker", "executor", "connect", "thrift", "all"] = "all", + app_id: str | None = None, + executor_id: str | None = None, + tail_lines: int = 500, +) -> str: + """Retrieve and aggregate Spark logs from the cluster. + + Use this tool to fetch log content from Spark components for analysis. + Logs are retrieved from the spark_scratch directory which contains + master, worker, and executor logs. + + Parameters + ---------- + spark_scratch + Path to the spark_scratch directory containing logs. + log_type + Type of logs to retrieve: + - "master": Spark master logs + - "worker": Spark worker logs + - "executor": Executor stderr logs (per application) + - "connect": Spark Connect server logs + - "thrift": Thrift server logs + - "all": All available logs + app_id + Filter executor logs by application ID (e.g., "app-20240115120000-0000"). + executor_id + Filter by specific executor ID (e.g., "0", "1"). + tail_lines + Number of lines to retrieve from the end of each log file. + + Returns + ------- + str + JSON containing log entries with source, filepath, content, and metadata. + """ + result = tools.get_spark_logs( + spark_scratch=spark_scratch, + log_type=log_type, + app_id=app_id, + executor_id=executor_id, + tail_lines=tail_lines, + ) + return result.model_dump_json(indent=2) + + +@mcp.tool() +def analyze_spark_failure( + spark_scratch: str, + app_id: str | None = None, + include_stack_traces: bool = True, + max_errors: int = 50, +) -> str: + """Analyze Spark logs for failure patterns and provide diagnosis. + + This is the primary diagnostic tool. It scans all Spark logs for known + error patterns including: + - Out of memory errors (OOM) + - Shuffle failures (FetchFailedException) + - Stage and task failures + - Connection/network issues + - Disk space issues + - Serialization errors + - Timeout errors + + The analysis provides: + - Categorized list of detected errors with severity + - Error count summary by type + - Most likely root cause determination + - Affected stages and executors + - Recommendation on whether to use torc for system-level analysis + + Parameters + ---------- + spark_scratch + Path to the spark_scratch directory containing logs. + app_id + Specific application ID to analyze. If not specified, analyzes all apps. + include_stack_traces + Whether to include full stack traces in the output. + max_errors + Maximum number of error occurrences to include (default 50). + + Returns + ------- + str + JSON containing SparkFailureAnalysis with errors, summary, and recommendations. + """ + result = tools.analyze_spark_failure( + spark_scratch=spark_scratch, + app_id=app_id, + include_stack_traces=include_stack_traces, + max_errors=max_errors, + ) + return result.model_dump_json(indent=2) + + +@mcp.tool() +def get_recovery_suggestions( + error_types: list[str], + current_config: dict[str, str] | None = None, +) -> str: + """Get recovery suggestions for detected Spark errors. + + Based on the error types detected by analyze_spark_failure, this tool + provides prioritized suggestions for fixing the issues. Each suggestion + includes: + - Priority ranking + - Category (memory, shuffle, resource, etc.) + - Description of the fix + - Specific configuration changes to apply + - sparkctl commands to run (if applicable) + - Expected impact + + Common error_types values: + - "out_of_memory": OOM errors + - "shuffle_failure": Shuffle/fetch failures + - "stage_failure": Stage or task failures + - "resource_exhaustion": Cluster resource issues + - "connection_failure": Network issues + - "disk_failure": Disk space issues + - "timeout": Timeout errors + + Parameters + ---------- + error_types + List of error type strings from analyze_spark_failure's error_summary. + current_config + Current Spark configuration values (optional, for context). + + Returns + ------- + str + JSON containing prioritized recovery suggestions and guidance. + """ + result = tools.get_recovery_suggestions( + error_types=error_types, + current_config=current_config, + ) + return result.model_dump_json(indent=2) + + +@mcp.tool() +def list_spark_applications( + spark_scratch: str, +) -> str: + """List Spark applications found in the spark_scratch directory. + + Discovers applications from executor log directories and provides + metadata about each application including: + - Application ID + - Start time (from directory creation) + - Whether executor logs exist + - Number of executors + - Paths to log files + + This is useful for identifying which applications to analyze. + + Parameters + ---------- + spark_scratch + Path to the spark_scratch directory. + + Returns + ------- + str + JSON containing list of SparkApplication objects with metadata. + """ + result = tools.list_spark_applications(spark_scratch=spark_scratch) + return result.model_dump_json(indent=2) + + +def main(): + """Entry point for sparkctl-mcp-server command.""" + # Configure logging to stderr only (stdout is for MCP protocol) + import logging + + logging.basicConfig( + level=logging.WARNING, + format="%(levelname)s: %(message)s", + stream=sys.stderr, + ) + + # Run the MCP server with stdio transport + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/src/sparkctl/mcp_server/tools.py b/src/sparkctl/mcp_server/tools.py new file mode 100644 index 0000000..0689667 --- /dev/null +++ b/src/sparkctl/mcp_server/tools.py @@ -0,0 +1,398 @@ +"""MCP tool implementations for Spark log analysis.""" + +import re +from datetime import datetime +from pathlib import Path +from typing import Literal + +from loguru import logger +from sparkctl.mcp_server.error_patterns import ErrorPatternRegistry +from sparkctl.mcp_server.log_parser import SparkLogLocator, SparkLogParser +from sparkctl.mcp_server.models import ( + ErrorCategory, + ErrorOccurrence, + LogEntry, + SparkApplication, + SparkApplicationList, + SparkFailureAnalysis, + SparkLogsResponse, +) +from sparkctl.mcp_server.recovery import RecoveryEngine + + +def get_spark_logs( + spark_scratch: str, + log_type: Literal["master", "worker", "executor", "connect", "thrift", "all"] = "all", + app_id: str | None = None, + executor_id: str | None = None, + tail_lines: int = 500, +) -> SparkLogsResponse: + """Retrieve and aggregate Spark logs from the cluster. + + Parameters + ---------- + spark_scratch + Path to the spark_scratch directory containing logs. + log_type + Type of logs to retrieve (master, worker, executor, connect, thrift, or all). + app_id + Filter executor logs by application ID (e.g., "app-20240115120000-0000"). + executor_id + Filter by specific executor ID (e.g., "0", "1"). + tail_lines + Number of lines to retrieve from the end of each log file. + + Returns + ------- + SparkLogsResponse + SparkLogsResponse containing aggregated log content from requested sources. + """ + scratch_path = Path(spark_scratch) + locator = SparkLogLocator(scratch_path) + parser = SparkLogParser() + + log_entries: list[LogEntry] = [] + truncated = False + + # Get requested log files + log_files = [] + if log_type in ("master", "all"): + log_files.extend(locator.get_master_logs()) + if log_type in ("worker", "all"): + log_files.extend(locator.get_worker_logs()) + if log_type in ("connect", "all"): + log_files.extend(locator.get_connect_logs()) + if log_type in ("thrift", "all"): + log_files.extend(locator.get_thrift_logs()) + if log_type in ("executor", "all"): + executor_logs = locator.get_executor_logs(app_id) + # Filter by executor_id if specified + if executor_id is not None: + executor_logs = [f for f in executor_logs if f.source_id.endswith(f"/{executor_id}")] + log_files.extend(executor_logs) + + for log_file in log_files: + try: + lines = parser._read_lines(log_file.path, tail_lines) + content = "\n".join(lines) + + # Check if we truncated + if tail_lines and len(lines) == tail_lines: + truncated = True + + source = f"{log_file.source_type}-{log_file.source_id}" + log_entries.append( + LogEntry( + source=source, + filepath=str(log_file.path), + content=content, + line_count=len(lines), + last_modified=log_file.last_modified, + ) + ) + except Exception: + # Skip files we can't read + continue + + return SparkLogsResponse( + logs=log_entries, + total_entries=len(log_entries), + truncated=truncated, + spark_scratch=str(scratch_path), + ) + + +def _extract_stack_trace(entry_full_text: str, include_stack_traces: bool) -> str | None: + """Extract stack trace from log entry if present.""" + if not include_stack_traces or "\n" not in entry_full_text: + return None + return "\n".join(entry_full_text.split("\n")[1:]) + + +def _extract_affected_resources( + entry_full_text: str, + stage_pattern: re.Pattern[str], + executor_pattern: re.Pattern[str], + affected_stages: set[str], + affected_executors: set[str], +) -> None: + """Extract and update affected stages and executors from log entry.""" + stage_match = stage_pattern.search(entry_full_text) + if stage_match: + affected_stages.add(stage_match.group(1)) + + executor_match = executor_pattern.search(entry_full_text) + if executor_match: + affected_executors.add(executor_match.group(1)) + + +def _process_log_entry( + entry, + log_file, + errors: list[ErrorOccurrence], + error_counts: dict[str, int], + affected_stages: set[str], + affected_executors: set[str], + stage_pattern: re.Pattern[str], + executor_pattern: re.Pattern[str], + include_stack_traces: bool, + max_errors: int, +) -> tuple[bool, str | None]: + """Process a single log entry and update analysis state.""" + should_torc, reason = ErrorPatternRegistry.should_recommend_torc(entry.full_text) + + pattern = ErrorPatternRegistry.classify_error(entry.full_text) + if pattern is None: + return should_torc, reason + + category_str = pattern.category.value + error_counts[category_str] = error_counts.get(category_str, 0) + 1 + + _extract_affected_resources( + entry.full_text, stage_pattern, executor_pattern, affected_stages, affected_executors + ) + + if len(errors) < max_errors: + stack_trace = _extract_stack_trace(entry.full_text, include_stack_traces) + errors.append( + ErrorOccurrence( + error_type=pattern.category, + severity=pattern.severity, + message=entry.message, + stack_trace=stack_trace, + source_file=str(log_file.path), + line_number=entry.line_start, + timestamp=entry.timestamp, + context_lines=[], + ) + ) + + return should_torc, reason + + +def analyze_spark_failure( + spark_scratch: str, + app_id: str | None = None, + include_stack_traces: bool = True, + max_errors: int = 50, +) -> SparkFailureAnalysis: + """Analyze Spark logs for failure patterns and provide diagnosis. + + Scans master, worker, and executor logs for known error patterns including + OOM errors, shuffle failures, stage failures, connection issues, and more. + Provides a summary of detected errors and recommends recovery actions. + + Parameters + ---------- + spark_scratch + Path to the spark_scratch directory containing logs. + app_id + Specific application ID to analyze (optional, analyzes all if not specified). + include_stack_traces + Whether to include full stack traces in the output. + max_errors + Maximum number of error occurrences to include in response. + + Returns + ------- + SparkFailureAnalysis + SparkFailureAnalysis with categorized errors, root cause analysis, + and recommendations for recovery. + """ + scratch_path = Path(spark_scratch) + locator = SparkLogLocator(scratch_path) + parser = SparkLogParser() + + errors: list[ErrorOccurrence] = [] + error_counts: dict[str, int] = {} + affected_stages: set[str] = set() + affected_executors: set[str] = set() + recommend_torc = False + torc_reason: str | None = None + + stage_pattern = re.compile(r"stage[s]?\s*(\d+(?:\.\d+)?)", re.IGNORECASE) + executor_pattern = re.compile(r"executor[s]?\s*(\d+)", re.IGNORECASE) + + log_files = locator.get_all_log_files() + if app_id: + log_files = [f for f in log_files if f.source_type != "executor" or app_id in f.source_id] + + for log_file in log_files: + try: + entries = parser.parse_file(log_file.path) + problem_entries = parser.get_warnings_and_errors(entries) + + for entry in problem_entries: + should_torc, reason = _process_log_entry( + entry, + log_file, + errors, + error_counts, + affected_stages, + affected_executors, + stage_pattern, + executor_pattern, + include_stack_traces, + max_errors, + ) + + if should_torc: + recommend_torc = True + if torc_reason is None: + torc_reason = reason + except Exception: + logger.exception("Failed to analyze file {}", log_file) + continue + + likely_root_cause = _determine_root_cause(error_counts) + + torc_recommendation = None + if recommend_torc: + torc_recommendation = ( + f"{torc_reason}. Use torc's analyze_workflow_logs tool to check for " + "Slurm job failures, node health issues, or filesystem problems." + ) + + return SparkFailureAnalysis( + app_id=app_id, + errors=errors, + error_summary=error_counts, + likely_root_cause=likely_root_cause, + affected_stages=sorted(affected_stages), + affected_executors=sorted(affected_executors), + analysis_timestamp=datetime.now(), + recommend_torc_analysis=recommend_torc, + torc_recommendation=torc_recommendation, + ) + + +def _determine_root_cause(error_counts: dict[str, int]) -> str | None: + """Determine the most likely root cause from error counts.""" + if not error_counts: + return None + + # Priority order for root cause determination + priority = ErrorPatternRegistry.get_root_cause_priority() + + # Find the highest priority error that occurred + root_cause = None + best_priority = float("inf") + + for error_type, count in error_counts.items(): + if count > 0: + try: + category = ErrorCategory(error_type) + p = priority.get(category, 100) + if p < best_priority: + best_priority = p + root_cause = error_type + except ValueError: + continue + + if root_cause: + # Get description from pattern registry + for pattern in ErrorPatternRegistry.PATTERNS: + if pattern.category.value == root_cause: + causes = ", ".join(pattern.common_causes[:2]) + return f"{pattern.description}. Common causes: {causes}" + + return root_cause + + +def get_recovery_suggestions( + error_types: list[str], + current_config: dict[str, str] | None = None, +): + """Get recovery suggestions for detected Spark errors. + + Based on the error types detected by analyze_spark_failure, provides + prioritized suggestions for fixing the issues. + + Parameters + ---------- + error_types + List of error type strings from analyze_spark_failure + (e.g., ["out_of_memory", "shuffle_failure"]). + current_config + Current Spark configuration values (optional, for context). + + Returns + ------- + RecoverySuggestions + RecoverySuggestions with prioritized remediation steps and + guidance on whether to use torc for system-level analysis. + """ + engine = RecoveryEngine() + + # Convert string error types to ErrorCategory + categories: list[ErrorCategory] = [] + for error_type in error_types: + try: + categories.append(ErrorCategory(error_type)) + except ValueError: + # Skip unknown error types + continue + + return engine.get_suggestions(categories, current_config) + + +def list_spark_applications( + spark_scratch: str, + include_completed: bool = True, + include_failed: bool = True, +) -> SparkApplicationList: + """List Spark applications found in the spark_scratch directory. + + Discovers applications from executor log directories and provides + metadata about each application. + + Parameters + ---------- + spark_scratch + Path to the spark_scratch directory. + include_completed + Include applications that completed successfully. + include_failed + Include applications that failed. + + Returns + ------- + SparkApplicationList + SparkApplicationList with application metadata. + """ + scratch_path = Path(spark_scratch) + locator = SparkLogLocator(scratch_path) + + applications: list[SparkApplication] = [] + app_ids = locator.get_application_ids() + + for app_id in app_ids: + executor_logs = locator.get_executor_logs(app_id) + executor_count = len(executor_logs) + + # Get start time from directory creation time + app_dir = scratch_path / "workers" / app_id + start_time = None + if app_dir.exists(): + try: + start_time = datetime.fromtimestamp(app_dir.stat().st_ctime) + except OSError: + pass + + log_files = [str(log.path) for log in executor_logs] + + applications.append( + SparkApplication( + app_id=app_id, + start_time=start_time, + has_executor_logs=executor_count > 0, + executor_count=executor_count, + log_files=log_files, + ) + ) + + return SparkApplicationList( + applications=applications, + total_count=len(applications), + spark_scratch=str(scratch_path), + )