Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 111 additions & 3 deletions sigma/cli/check.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
import xml.etree.ElementTree as ET
from collections import Counter
from sys import stderr
from textwrap import fill
Expand All @@ -16,6 +17,66 @@

severity_color = {"low": "green", "medium": "yellow", "high": "red"}

# JUnit-specific icons covering PySigma validation and parsing possibilities
SEVERITY_ICONS = {
"critical": "💥",
"high": "🔴",
"medium": "🟠",
"low": "🟡",
"informational": "ℹ️",
"parsing_error": "🚫",
"condition_error": "❌",
"error": "❌", # Fallback for other SigmaError subclasses
"ok": "✅",
}


def generate_junit_report(results, output_file):
"""Generates JUnit XML grouped by PySigma error/issue types."""
root = ET.Element("testsuites", name="Sigma Rule Validation")

suites = {}
for res in results:
suite_name = res.get("issue_type", "Validation Success")
if suite_name not in suites:
suites[suite_name] = []
suites[suite_name].append(res)

for suite_name, tests in suites.items():
failures = len([t for t in tests if t["status"] == "failed"])
test_suite = ET.SubElement(
root,
"testsuite",
name=suite_name,
tests=str(len(tests)),
failures=str(failures),
)

for res in tests:
severity = (res.get("severity") or "ok").lower()
icon = SEVERITY_ICONS.get(severity, SEVERITY_ICONS["error"])

test_case = ET.SubElement(
test_suite,
"testcase",
name=f"{icon} {res['rule_name']}",
classname=suite_name,
file=res.get("file_path", "unknown"),
)

if res["status"] == "failed":
failure = ET.SubElement(
test_case,
"failure",
message=f"{res.get('issue_type')}: {(res.get('severity') or 'UNKNOWN').upper()}",
)
failure.text = res.get("description", "")

tree = ET.ElementTree(root)
with open(output_file, "wb") as f:
tree.write(f, encoding="utf-8", xml_declaration=True)


# ==========================================
# Data Processing & Extraction Functions
# ==========================================
Expand Down Expand Up @@ -70,7 +131,7 @@ def validate_loaded_rules(check_rules, rule_validator):
return issues


def load_and_check_rules(input, file_pattern, rule_errors, cond_errors):
def load_and_check_rules(input, file_pattern, rule_errors, cond_errors, junit_results=None):
rule_collection = load_rules(input, file_pattern)
check_rules = list()
first_error = True
Expand All @@ -85,6 +146,15 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors):
for error in rule.errors:
click.echo(error)
rule_errors.update((error.__class__.__name__,))
if junit_results is not None:
error_type = error.__class__.__name__
rule_name = rule.title or str(rule.path)
file_path = str(rule.source) if rule.source else "unknown"
sev = "parsing_error" if "Parse" in error_type else "error"
junit_results.append({
"rule_name": rule_name, "file_path": file_path, "status": "failed",
"issue_type": error_type, "severity": sev, "description": str(error)
})
elif isinstance(rule, SigmaRule): # rule has no errors, parse condition
try:
for condition in rule.detection.parsed_condition:
Expand All @@ -96,6 +166,15 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors):
f"Condition error in { str(condition.source) }:{ error }"
)
cond_errors.update((error,))
if junit_results is not None:
junit_results.append({
"rule_name": rule.title or str(rule.path),
"file_path": str(rule.source) if rule.source else "unknown",
"status": "failed",
"issue_type": e.__class__.__name__,
"severity": "condition_error",
"description": str(e),
})
else:
check_rules.append(rule)
return check_rules
Expand Down Expand Up @@ -128,6 +207,11 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors):
show_default=True,
help="Fail on Sigma rule validation issues.",
)
@click.option(
"--junitxml",
type=click.Path(path_type=pathlib.Path),
help="Output results in JUnit XML format to the specified file.",
)
@click.option(
"--exclude",
"-x",
Expand All @@ -143,7 +227,7 @@ def load_and_check_rules(input, file_pattern, rule_errors, cond_errors):
type=click.Path(exists=True, allow_dash=True, path_type=pathlib.Path),
)
def check(
input, validation_config, file_pattern, fail_on_error, fail_on_issues, exclude
input, validation_config, file_pattern, fail_on_error, fail_on_issues, junitxml, exclude
):
"""Check Sigma rules for validity and best practices (not yet implemented)."""

