diff --git a/sigma/cli/check.py b/sigma/cli/check.py index b2565cd..279de45 100644 --- a/sigma/cli/check.py +++ b/sigma/cli/check.py @@ -1,4 +1,5 @@ import pathlib +import xml.etree.ElementTree as ET from collections import Counter from sys import stderr from textwrap import fill @@ -16,6 +17,66 @@ severity_color = {"low": "green", "medium": "yellow", "high": "red"} +# JUnit-specific icons covering PySigma validation and parsing possibilities +SEVERITY_ICONS = { + "critical": "đŸ’Ĩ", + "high": "🔴", + "medium": "🟠", + "low": "🟡", + "informational": "â„šī¸", + "parsing_error": "đŸšĢ", + "condition_error": "❌", + "error": "❌", # Fallback for other SigmaError subclasses + "ok": "✅", +} + + +def generate_junit_report(results, output_file): + """Generates JUnit XML grouped by PySigma error/issue types.""" + root = ET.Element("testsuites", name="Sigma Rule Validation") + + suites = {} + for res in results: + suite_name = res.get("issue_type", "Validation Success") + if suite_name not in suites: + suites[suite_name] = [] + suites[suite_name].append(res) + + for suite_name, tests in suites.items(): + failures = len([t for t in tests if t["status"] == "failed"]) + test_suite = ET.SubElement( + root, + "testsuite", + name=suite_name, + tests=str(len(tests)), + failures=str(failures), + ) + + for res in tests: + severity = (res.get("severity") or "ok").lower() + icon = SEVERITY_ICONS.get(severity, SEVERITY_ICONS["error"]) + + test_case = ET.SubElement( + test_suite, + "testcase", + name=f"{icon} {res['rule_name']}", + classname=suite_name, + file=res.get("file_path", "unknown"), + ) + + if res["status"] == "failed": + failure = ET.SubElement( + test_case, + "failure", + message=f"{res.get('issue_type')}: {(res.get('severity') or 'UNKNOWN').upper()}", + ) + failure.text = res.get("description", "") + + tree = ET.ElementTree(root) + with open(output_file, "wb") as f: + tree.write(f, encoding="utf-8", xml_declaration=True) + + # ========================================== # Data Processing & Extraction Functions # ========================================== @@ -70,7 +131,7 @@ def validate_loaded_rules(check_rules, rule_validator): return issues -def load_and_check_rules(input, file_pattern, rule_errors, cond_errors): +def load_and_check_rules(input, file_pattern, rule_errors, cond_errors, junit_results=None): rule_collection = load_rules(input, file_pattern) check_rules = list() first_error = True @@ -85,6 +146,15 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors): for error in rule.errors: click.echo(error) rule_errors.update((error.__class__.__name__,)) + if junit_results is not None: + error_type = error.__class__.__name__ + rule_name = rule.title or str(rule.path) + file_path = str(rule.source) if rule.source else "unknown" + sev = "parsing_error" if "Parse" in error_type else "error" + junit_results.append({ + "rule_name": rule_name, "file_path": file_path, "status": "failed", + "issue_type": error_type, "severity": sev, "description": str(error) + }) elif isinstance(rule, SigmaRule): # rule has no errors, parse condition try: for condition in rule.detection.parsed_condition: @@ -96,6 +166,15 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors): f"Condition error in { str(condition.source) }:{ error }" ) cond_errors.update((error,)) + if junit_results is not None: + junit_results.append({ + "rule_name": rule.title or str(rule.path), + "file_path": str(rule.source) if rule.source else "unknown", + "status": "failed", + "issue_type": e.__class__.__name__, + "severity": "condition_error", + "description": str(e), + }) else: check_rules.append(rule) return check_rules @@ -128,6 +207,11 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors): show_default=True, help="Fail on Sigma rule validation issues.", ) +@click.option( + "--junitxml", + type=click.Path(path_type=pathlib.Path), + help="Output results in JUnit XML format to the specified file.", +) @click.option( "--exclude", "-x", @@ -143,7 +227,7 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors): type=click.Path(exists=True, allow_dash=True, path_type=pathlib.Path), ) def check( - input, validation_config, file_pattern, fail_on_error, fail_on_issues, exclude + input, validation_config, file_pattern, fail_on_error, fail_on_issues, junitxml, exclude ): """Check Sigma rules for validity and best practices (not yet implemented).""" @@ -152,7 +236,8 @@ def check( try: rule_errors = Counter() cond_errors = Counter() - check_rules = load_and_check_rules(input, file_pattern, rule_errors, cond_errors) + junit_results = [] if junitxml else None + check_rules = load_and_check_rules(input, file_pattern, rule_errors, cond_errors, junit_results) # TODO: From Python 3.10 the commented line below can be used. rule_error_count = sum(rule_errors.values()) @@ -198,6 +283,16 @@ def check( + f" {additional_fields}" ) issue_counter.update((issue.__class__,)) + if junit_results is not None: + for rule in issue.rules: + junit_results.append({ + "rule_name": rule.title or str(rule.path), + "file_path": str(rule.source) if rule.source else "unknown", + "status": "failed", + "issue_type": type(issue).__name__, + "severity": issue.severity.name.lower(), + "description": str(issue.description), + }) # TODO: From Python 3.10 the commented line below can be used. cond_error_count = sum(cond_errors.values()) @@ -267,6 +362,10 @@ def check( else: click.echo("No validation issues found.") + if junitxml: + generate_junit_report(junit_results, junitxml) + click.echo(f"\nJUnit report saved to: {junitxml}") + if ( fail_on_error and (rule_error_count > 0 or cond_error_count > 0) @@ -276,4 +375,13 @@ def check( click.echo("Check failure") click.get_current_context().exit(1) except SigmaError as e: + if junitxml: + generate_junit_report([{ + "rule_name": "Global/Collection Loading Error", + "file_path": str(input), + "status": "failed", + "issue_type": e.__class__.__name__, + "severity": "error", + "description": str(e), + }], junitxml) raise click.ClickException("Check error: " + str(e)) diff --git a/tests/test_check.py b/tests/test_check.py index 5fa0c39..1aa98c1 100644 --- a/tests/test_check.py +++ b/tests/test_check.py @@ -1,3 +1,6 @@ +import xml.etree.ElementTree as ET +from pathlib import Path + from click.testing import CliRunner from sigma.cli.check import check @@ -88,3 +91,53 @@ def test_check_exclude(): assert "Invalid validators name" in result.stdout assert "myvalidator" in result.stdout assert "Check failure" in result.stdout + + +def test_check_junitxml_creates_file(tmp_path): + """JUnit XML file is created and well-formed when --junitxml is specified.""" + cli = CliRunner() + output_xml = tmp_path / "results.xml" + # Use --pass-on-error so exit code is 0; invalid rules don't reach the network + # validator, which avoids the pre-existing MITRE D3FEND network dependency. + result = cli.invoke( + check, ["--pass-on-error", "--junitxml", str(output_xml), "tests/files/invalid"] + ) + assert result.exit_code == 0 + assert output_xml.exists(), "JUnit XML file was not created" + assert f"JUnit report saved to: {output_xml}" in result.stdout + + # Verify the XML is well-formed and has the expected root element + tree = ET.parse(output_xml) + root = tree.getroot() + assert root.tag == "testsuites" + + +def test_check_junitxml_invalid(tmp_path): + """JUnit XML file contains failure entries when rules have errors.""" + cli = CliRunner() + output_xml = tmp_path / "results.xml" + result = cli.invoke( + check, ["--pass-on-error", "--junitxml", str(output_xml), "tests/files/invalid"] + ) + assert result.exit_code == 0 + assert output_xml.exists(), "JUnit XML file was not created" + + tree = ET.parse(output_xml) + root = tree.getroot() + assert root.tag == "testsuites" + + # At least one testsuite should have failures > 0 + failures = sum( + int(suite.get("failures", "0")) for suite in root.findall("testsuite") + ) + assert failures > 0, "Expected failure entries in JUnit XML for invalid rules" + + # The number of elements across all testcases must match + # the reported failure counts in the testsuite attributes. + for suite in root.findall("testsuite"): + reported_failures = int(suite.get("failures", "0")) + actual_failures = sum( + 1 for testcase in suite.findall("testcase") + if testcase.find("failure") is not None + ) + assert actual_failures == reported_failures