Expand All @@ -152,7 +236,8 @@ def check(
try:
rule_errors = Counter()
cond_errors = Counter()
check_rules = load_and_check_rules(input, file_pattern, rule_errors, cond_errors)
junit_results = [] if junitxml else None
check_rules = load_and_check_rules(input, file_pattern, rule_errors, cond_errors, junit_results)

# TODO: From Python 3.10 the commented line below can be used.
rule_error_count = sum(rule_errors.values())
Expand Down Expand Up @@ -198,6 +283,16 @@ def check(
+ f" {additional_fields}"
)
issue_counter.update((issue.__class__,))
if junit_results is not None:
for rule in issue.rules:
junit_results.append({
"rule_name": rule.title or str(rule.path),
"file_path": str(rule.source) if rule.source else "unknown",
"status": "failed",
"issue_type": type(issue).__name__,
"severity": issue.severity.name.lower(),
"description": str(issue.description),
})

# TODO: From Python 3.10 the commented line below can be used.
cond_error_count = sum(cond_errors.values())
Expand Down Expand Up @@ -267,6 +362,10 @@ def check(
else:
click.echo("No validation issues found.")

if junitxml:
generate_junit_report(junit_results, junitxml)
click.echo(f"\nJUnit report saved to: {junitxml}")

if (
fail_on_error
and (rule_error_count > 0 or cond_error_count > 0)
Expand All @@ -276,4 +375,13 @@ def check(
click.echo("Check failure")
click.get_current_context().exit(1)
except SigmaError as e:
if junitxml:
generate_junit_report([{
"rule_name": "Global/Collection Loading Error",
"file_path": str(input),
"status": "failed",
"issue_type": e.__class__.__name__,
"severity": "error",
"description": str(e),
}], junitxml)
raise click.ClickException("Check error: " + str(e))
53 changes: 53 additions & 0 deletions tests/test_check.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import xml.etree.ElementTree as ET
from pathlib import Path

from click.testing import CliRunner

from sigma.cli.check import check
Expand Down Expand Up @@ -88,3 +91,53 @@ def test_check_exclude():
assert "Invalid validators name" in result.stdout
assert "myvalidator" in result.stdout
assert "Check failure" in result.stdout


def test_check_junitxml_creates_file(tmp_path):
"""JUnit XML file is created and well-formed when --junitxml is specified."""
cli = CliRunner()
output_xml = tmp_path / "results.xml"
# Use --pass-on-error so exit code is 0; invalid rules don't reach the network
# validator, which avoids the pre-existing MITRE D3FEND network dependency.
result = cli.invoke(
check, ["--pass-on-error", "--junitxml", str(output_xml), "tests/files/invalid"]
)
assert result.exit_code == 0
assert output_xml.exists(), "JUnit XML file was not created"
assert f"JUnit report saved to: {output_xml}" in result.stdout

# Verify the XML is well-formed and has the expected root element
tree = ET.parse(output_xml)
root = tree.getroot()
assert root.tag == "testsuites"


def test_check_junitxml_invalid(tmp_path):
"""JUnit XML file contains failure entries when rules have errors."""
cli = CliRunner()
output_xml = tmp_path / "results.xml"
result = cli.invoke(
check, ["--pass-on-error", "--junitxml", str(output_xml), "tests/files/invalid"]
)
assert result.exit_code == 0
assert output_xml.exists(), "JUnit XML file was not created"

tree = ET.parse(output_xml)
root = tree.getroot()
assert root.tag == "testsuites"

# At least one testsuite should have failures > 0
failures = sum(
int(suite.get("failures", "0")) for suite in root.findall("testsuite")
)
assert failures > 0, "Expected failure entries in JUnit XML for invalid rules"

# The number of <failure> elements across all testcases must match
# the reported failure counts in the testsuite attributes.
for suite in root.findall("testsuite"):
reported_failures = int(suite.get("failures", "0"))
actual_failures = sum(
1 for testcase in suite.findall("testcase")
if testcase.find("failure") is not None
)
assert actual_failures == reported_failures