diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 137ce180071..3d6469ffc10 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,14 +1,16 @@ # detection-rules code owners # POC: Elastic Security Intelligence and Analytics Team -tests/**/*.py @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/ @mikaayenson @eric-forte-elastic @terrancedejesus -tests/ @mikaayenson @eric-forte-elastic @terrancedejesus -lib/ @mikaayenson @eric-forte-elastic @terrancedejesus -hunting/ @mikaayenson @eric-forte-elastic @terrancedejesus +tests/**/*.py @mikaayenson @eric-forte-elastic @traut +detection_rules/ @mikaayenson @eric-forte-elastic @traut +tests/ @mikaayenson @eric-forte-elastic @traut +lib/ @mikaayenson @eric-forte-elastic @traut +hunting/ @mikaayenson @eric-forte-elastic @traut # skip rta-mapping to avoid the spam -detection_rules/etc/packages.yaml @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/etc/*.json @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/etc/*.json @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/etc/*/* @mikaayenson @eric-forte-elastic @terrancedejesus +detection_rules/etc/packages.yaml @mikaayenson @eric-forte-elastic @traut +detection_rules/etc/*.json @mikaayenson @eric-forte-elastic @traut +detection_rules/etc/*/* @mikaayenson @eric-forte-elastic @traut + +# exclude files from code owners +detection_rules/etc/non-ecs-schema.json diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml new file mode 100644 index 00000000000..7e5698b206b --- /dev/null +++ b/.github/workflows/code-checks.yaml @@ -0,0 +1,47 @@ +name: Code checks + +on: + push: + branches: [ "main", "7.*", "8.*", "9.*" ] + pull_request: + branches: [ "*" ] + paths: + - 'detection_rules/**/*.py' + - 'hunting/**/*.py' + +jobs: + code-checks: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Set up Python 3.13 + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip cache purge + pip install .[dev] + + - name: Linting check + run: | + ruff check --exit-non-zero-on-fix + + - name: Formatting check + run: | + ruff format --check + + - name: Pyright check + run: | + pyright + + - name: Python License Check + run: | + python -m detection_rules dev license-check diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 97ca3e0f025..5b454cd7187 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -20,10 +20,10 @@ jobs: run: | git fetch origin main:refs/remotes/origin/main - - name: Set up Python 3.12 + - name: Set up Python 3.13 uses: actions/setup-python@v5 with: - python-version: '3.12' + python-version: '3.13' - name: Install dependencies run: | @@ -31,14 +31,6 @@ jobs: pip cache purge pip install .[dev] - - name: Python Lint - run: | - python -m flake8 tests detection_rules --ignore D203,N815 --max-line-length 120 - - - name: Python License Check - run: | - python -m detection_rules dev license-check - - name: Unit tests env: # only run the test test_rule_change_has_updated_date on pull request events to main diff --git a/detection_rules/__init__.py b/detection_rules/__init__.py index ebf6fdb0b37..e6490fb99cd 100644 --- a/detection_rules/__init__.py +++ b/detection_rules/__init__.py @@ -25,23 +25,23 @@ rule_formatter, rule_loader, schemas, - utils + utils, ) __all__ = ( - 'custom_rules', - 'custom_schemas', - 'devtools', - 'docs', - 'eswrap', - 'ghwrap', - 'kbwrap', + "custom_rules", + "custom_schemas", + "devtools", + "docs", + "eswrap", + "ghwrap", + "kbwrap", "main", - 'misc', - 'ml', - 'navigator', - 'rule_formatter', - 'rule_loader', - 'schemas', - 'utils' + "misc", + "ml", + "navigator", + "rule_formatter", + "rule_loader", + "schemas", + "utils", ) diff --git a/detection_rules/__main__.py b/detection_rules/__main__.py index 8d576e3e014..10439a7093f 100644 --- a/detection_rules/__main__.py +++ b/detection_rules/__main__.py @@ -5,6 +5,7 @@ # coding=utf-8 """Shell for detection-rules.""" + import sys from pathlib import Path diff --git a/detection_rules/action.py b/detection_rules/action.py index 95ee9b997d2..a0e59adf8b7 100644 --- a/detection_rules/action.py +++ b/detection_rules/action.py @@ -4,9 +4,10 @@ # 2.0. """Dataclasses for Action.""" + +from typing import Any from dataclasses import dataclass from pathlib import Path -from typing import List, Optional from .mixins import MarshmallowDataclassMixin from .schemas import definitions @@ -15,43 +16,49 @@ @dataclass(frozen=True) class ActionMeta(MarshmallowDataclassMixin): """Data stored in an exception's [metadata] section of TOML.""" + creation_date: definitions.Date - rule_id: List[definitions.UUIDString] + rule_id: list[definitions.UUIDString] rule_name: str updated_date: definitions.Date # Optional fields - deprecation_date: Optional[definitions.Date] - comments: Optional[str] - maturity: Optional[definitions.Maturity] + deprecation_date: definitions.Date | None + comments: str | None + maturity: definitions.Maturity | None @dataclass class Action(MarshmallowDataclassMixin): """Data object for rule Action.""" + @dataclass class ActionParams: """Data object for rule Action params.""" + body: str action_type_id: definitions.ActionTypeId group: str params: ActionParams - id: Optional[str] - frequency: Optional[dict] - alerts_filter: Optional[dict] + + id: str | None + frequency: dict[str, Any] | None + alerts_filter: dict[str, Any] | None @dataclass(frozen=True) class TOMLActionContents(MarshmallowDataclassMixin): """Object for action from TOML file.""" + metadata: ActionMeta - actions: List[Action] + actions: list[Action] @dataclass(frozen=True) class TOMLAction: """Object for action from TOML file.""" + contents: TOMLActionContents path: Path diff --git a/detection_rules/action_connector.py b/detection_rules/action_connector.py index 8a31c2a8f0e..9fce011cb16 100644 --- a/detection_rules/action_connector.py +++ b/detection_rules/action_connector.py @@ -4,12 +4,13 @@ # 2.0. """Dataclasses for Action.""" + from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import List, Optional, Tuple +from typing import Any -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from marshmallow import EXCLUDE from .mixins import MarshmallowDataclassMixin @@ -25,14 +26,14 @@ class ActionConnectorMeta(MarshmallowDataclassMixin): creation_date: definitions.Date action_connector_name: str - rule_ids: List[definitions.UUIDString] - rule_names: List[str] + rule_ids: list[definitions.UUIDString] + rule_names: list[str] updated_date: definitions.Date # Optional fields - deprecation_date: Optional[definitions.Date] - comments: Optional[str] - maturity: Optional[definitions.Maturity] + deprecation_date: definitions.Date | None + comments: str | None + maturity: definitions.Maturity | None @dataclass @@ -40,11 +41,11 @@ class ActionConnector(MarshmallowDataclassMixin): """Data object for rule Action Connector.""" id: str - attributes: dict - frequency: Optional[dict] - managed: Optional[bool] - type: Optional[str] - references: Optional[List] + attributes: dict[str, Any] + frequency: dict[str, Any] | None + managed: bool | None + type: str | None + references: list[Any] | None @dataclass(frozen=True) @@ -52,13 +53,15 @@ class TOMLActionConnectorContents(MarshmallowDataclassMixin): """Object for action connector from TOML file.""" metadata: ActionConnectorMeta - action_connectors: List[ActionConnector] + action_connectors: list[ActionConnector] @classmethod - def from_action_connector_dict(cls, actions_dict: dict, rule_list: dict) -> "TOMLActionConnectorContents": + def from_action_connector_dict( + cls, actions_dict: dict[str, Any], rule_list: list[dict[str, Any]] + ) -> "TOMLActionConnectorContents": """Create a TOMLActionContents from a kibana rule resource.""" - rule_ids = [] - rule_names = [] + rule_ids: list[str] = [] + rule_names: list[str] = [] for rule in rule_list: rule_ids.append(rule["id"]) @@ -77,9 +80,9 @@ def from_action_connector_dict(cls, actions_dict: dict, rule_list: dict) -> "TOM return cls.from_dict({"metadata": metadata, "action_connectors": [actions_dict]}, unknown=EXCLUDE) - def to_api_format(self) -> List[dict]: + def to_api_format(self) -> list[dict[str, Any]]: """Convert the TOML Action Connector to the API format.""" - converted = [] + converted: list[dict[str, Any]] = [] for action in self.action_connectors: converted.append(action.to_dict()) @@ -109,13 +112,15 @@ def save_toml(self): contents_dict = self.contents.to_dict() # Sort the dictionary so that 'metadata' is at the top sorted_dict = dict(sorted(contents_dict.items(), key=lambda item: item[0] != "metadata")) - pytoml.dump(sorted_dict, f) + pytoml.dump(sorted_dict, f) # type: ignore[reportUnknownMemberType] -def parse_action_connector_results_from_api(results: List[dict]) -> tuple[List[dict], List[dict]]: +def parse_action_connector_results_from_api( + results: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Filter Kibana export rule results for action connector dictionaries.""" - action_results = [] - non_action_results = [] + action_results: list[dict[str, Any]] = [] + non_action_results: list[dict[str, Any]] = [] for result in results: if result.get("type") != "action": non_action_results.append(result) @@ -125,17 +130,21 @@ def parse_action_connector_results_from_api(results: List[dict]) -> tuple[List[d return action_results, non_action_results -def build_action_connector_objects(action_connectors: List[dict], action_connector_rule_table: dict, - action_connectors_directory: Path, save_toml: bool = False, - skip_errors: bool = False, verbose=False, - ) -> Tuple[List[TOMLActionConnector], List[str], List[str]]: +def build_action_connector_objects( + action_connectors: list[dict[str, Any]], + action_connector_rule_table: dict[str, Any], + action_connectors_directory: Path | None, + save_toml: bool = False, + skip_errors: bool = False, + verbose: bool = False, +) -> tuple[list[TOMLActionConnector], list[str], list[str]]: """Build TOMLActionConnector objects from a list of action connector dictionaries.""" - output = [] - errors = [] - toml_action_connectors = [] + output: list[str] = [] + errors: list[str] = [] + toml_action_connectors: list[TOMLActionConnector] = [] for action_connector_dict in action_connectors: try: - connector_id = action_connector_dict.get("id") + connector_id = action_connector_dict["id"] rule_list = action_connector_rule_table.get(connector_id) if not rule_list: output.append(f"Warning action connector {connector_id} has no associated rules. Loading skipped.") @@ -161,6 +170,7 @@ def build_action_connector_objects(action_connectors: List[dict], action_connect ) if save_toml: ac_object.save_toml() + toml_action_connectors.append(ac_object) except Exception as e: diff --git a/detection_rules/attack.py b/detection_rules/attack.py index 2178d2c8cd5..33ff2e816f8 100644 --- a/detection_rules/attack.py +++ b/detection_rules/attack.py @@ -4,10 +4,11 @@ # 2.0. """Mitre attack info.""" + import re import time from pathlib import Path -from typing import Optional +from typing import Any import json import requests @@ -16,209 +17,214 @@ from semver import Version from .utils import cached, clear_caches, get_etc_path, get_etc_glob_path, read_gzip, gzip_compress -PLATFORMS = ['Windows', 'macOS', 'Linux'] -CROSSWALK_FILE = get_etc_path('attack-crosswalk.json') -TECHNIQUES_REDIRECT_FILE = get_etc_path('attack-technique-redirects.json') +PLATFORMS = ["Windows", "macOS", "Linux"] +CROSSWALK_FILE = get_etc_path(["attack-crosswalk.json"]) +TECHNIQUES_REDIRECT_FILE = get_etc_path(["attack-technique-redirects.json"]) -tactics_map = {} +tactics_map: dict[str, Any] = {} @cached -def load_techniques_redirect() -> dict: - return json.loads(TECHNIQUES_REDIRECT_FILE.read_text())['mapping'] +def load_techniques_redirect() -> dict[str, Any]: + return json.loads(TECHNIQUES_REDIRECT_FILE.read_text())["mapping"] def get_attack_file_path() -> Path: - pattern = 'attack-v*.json.gz' - attack_file = get_etc_glob_path(pattern) + pattern = "attack-v*.json.gz" + attack_file = get_etc_glob_path([pattern]) if len(attack_file) < 1: - raise FileNotFoundError(f'Missing required {pattern} file') + raise FileNotFoundError(f"Missing required {pattern} file") elif len(attack_file) != 1: - raise FileExistsError(f'Multiple files found with {pattern} pattern. Only one is allowed') + raise FileExistsError(f"Multiple files found with {pattern} pattern. Only one is allowed") return Path(attack_file[0]) -_, _attack_path_base = str(get_attack_file_path()).split('-v') -_ext_length = len('.json.gz') +_, _attack_path_base = str(get_attack_file_path()).split("-v") +_ext_length = len(".json.gz") CURRENT_ATTACK_VERSION = _attack_path_base[:-_ext_length] -def load_attack_gz() -> dict: - +def load_attack_gz() -> dict[str, Any]: return json.loads(read_gzip(get_attack_file_path())) attack = load_attack_gz() -technique_lookup = {} -revoked = {} -deprecated = {} +technique_lookup: dict[str, Any] = {} +revoked: dict[str, Any] = {} +deprecated: dict[str, Any] = {} for item in attack["objects"]: if item["type"] == "x-mitre-tactic": - tactics_map[item['name']] = item['external_references'][0]['external_id'] + tactics_map[item["name"]] = item["external_references"][0]["external_id"] - if item["type"] == "attack-pattern" and item["external_references"][0]['source_name'] == 'mitre-attack': - technique_id = item['external_references'][0]['external_id'] + if item["type"] == "attack-pattern" and item["external_references"][0]["source_name"] == "mitre-attack": + technique_id = item["external_references"][0]["external_id"] technique_lookup[technique_id] = item - if item.get('revoked'): + if item.get("revoked"): revoked[technique_id] = item - if item.get('x_mitre_deprecated'): + if item.get("x_mitre_deprecated"): deprecated[technique_id] = item revoked = dict(sorted(revoked.items())) deprecated = dict(sorted(deprecated.items())) tactics = list(tactics_map) -matrix = {tactic: [] for tactic in tactics} -no_tactic = [] -attack_tm = 'ATT&CK\u2122' +matrix: dict[str, list[str]] = {tactic: [] for tactic in tactics} +no_tactic: list[str] = [] +attack_tm = "ATT&CK\u2122" # Enumerate over the techniques and build the matrix back up -for technique_id, technique in sorted(technique_lookup.items(), key=lambda kv: kv[1]['name'].lower()): - kill_chain = technique.get('kill_chain_phases') +for technique_id, technique in sorted(technique_lookup.items(), key=lambda kv: kv[1]["name"].lower()): + kill_chain = technique.get("kill_chain_phases") if kill_chain: for tactic in kill_chain: - tactic_name = next(t for t in tactics if tactic['kill_chain_name'] == 'mitre-attack' and t.lower() == tactic['phase_name'].replace("-", " ")) # noqa: E501 + tactic_name = next( + t + for t in tactics + if tactic["kill_chain_name"] == "mitre-attack" and t.lower() == tactic["phase_name"].replace("-", " ") + ) # noqa: E501 matrix[tactic_name].append(technique_id) else: no_tactic.append(technique_id) for tactic in matrix: - matrix[tactic].sort(key=lambda tid: technique_lookup[tid]['name'].lower()) + matrix[tactic].sort(key=lambda tid: technique_lookup[tid]["name"].lower()) technique_lookup = OrderedDict(sorted(technique_lookup.items())) -techniques = sorted({v['name'] for k, v in technique_lookup.items()}) -technique_id_list = [t for t in technique_lookup if '.' not in t] -sub_technique_id_list = [t for t in technique_lookup if '.' in t] +techniques = sorted({v["name"] for _, v in technique_lookup.items()}) +technique_id_list = [t for t in technique_lookup if "." not in t] +sub_technique_id_list = [t for t in technique_lookup if "." in t] -def refresh_attack_data(save=True) -> (Optional[dict], Optional[bytes]): +def refresh_attack_data(save: bool = True) -> tuple[dict[str, Any] | None, bytes | None]: """Refresh ATT&CK data from Mitre.""" attack_path = get_attack_file_path() - filename, _, _ = attack_path.name.rsplit('.', 2) + filename, _, _ = attack_path.name.rsplit(".", 2) - def get_version_from_tag(name, pattern='att&ck-v'): + def get_version_from_tag(name: str, pattern: str = "att&ck-v"): _, version = name.lower().split(pattern, 1) return version - current_version = Version.parse(get_version_from_tag(filename, 'attack-v'), optional_minor_and_patch=True) + current_version = Version.parse(get_version_from_tag(filename, "attack-v"), optional_minor_and_patch=True) - r = requests.get('https://api.github.com/repos/mitre/cti/tags') + r = requests.get("https://api.github.com/repos/mitre/cti/tags") r.raise_for_status() - releases = [t for t in r.json() if t['name'].startswith('ATT&CK-v')] - latest_release = max(releases, key=lambda release: Version.parse(get_version_from_tag(release['name']), - optional_minor_and_patch=True)) - release_name = latest_release['name'] + releases = [t for t in r.json() if t["name"].startswith("ATT&CK-v")] + latest_release = max( + releases, + key=lambda release: Version.parse(get_version_from_tag(release["name"]), optional_minor_and_patch=True), + ) + release_name = latest_release["name"] latest_version = Version.parse(get_version_from_tag(release_name), optional_minor_and_patch=True) if current_version >= latest_version: - print(f'No versions newer than the current detected: {current_version}') + print(f"No versions newer than the current detected: {current_version}") return None, None - download = f'https://raw.githubusercontent.com/mitre/cti/{release_name}/enterprise-attack/enterprise-attack.json' + download = f"https://raw.githubusercontent.com/mitre/cti/{release_name}/enterprise-attack/enterprise-attack.json" r = requests.get(download) r.raise_for_status() attack_data = r.json() compressed = gzip_compress(json.dumps(attack_data, sort_keys=True)) if save: - new_path = get_etc_path(f'attack-v{latest_version}.json.gz') - new_path.write_bytes(compressed) + new_path = get_etc_path([f"attack-v{latest_version}.json.gz"]) + _ = new_path.write_bytes(compressed) attack_path.unlink() - print(f'Replaced file: {attack_path} with {new_path}') + print(f"Replaced file: {attack_path} with {new_path}") return attack_data, compressed -def build_threat_map_entry(tactic: str, *technique_ids: str) -> dict: +def build_threat_map_entry(tactic: str, *technique_ids: str) -> dict[str, Any]: """Build rule threat map from technique IDs.""" techniques_redirect_map = load_techniques_redirect() - url_base = 'https://attack.mitre.org/{type}/{id}/' + url_base = "https://attack.mitre.org/{type}/{id}/" tactic_id = tactics_map[tactic] - tech_entries = {} + tech_entries: dict[str, Any] = {} - def make_entry(_id): + def make_entry(_id: str): e = { - 'id': _id, - 'name': technique_lookup[_id]['name'], - 'reference': url_base.format(type='techniques', id=_id.replace('.', '/')) + "id": _id, + "name": technique_lookup[_id]["name"], + "reference": url_base.format(type="techniques", id=_id.replace(".", "/")), } return e for tid in technique_ids: # fail if deprecated or else convert if it has been replaced if tid in deprecated: - raise ValueError(f'Technique ID: {tid} has been deprecated and should not be used') + raise ValueError(f"Technique ID: {tid} has been deprecated and should not be used") elif tid in techniques_redirect_map: tid = techniques_redirect_map[tid] if tid not in matrix[tactic]: - raise ValueError(f'Technique ID: {tid} does not fall under tactic: {tactic}') + raise ValueError(f"Technique ID: {tid} does not fall under tactic: {tactic}") # sub-techniques - if '.' in tid: - parent_technique, _ = tid.split('.', 1) + if "." in tid: + parent_technique, _ = tid.split(".", 1) tech_entries.setdefault(parent_technique, make_entry(parent_technique)) - tech_entries[parent_technique].setdefault('subtechnique', []).append(make_entry(tid)) + tech_entries[parent_technique].setdefault("subtechnique", []).append(make_entry(tid)) else: tech_entries.setdefault(tid, make_entry(tid)) - entry = { - 'framework': 'MITRE ATT&CK', - 'tactic': { - 'id': tactic_id, - 'name': tactic, - 'reference': url_base.format(type='tactics', id=tactic_id) - } + entry: dict[str, Any] = { + "framework": "MITRE ATT&CK", + "tactic": {"id": tactic_id, "name": tactic, "reference": url_base.format(type="tactics", id=tactic_id)}, } if tech_entries: - entry['technique'] = sorted(tech_entries.values(), key=lambda x: x['id']) + entry["technique"] = sorted(tech_entries.values(), key=lambda x: x["id"]) return entry -def update_threat_map(rule_threat_map): +def update_threat_map(rule_threat_map: list[dict[str, Any]]): """Update rule map techniques to reflect changes from ATT&CK.""" for entry in rule_threat_map: - for tech in entry['technique']: - tech['name'] = technique_lookup[tech['id']]['name'] + for tech in entry["technique"]: + tech["name"] = technique_lookup[tech["id"]]["name"] def retrieve_redirected_id(asset_id: str): """Get the ID for a redirected ATT&CK asset.""" if asset_id in (tactics_map.values()): - attack_type = 'tactics' + attack_type = "tactics" elif asset_id in list(technique_lookup): - attack_type = 'techniques' + attack_type = "techniques" else: - raise ValueError(f'Unknown asset_id: {asset_id}') + raise ValueError(f"Unknown asset_id: {asset_id}") - response = requests.get(f'https://attack.mitre.org/{attack_type}/{asset_id.replace(".", "/")}') + response = requests.get(f"https://attack.mitre.org/{attack_type}/{asset_id.replace('.', '/')}") text = response.text.strip().strip("'").lower() if text.startswith(' dict: +def load_crosswalk_map() -> dict[str, Any]: """Retrieve the replacement mapping.""" - return json.loads(CROSSWALK_FILE.read_text())['mapping'] + return json.loads(CROSSWALK_FILE.read_text())["mapping"] diff --git a/detection_rules/beats.py b/detection_rules/beats.py index ed2a1d9b9e6..28258d9e8a6 100644 --- a/detection_rules/beats.py +++ b/detection_rules/beats.py @@ -4,30 +4,29 @@ # 2.0. """ECS Schemas management.""" + import json import os import re -from typing import List, Optional, Union +from typing import Any -import eql +import eql # type: ignore[reportMissingTypeStubs] import requests from semver import Version import yaml -import kql +import kql # type: ignore[reportMissingTypeStubs] -from .utils import (DateTimeEncoder, cached, get_etc_path, gzip_compress, - read_gzip, unzip) +from .utils import DateTimeEncoder, cached, get_etc_path, gzip_compress, read_gzip, unzip -def _decompress_and_save_schema(url, release_name): +def _decompress_and_save_schema(url: str, release_name: str): print(f"Downloading beats {release_name}") response = requests.get(url) print(f"Downloaded {len(response.content) / 1024.0 / 1024.0:.2f} MB release.") - fs = {} - parsed = {} + fs: dict[str, Any] = {} with unzip(response.content) as archive: base_directory = archive.namelist()[0] @@ -37,18 +36,18 @@ def _decompress_and_save_schema(url, release_name): contents = archive.read(name) # chop off the base directory name - key = name[len(base_directory):] + key = name[len(base_directory) :] if key.startswith("x-pack"): - key = key[len("x-pack") + 1:] + key = key[len("x-pack") + 1 :] try: decoded = yaml.safe_load(contents) except yaml.YAMLError: print(f"Error loading {name}") + raise ValueError(f"Error loading {name}") # create a hierarchical structure - parsed[key] = decoded branch = fs directory, base_name = os.path.split(key) for limb in directory.split(os.path.sep): @@ -61,36 +60,36 @@ def _decompress_and_save_schema(url, release_name): print(f"Saving detection_rules/etc/beats_schema/{release_name}.json") compressed = gzip_compress(json.dumps(fs, sort_keys=True, cls=DateTimeEncoder)) - path = get_etc_path("beats_schemas", release_name + ".json.gz") - with open(path, 'wb') as f: - f.write(compressed) + path = get_etc_path(["beats_schemas", release_name + ".json.gz"]) + with open(path, "wb") as f: + _ = f.write(compressed) def download_beats_schema(version: str): """Download a beats schema by version.""" - url = 'https://api.github.com/repos/elastic/beats/releases' + url = "https://api.github.com/repos/elastic/beats/releases" releases = requests.get(url) - version = f'v{version.lstrip("v")}' + version = f"v{version.lstrip('v')}" beats_release = None for release in releases.json(): - if release['tag_name'] == version: + if release["tag_name"] == version: beats_release = release break if not beats_release: - print(f'beats release {version} not found!') + print(f"beats release {version} not found!") return - beats_url = beats_release['zipball_url'] - name = beats_release['tag_name'] + beats_url = beats_release["zipball_url"] + name = beats_release["tag_name"] _decompress_and_save_schema(beats_url, name) def download_latest_beats_schema(): """Download additional schemas from beats releases.""" - url = 'https://api.github.com/repos/elastic/beats/releases' + url = "https://api.github.com/repos/elastic/beats/releases" releases = requests.get(url) latest_release = max(releases.json(), key=lambda release: Version.parse(release["tag_name"].lstrip("v"))) @@ -99,15 +98,15 @@ def download_latest_beats_schema(): def refresh_main_schema(): """Download and refresh beats schema from main.""" - _decompress_and_save_schema('https://github.com/elastic/beats/archive/main.zip', 'main') + _decompress_and_save_schema("https://github.com/elastic/beats/archive/main.zip", "main") -def _flatten_schema(schema: list, prefix="") -> list: +def _flatten_schema(schema: list[dict[str, Any]] | None, prefix: str = "") -> list[dict[str, Any]]: if schema is None: # sometimes we see `fields: null` in the yaml return [] - flattened = [] + flattened: list[dict[str, Any]] = [] for s in schema: if s.get("type") == "group": nested_prefix = prefix + s["name"] + "." @@ -141,15 +140,15 @@ def _flatten_schema(schema: list, prefix="") -> list: return flattened -def flatten_ecs_schema(schema: dict) -> dict: +def flatten_ecs_schema(schema: list[dict[str, Any]]) -> list[dict[str, Any]]: return _flatten_schema(schema) -def get_field_schema(base_directory, prefix="", include_common=False): +def get_field_schema(base_directory: dict[str, Any], prefix: str = "", include_common: bool = False): base_directory = base_directory.get("folders", {}).get("_meta", {}).get("files", {}) - flattened = [] + flattened: list[dict[str, Any]] = [] - file_names = ("fields.yml", "fields.common.yml") if include_common else ("fields.yml", ) + file_names = ("fields.yml", "fields.common.yml") if include_common else ("fields.yml",) for name in file_names: if name in base_directory: @@ -158,7 +157,7 @@ def get_field_schema(base_directory, prefix="", include_common=False): return flattened -def get_beat_root_schema(schema: dict, beat: str): +def get_beat_root_schema(schema: dict[str, Any], beat: str): if beat not in schema: raise KeyError(f"Unknown beats module {beat}") @@ -168,22 +167,24 @@ def get_beat_root_schema(schema: dict, beat: str): return {field["name"]: field for field in sorted(flattened, key=lambda f: f["name"])} -def get_beats_sub_schema(schema: dict, beat: str, module: str, *datasets: str): +def get_beats_sub_schema(schema: dict[str, Any], beat: str, module: str, *datasets: str): if beat not in schema: raise KeyError(f"Unknown beats module {beat}") - flattened = [] + flattened: list[dict[str, Any]] = [] beat_dir = schema[beat] module_dir = beat_dir.get("folders", {}).get("module", {}).get("folders", {}).get(module, {}) # if we only have a module then we'll work with what we got - if not datasets: - datasets = [d for d in module_dir.get("folders", {}) if not d.startswith("_")] + if datasets: + all_datasets = datasets + else: + all_datasets = [d for d in module_dir.get("folders", {}) if not d.startswith("_")] - for dataset in datasets: + for dataset in all_datasets: # replace aws.s3 -> s3 if dataset.startswith(module + "."): - dataset = dataset[len(module) + 1:] + dataset = dataset[len(module) + 1 :] dataset_dir = module_dir.get("folders", {}).get(dataset, {}) flattened.extend(get_field_schema(dataset_dir, prefix=module + ".", include_common=True)) @@ -195,10 +196,10 @@ def get_beats_sub_schema(schema: dict, beat: str, module: str, *datasets: str): @cached -def get_versions() -> List[Version]: - versions = [] - for filename in os.listdir(get_etc_path("beats_schemas")): - version_match = re.match(r'v(.+)\.json\.gz', filename) +def get_versions() -> list[Version]: + versions: list[Version] = [] + for filename in os.listdir(get_etc_path(["beats_schemas"])): + version_match = re.match(r"v(.+)\.json\.gz", filename) if version_match: versions.append(Version.parse(version_match.groups()[0])) @@ -211,23 +212,24 @@ def get_max_version() -> str: @cached -def read_beats_schema(version: str = None): - if version and version.lower() == 'main': - return json.loads(read_gzip(get_etc_path('beats_schemas', 'main.json.gz'))) +def read_beats_schema(version: str | None = None): + if version and version.lower() == "main": + path = get_etc_path(["beats_schemas", "main.json.gz"]) + return json.loads(read_gzip(path)) - version = Version.parse(version) if version else None + ver = Version.parse(version) if version else None beats_schemas = get_versions() - if version and version not in beats_schemas: - raise ValueError(f'Unknown beats schema: {version}') + if ver and ver not in beats_schemas: + raise ValueError(f"Unknown beats schema: {ver}") version = version or get_max_version() - return json.loads(read_gzip(get_etc_path('beats_schemas', f'v{version}.json.gz'))) + return json.loads(read_gzip(get_etc_path(["beats_schemas", f"v{version}.json.gz"]))) -def get_schema_from_datasets(beats, modules, datasets, version=None): - filtered = {} +def get_schema_from_datasets(beats: list[str], modules: set[str], datasets: set[str], version: str | None = None): + filtered: dict[str, Any] = {} beats_schema = read_beats_schema(version=version) # infer the module if only a dataset are defined @@ -246,53 +248,50 @@ def get_schema_from_datasets(beats, modules, datasets, version=None): return filtered -def get_datasets_and_modules(tree: Union[eql.ast.BaseNode, kql.ast.BaseNode]) -> tuple: +def get_datasets_and_modules(tree: eql.ast.BaseNode | kql.ast.BaseNode) -> tuple[set[str], set[str]]: """Get datasets and modules from an EQL or KQL AST.""" - modules = set() - datasets = set() + modules: set[str] = set() + datasets: set[str] = set() # extract out event.module and event.dataset from the query's AST - for node in tree: - if isinstance(node, eql.ast.Comparison) and node.comparator == node.EQ and \ - isinstance(node.right, eql.ast.String): + for node in tree: # type: ignore[reportUnknownVariableType] + if ( + isinstance(node, eql.ast.Comparison) + and node.comparator == node.EQ + and isinstance(node.right, eql.ast.String) + ): if node.left == eql.ast.Field("event", ["module"]): - modules.add(node.right.render()) + modules.add(node.right.render()) # type: ignore[reportUnknownMemberType] elif node.left == eql.ast.Field("event", ["dataset"]): - datasets.add(node.right.render()) + datasets.add(node.right.render()) # type: ignore[reportUnknownMemberType] elif isinstance(node, eql.ast.InSet): if node.expression == eql.ast.Field("event", ["module"]): - modules.add(node.get_literals()) + modules.add(node.get_literals()) # type: ignore[reportUnknownMemberType] elif node.expression == eql.ast.Field("event", ["dataset"]): - datasets.add(node.get_literals()) - elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"): - modules.update(child.value for child in node.value if isinstance(child, kql.ast.String)) - elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"): - datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String)) + datasets.add(node.get_literals()) # type: ignore[reportUnknownMemberType] + elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"): # type: ignore[reportUnknownMemberType] + modules.update(child.value for child in node.value if isinstance(child, kql.ast.String)) # type: ignore[reportUnknownMemberType, reportUnknownVariableType] + elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"): # type: ignore[reportUnknownMemberType] + datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String)) # type: ignore[reportUnknownMemberType, reportUnknownVariableType] return datasets, modules -def get_schema_from_eql(tree: eql.ast.BaseNode, beats: list, version: str = None) -> dict: - """Get a schema based on datasets and modules in an EQL AST.""" - datasets, modules = get_datasets_and_modules(tree) - return get_schema_from_datasets(beats, modules, datasets, version=version) - - -def get_schema_from_kql(tree: kql.ast.BaseNode, beats: list, version: str = None) -> dict: +def get_schema_from_kql(tree: kql.ast.BaseNode, beats: list[str], version: str | None = None) -> dict[str, Any]: """Get a schema based on datasets and modules in an KQL AST.""" datasets, modules = get_datasets_and_modules(tree) return get_schema_from_datasets(beats, modules, datasets, version=version) -def parse_beats_from_index(index: Optional[list]) -> List[str]: +def parse_beats_from_index(indexes: list[str] | None) -> list[str]: """Parse beats schema types from index.""" - indexes = index or [] - beat_types = [] + indexes = indexes or [] + beat_types: list[str] = [] # Need to split on : or :: to support cross-cluster search # e.g. mycluster:logs-* -> logs-* for index in indexes: if "beat-*" in index: - index_parts = index.replace('::', ':').split(':', 1) + index_parts = index.replace("::", ":").split(":", 1) last_part = index_parts[-1] beat_type = last_part.split("-")[0] beat_types.append(beat_type) diff --git a/detection_rules/cli_utils.py b/detection_rules/cli_utils.py index 0ab967bdb69..0d61dd0d5b0 100644 --- a/detection_rules/cli_utils.py +++ b/detection_rules/cli_utils.py @@ -9,18 +9,16 @@ import os import typing from pathlib import Path -from typing import List, Optional +from typing import Callable, Any import click -import kql +import kql # type: ignore[reportMissingTypeStubs] from . import ecs from .attack import build_threat_map_entry, matrix, tactics from .rule import BYPASS_VERSION_LOCK, TOMLRule, TOMLRuleContents -from .rule_loader import (DEFAULT_PREBUILT_BBR_DIRS, - DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, - dict_filter) +from .rule_loader import DEFAULT_PREBUILT_BBR_DIRS, DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, dict_filter from .schemas import definitions from .utils import clear_caches, ensure_list_of_strings, rulename_to_filename from .config import parse_rules_config @@ -28,32 +26,33 @@ RULES_CONFIG = parse_rules_config() -def single_collection(f): +def single_collection(f: Callable[..., Any]): """Add arguments to get a RuleCollection by file, directory or a list of IDs""" - from .misc import client_error + from .misc import raise_client_error - @click.option('--rule-file', '-f', multiple=False, required=False, type=click.Path(dir_okay=False)) - @click.option('--rule-id', '-id', multiple=False, required=False) + @click.option("--rule-file", "-f", multiple=False, required=False, type=click.Path(dir_okay=False)) + @click.option("--rule-id", "-id", multiple=False, required=False) @functools.wraps(f) - def get_collection(*args, **kwargs): - rule_name: List[str] = kwargs.pop("rule_name", []) - rule_id: List[str] = kwargs.pop("rule_id", []) - rule_files: List[str] = kwargs.pop("rule_file") - directories: List[str] = kwargs.pop("directory") + def get_collection(*args: Any, **kwargs: Any): + rule_name: list[str] = kwargs.pop("rule_name", []) + rule_id: list[str] = kwargs.pop("rule_id", []) + rule_files: list[str] = kwargs.pop("rule_file") + directories: list[str] = kwargs.pop("directory") rules = RuleCollection() if bool(rule_name) + bool(rule_id) + bool(rule_files) != 1: - client_error('Required: exactly one of --rule-id, --rule-file, or --directory') + raise_client_error("Required: exactly one of --rule-id, --rule-file, or --directory") rules.load_files(Path(p) for p in rule_files) rules.load_directories(Path(d) for d in directories) if rule_id: - rules.load_directories(DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS, - obj_filter=dict_filter(rule__rule_id=rule_id)) + rules.load_directories( + DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS, obj_filter=dict_filter(rule__rule_id=rule_id) + ) if len(rules) != 1: - client_error(f"Could not find rule with ID {rule_id}") + raise_client_error(f"Could not find rule with ID {rule_id}") kwargs["rules"] = rules return f(*args, **kwargs) @@ -61,28 +60,38 @@ def get_collection(*args, **kwargs): return get_collection -def multi_collection(f): +def multi_collection(f: Callable[..., Any]): """Add arguments to get a RuleCollection by file, directory or a list of IDs""" - from .misc import client_error + from .misc import raise_client_error @click.option("--rule-file", "-f", multiple=True, type=click.Path(dir_okay=False), required=False) - @click.option("--directory", "-d", multiple=True, type=click.Path(file_okay=False), required=False, - help="Recursively load rules from a directory") + @click.option( + "--directory", + "-d", + multiple=True, + type=click.Path(file_okay=False), + required=False, + help="Recursively load rules from a directory", + ) @click.option("--rule-id", "-id", multiple=True, required=False) - @click.option("--no-tactic-filename", "-nt", is_flag=True, required=False, - help="Allow rule filenames without tactic prefix. " - "Use this if rules have been exported with this flag.") + @click.option( + "--no-tactic-filename", + "-nt", + is_flag=True, + required=False, + help="Allow rule filenames without tactic prefix. Use this if rules have been exported with this flag.", + ) @functools.wraps(f) - def get_collection(*args, **kwargs): - rule_id: List[str] = kwargs.pop("rule_id", []) - rule_files: List[str] = kwargs.pop("rule_file") - directories: List[str] = kwargs.pop("directory") + def get_collection(*args: Any, **kwargs: Any): + rule_id: list[str] = kwargs.pop("rule_id", []) + rule_files: list[str] = kwargs.pop("rule_file") + directories: list[str] = kwargs.pop("directory") no_tactic_filename: bool = kwargs.pop("no_tactic_filename", False) rules = RuleCollection() if not (directories or rule_id or rule_files or (DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS)): - client_error("Required: at least one of --rule-id, --rule-file, or --directory") + raise_client_error("Required: at least one of --rule-id, --rule-file, or --directory") rules.load_files(Path(p) for p in rule_files) rules.load_directories(Path(d) for d in directories) @@ -95,12 +104,12 @@ def get_collection(*args, **kwargs): missing = set(rule_id).difference(found_ids) if missing: - client_error(f'Could not find rules with IDs: {", ".join(missing)}') + raise_client_error(f"Could not find rules with IDs: {', '.join(missing)}") elif not rule_files and not directories: rules.load_directories(Path(d) for d in (DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS)) if len(rules) == 0: - client_error("No rules found") + raise_client_error("No rules found") # Warn that if the path does not match the expected path, it will be saved to the expected path for rule in rules: @@ -110,7 +119,9 @@ def get_collection(*args, **kwargs): no_tactic_filename = no_tactic_filename or RULES_CONFIG.no_tactic_filename tactic_name = None if no_tactic_filename else first_tactic rule_name = rulename_to_filename(rule.contents.data.name, tactic_name=tactic_name) - if rule.path.name != rule_name: + if not rule.path: + click.secho(f"WARNING: Rule path for rule not found: {rule_name}", fg="yellow") + elif rule.path.name != rule_name: click.secho( f"WARNING: Rule path does not match required path: {rule.path.name} != {rule_name}", fg="yellow" ) @@ -121,67 +132,84 @@ def get_collection(*args, **kwargs): return get_collection -def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbose=False, - additional_required: Optional[list] = None, skip_errors: bool = False, strip_none_values=True, **kwargs, - ) -> TOMLRule: +def rule_prompt( + path: Path | None = None, + rule_type: str | None = None, + required_only: bool = True, + save: bool = True, + verbose: bool = False, + additional_required: list[str] | None = None, + skip_errors: bool = False, + strip_none_values: bool = True, + **kwargs: Any, +) -> TOMLRule | str: """Prompt loop to build a rule.""" from .misc import schema_prompt additional_required = additional_required or [] creation_date = datetime.date.today().strftime("%Y/%m/%d") if verbose and path: - click.echo(f'[+] Building rule for {path}') + click.echo(f"[+] Building rule for {path}") kwargs = copy.deepcopy(kwargs) - rule_name = kwargs.get('name') + rule_name = kwargs.get("name") - if 'rule' in kwargs and 'metadata' in kwargs: - kwargs.update(kwargs.pop('metadata')) - kwargs.update(kwargs.pop('rule')) + if "rule" in kwargs and "metadata" in kwargs: + kwargs.update(kwargs.pop("metadata")) + kwargs.update(kwargs.pop("rule")) - rule_type = rule_type or kwargs.get('type') or \ - click.prompt('Rule type', type=click.Choice(typing.get_args(definitions.RuleType))) + rule_type_val = ( + rule_type + or kwargs.get("type") + or click.prompt("Rule type", type=click.Choice(typing.get_args(definitions.RuleType))) + ) - target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type) + target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type_val) schema = target_data_subclass.jsonschema() - props = schema['properties'] - required_fields = schema.get('required', []) + additional_required - contents = {} - skipped = [] + props = schema["properties"] + required_fields = schema.get("required", []) + additional_required + contents: dict[str, Any] = {} + skipped: list[str] = [] for name, options in props.items(): - - if name == 'index' and kwargs.get("type") == "esql": + if name == "index" and kwargs.get("type") == "esql": continue - if name == 'type': + if name == "type": contents[name] = rule_type continue # these are set at package release time depending on the version strategy - if (name == 'version' or name == 'revision') and not BYPASS_VERSION_LOCK: + if (name == "version" or name == "revision") and not BYPASS_VERSION_LOCK: continue if required_only and name not in required_fields: continue # build this from technique ID - if name == 'threat': - threat_map = [] + if name == "threat": + threat_map: list[dict[str, Any]] = [] if not skip_errors: - while click.confirm('add mitre tactic?'): - tactic = schema_prompt('mitre tactic name', type='string', enum=tactics, is_required=True) - technique_ids = schema_prompt(f'technique or sub-technique IDs for {tactic}', type='array', - is_required=False, enum=list(matrix[tactic])) or [] + while click.confirm("add mitre tactic?"): + tactic = schema_prompt("mitre tactic name", type="string", enum=tactics, is_required=True) + technique_ids = ( # type: ignore[reportUnknownVariableType] + schema_prompt( + f"technique or sub-technique IDs for {tactic}", + type="array", + is_required=False, + enum=list(matrix[tactic]), + ) + or [] + ) try: - threat_map.append(build_threat_map_entry(tactic, *technique_ids)) + threat_map.append(build_threat_map_entry(tactic, *technique_ids)) # type: ignore[reportUnknownArgumentType] except KeyError as e: - click.secho(f'Unknown ID: {e.args[0]} - entry not saved for: {tactic}', fg='red', err=True) + click.secho(f"Unknown ID: {e.args[0]} - entry not saved for: {tactic}", fg="red", err=True) continue except ValueError as e: - click.secho(f'{e} - entry not saved for: {tactic}', fg='red', err=True) + click.secho(f"{e} - entry not saved for: {tactic}", fg="red", err=True) continue if len(threat_map) > 0: @@ -194,7 +222,7 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos if name == "new_terms": # patch to allow new_term imports - result = {"field": "new_terms_fields"} + result: dict[str, Any] = {"field": "new_terms_fields"} new_terms_fields_value = schema_prompt("new_terms_fields", value=kwargs.pop("new_terms_fields", None)) result["value"] = ensure_list_of_strings(new_terms_fields_value) history_window_start_value = kwargs.pop("history_window_start", None) @@ -208,19 +236,19 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos else: if skip_errors: # return missing information - return f"Rule: {kwargs["id"]}, Rule Name: {rule_name} is missing {name} information" + return f"Rule: {kwargs['id']}, Rule Name: {rule_name} is missing {name} information" else: result = schema_prompt(name, is_required=name in required_fields, **options.copy()) if result: - if name not in required_fields and result == options.get('default', ''): + if name not in required_fields and result == options.get("default", ""): skipped.append(name) continue contents[name] = result # DEFAULT_PREBUILT_RULES_DIRS[0] is a required directory just as a suggestion - suggested_path = Path(DEFAULT_PREBUILT_RULES_DIRS[0]) / contents['name'] - path = Path(path or input(f'File path for rule [{suggested_path}]: ') or suggested_path).resolve() + suggested_path: Path = Path(DEFAULT_PREBUILT_RULES_DIRS[0]) / contents["name"] + path = Path(path or input(f"File path for rule [{suggested_path}]: ") or suggested_path).resolve() # Inherit maturity and optionally local dates from the rule if it already exists meta = { "creation_date": kwargs.get("creation_date") or creation_date, @@ -229,28 +257,31 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos } try: - rule = TOMLRule(path=Path(path), contents=TOMLRuleContents.from_dict({'rule': contents, 'metadata': meta})) + rule = TOMLRule(path=Path(path), contents=TOMLRuleContents.from_dict({"rule": contents, "metadata": meta})) except kql.KqlParseError as e: if skip_errors: return f"Rule: {kwargs['id']}, Rule Name: {rule_name} query failed to parse: {e.error_msg}" - if e.error_msg == 'Unknown field': - warning = ('If using a non-ECS field, you must update "ecs{}.non-ecs-schema.json" under `beats` or ' - '`legacy-endgame` (Non-ECS fields should be used minimally).'.format(os.path.sep)) - click.secho(e.args[0], fg='red', err=True) - click.secho(warning, fg='yellow', err=True) + if e.error_msg == "Unknown field": + warning = ( + 'If using a non-ECS field, you must update "ecs{}.non-ecs-schema.json" under `beats` or ' + "`legacy-endgame` (Non-ECS fields should be used minimally).".format(os.path.sep) + ) + click.secho(e.args[0], fg="red", err=True) + click.secho(warning, fg="yellow", err=True) click.pause() # if failing due to a query, loop until resolved or terminated while True: try: - contents['query'] = click.edit(contents['query'], extension='.eql') - rule = TOMLRule(path=Path(path), - contents=TOMLRuleContents.from_dict({'rule': contents, 'metadata': meta})) + contents["query"] = click.edit(contents["query"], extension=".eql") + rule = TOMLRule( + path=Path(path), contents=TOMLRuleContents.from_dict({"rule": contents, "metadata": meta}) + ) except kql.KqlParseError as e: - click.secho(e.args[0], fg='red', err=True) + click.secho(e.args[0], fg="red", err=True) click.pause() - if e.error_msg.startswith("Unknown field"): + if e.error_msg.startswith("Unknown field"): # type: ignore[reportUnknownMemberType] # get the latest schema for schema errors clear_caches() ecs.get_kql_schema(indexes=contents.get("index", [])) @@ -266,7 +297,7 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos rule.save_toml(strip_none_values=strip_none_values) if skipped: - print('Did not set the following values because they are un-required when set to the default value') - print(' - {}'.format('\n - '.join(skipped))) + print("Did not set the following values because they are un-required when set to the default value") + print(" - {}".format("\n - ".join(skipped))) return rule diff --git a/detection_rules/config.py b/detection_rules/config.py index cd2804c35f3..5bf13b7c755 100644 --- a/detection_rules/config.py +++ b/detection_rules/config.py @@ -4,42 +4,44 @@ # 2.0. """Configuration support for custom components.""" + import fnmatch import os from dataclasses import dataclass, field from pathlib import Path from functools import cached_property -from typing import Dict, List, Optional +from typing import Any import yaml -from eql.utils import load_dump +from eql.utils import load_dump # type: ignore[reportMissingTypeStubs] from .misc import discover_tests from .utils import cached, load_etc_dump, get_etc_path, set_all_validation_bypass ROOT_DIR = Path(__file__).parent.parent -CUSTOM_RULES_DIR = os.getenv('CUSTOM_RULES_DIR', None) +CUSTOM_RULES_DIR = os.getenv("CUSTOM_RULES_DIR", None) @dataclass class UnitTest: """Base object for unit tests configuration.""" - bypass: Optional[List[str]] = None - test_only: Optional[List[str]] = None + + bypass: list[str] | None = None + test_only: list[str] | None = None def __post_init__(self): - assert (self.bypass is None or self.test_only is None), \ - 'Cannot set both `test_only` and `bypass` in test_config!' + assert self.bypass is None or self.test_only is None, "Cannot set both `test_only` and `bypass` in test_config!" @dataclass class RuleValidation: """Base object for rule validation configuration.""" - bypass: Optional[List[str]] = None - test_only: Optional[List[str]] = None + + bypass: list[str] | None = None + test_only: list[str] | None = None def __post_init__(self): - assert not (self.bypass and self.test_only), 'Cannot use both test_only and bypass' + assert not (self.bypass and self.test_only), "Cannot use both test_only and bypass" @dataclass @@ -50,32 +52,30 @@ class ConfigFile: class FilePaths: packages_file: str stack_schema_map_file: str - deprecated_rules_file: Optional[str] = None - version_lock_file: Optional[str] = None + deprecated_rules_file: str | None = None + version_lock_file: str | None = None @dataclass class TestConfigPath: config: str files: FilePaths - rule_dir: List[str] - testing: Optional[TestConfigPath] = None + rule_dir: list[str] + testing: TestConfigPath | None = None @classmethod - def from_dict(cls, obj: dict) -> 'ConfigFile': - files_data = obj.get('files', {}) + def from_dict(cls, obj: dict[str, Any]) -> "ConfigFile": + files_data = obj.get("files", {}) files = cls.FilePaths( - deprecated_rules_file=files_data.get('deprecated_rules'), - packages_file=files_data['packages'], - stack_schema_map_file=files_data['stack_schema_map'], - version_lock_file=files_data.get('version_lock') + deprecated_rules_file=files_data.get("deprecated_rules"), + packages_file=files_data["packages"], + stack_schema_map_file=files_data["stack_schema_map"], + version_lock_file=files_data.get("version_lock"), ) - rule_dir = obj['rule_dirs'] + rule_dir = obj["rule_dirs"] - testing_data = obj.get('testing') - testing = cls.TestConfigPath( - config=testing_data['config'] - ) if testing_data else None + testing_data = obj.get("testing") + testing = cls.TestConfigPath(config=testing_data["config"]) if testing_data else None return cls(files=files, rule_dir=rule_dir, testing=testing) @@ -83,59 +83,69 @@ def from_dict(cls, obj: dict) -> 'ConfigFile': @dataclass class TestConfig: """Detection rules test config file""" - test_file: Optional[Path] = None - unit_tests: Optional[UnitTest] = None - rule_validation: Optional[RuleValidation] = None + + test_file: Path | None = None + unit_tests: UnitTest | None = None + rule_validation: RuleValidation | None = None @classmethod - def from_dict(cls, test_file: Optional[Path] = None, unit_tests: Optional[dict] = None, - rule_validation: Optional[dict] = None) -> 'TestConfig': - return cls(test_file=test_file or None, unit_tests=UnitTest(**unit_tests or {}), - rule_validation=RuleValidation(**rule_validation or {})) + def from_dict( + cls, + test_file: Path | None = None, + unit_tests: dict[str, Any] | None = None, + rule_validation: dict[str, Any] | None = None, + ) -> "TestConfig": + return cls( + test_file=test_file or None, + unit_tests=UnitTest(**unit_tests or {}), + rule_validation=RuleValidation(**rule_validation or {}), + ) @cached_property def all_tests(self): """Get the list of all test names.""" return discover_tests() - def tests_by_patterns(self, *patterns: str) -> List[str]: + def tests_by_patterns(self, *patterns: str) -> list[str]: """Get the list of test names by patterns.""" - tests = set() + tests: set[str] = set() for pattern in patterns: tests.update(list(fnmatch.filter(self.all_tests, pattern))) return sorted(tests) @staticmethod - def parse_out_patterns(names: List[str]) -> (List[str], List[str]): + def parse_out_patterns(names: list[str]) -> tuple[list[str], list[str]]: """Parse out test patterns from a list of test names.""" - patterns = [] - tests = [] + patterns: list[str] = [] + tests: list[str] = [] for name in names: - if name.startswith('pattern:') and '*' in name: - patterns.append(name[len('pattern:'):]) + if name.startswith("pattern:") and "*" in name: + patterns.append(name[len("pattern:") :]) else: tests.append(name) return patterns, tests @staticmethod - def format_tests(tests: List[str]) -> List[str]: + def format_tests(tests: list[str]) -> list[str]: """Format unit test names into expected format for direct calling.""" - raw = [t.rsplit('.', maxsplit=2) for t in tests] - formatted = [] + raw = [t.rsplit(".", maxsplit=2) for t in tests] + formatted: list[str] = [] for test in raw: path, clazz, method = test - path = f'{path.replace(".", os.path.sep)}.py' - formatted.append('::'.join([path, clazz, method])) + path = f"{path.replace('.', os.path.sep)}.py" + formatted.append("::".join([path, clazz, method])) return formatted - def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): + def get_test_names(self, formatted: bool = False) -> tuple[list[str], list[str]]: """Get the list of test names to run.""" + if not self.unit_tests: + raise ValueError("No unit tests defined") patterns_t, tests_t = self.parse_out_patterns(self.unit_tests.test_only or []) patterns_b, tests_b = self.parse_out_patterns(self.unit_tests.bypass or []) defined_tests = tests_t + tests_b patterns = patterns_t + patterns_b unknowns = sorted(set(defined_tests) - set(self.all_tests)) - assert not unknowns, f'Unrecognized test names in config ({self.test_file}): {unknowns}' + assert not unknowns, f"Unrecognized test names in config ({self.test_file}): {unknowns}" combined_tests = sorted(set(defined_tests + self.tests_by_patterns(*patterns))) @@ -143,8 +153,8 @@ def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): tests = combined_tests skipped = [t for t in self.all_tests if t not in tests] elif self.unit_tests.bypass: - tests = [] - skipped = [] + tests: list[str] = [] + skipped: list[str] = [] for test in self.all_tests: if test not in combined_tests: tests.append(test) @@ -161,6 +171,8 @@ def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): def check_skip_by_rule_id(self, rule_id: str) -> bool: """Check if a rule_id should be skipped.""" + if not self.rule_validation: + raise ValueError("No rule validation specified") bypass = self.rule_validation.bypass test_only = self.rule_validation.test_only @@ -168,50 +180,51 @@ def check_skip_by_rule_id(self, rule_id: str) -> bool: if not (bypass or test_only): return False # if defined in bypass or not defined in test_only, then skip - return (bypass and rule_id in bypass) or (test_only and rule_id not in test_only) + return bool((bypass and rule_id in bypass) or (test_only and rule_id not in test_only)) @dataclass class RulesConfig: """Detection rules config file.""" + deprecated_rules_file: Path - deprecated_rules: Dict[str, dict] + deprecated_rules: dict[str, dict[str, Any]] packages_file: Path - packages: Dict[str, dict] - rule_dirs: List[Path] + packages: dict[str, dict[str, Any]] + rule_dirs: list[Path] stack_schema_map_file: Path - stack_schema_map: Dict[str, dict] + stack_schema_map: dict[str, dict[str, Any]] test_config: TestConfig version_lock_file: Path - version_lock: Dict[str, dict] + version_lock: dict[str, dict[str, Any]] - action_dir: Optional[Path] = None - action_connector_dir: Optional[Path] = None - auto_gen_schema_file: Optional[Path] = None - bbr_rules_dirs: Optional[List[Path]] = field(default_factory=list) + action_dir: Path | None = None + action_connector_dir: Path | None = None + auto_gen_schema_file: Path | None = None + bbr_rules_dirs: list[Path] = field(default_factory=list) # type: ignore[reportUnknownVariableType] bypass_version_lock: bool = False - exception_dir: Optional[Path] = None + exception_dir: Path | None = None normalize_kql_keywords: bool = True bypass_optional_elastic_validation: bool = False no_tactic_filename: bool = False def __post_init__(self): """Perform post validation on packages.yaml file.""" - if 'package' not in self.packages: - raise ValueError('Missing the `package` field defined in packages.yaml.') + if "package" not in self.packages: + raise ValueError("Missing the `package` field defined in packages.yaml.") - if 'name' not in self.packages['package']: - raise ValueError('Missing the `name` field defined in packages.yaml.') + if "name" not in self.packages["package"]: + raise ValueError("Missing the `name` field defined in packages.yaml.") @cached -def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: +def parse_rules_config(path: Path | None = None) -> RulesConfig: """Parse the _config.yaml file for default or custom rules.""" if path: - assert path.exists(), f'rules config file does not exist: {path}' + assert path.exists(), f"rules config file does not exist: {path}" loaded = yaml.safe_load(path.read_text()) elif CUSTOM_RULES_DIR: - path = Path(CUSTOM_RULES_DIR) / '_config.yaml' + path = Path(CUSTOM_RULES_DIR) / "_config.yaml" if not path.exists(): raise FileNotFoundError( """ @@ -222,15 +235,15 @@ def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: ) loaded = yaml.safe_load(path.read_text()) else: - path = Path(get_etc_path('_config.yaml')) - loaded = load_etc_dump('_config.yaml') + path = Path(get_etc_path(["_config.yaml"])) + loaded = load_etc_dump(["_config.yaml"]) try: - ConfigFile.from_dict(loaded) + _ = ConfigFile.from_dict(loaded) except KeyError as e: - raise SystemExit(f'Missing key `{str(e)}` in _config.yaml file.') + raise SystemExit(f"Missing key `{str(e)}` in _config.yaml file.") except (AttributeError, TypeError): - raise SystemExit(f'No data properly loaded from {path}') + raise SystemExit(f"No data properly loaded from {path}") except ValueError as e: raise SystemExit(e) @@ -239,11 +252,11 @@ def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: # testing # precedence to the environment variable # environment variable is absolute path and config file is relative to the _config.yaml file - test_config_ev = os.getenv('DETECTION_RULES_TEST_CONFIG', None) + test_config_ev = os.getenv("DETECTION_RULES_TEST_CONFIG", None) if test_config_ev: test_config_path = Path(test_config_ev) else: - test_config_file = loaded.get('testing', {}).get('config') + test_config_file = loaded.get("testing", {}).get("config") if test_config_file: test_config_path = base_dir.joinpath(test_config_file) else: @@ -254,72 +267,72 @@ def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: # overwrite None with empty list to allow implicit exemption of all tests with `test_only` defined to None in # test config - if 'unit_tests' in test_config_data and test_config_data['unit_tests'] is not None: - test_config_data['unit_tests'] = {k: v or [] for k, v in test_config_data['unit_tests'].items()} + if "unit_tests" in test_config_data and test_config_data["unit_tests"] is not None: + test_config_data["unit_tests"] = {k: v or [] for k, v in test_config_data["unit_tests"].items()} test_config = TestConfig.from_dict(test_file=test_config_path, **test_config_data) else: test_config = TestConfig.from_dict() # files # paths are relative - files = {f'{k}_file': base_dir.joinpath(v) for k, v in loaded['files'].items()} - contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded['files'].items()} + files = {f"{k}_file": base_dir.joinpath(v) for k, v in loaded["files"].items()} + contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded["files"].items()} contents.update(**files) # directories # paths are relative - if loaded.get('directories'): - contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded['directories'].items()}) + if loaded.get("directories"): + contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded["directories"].items()}) # rule_dirs # paths are relative - contents['rule_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('rule_dirs')] + contents["rule_dirs"] = [base_dir.joinpath(d).resolve() for d in loaded.get("rule_dirs")] # directories # paths are relative - if loaded.get('directories'): - directories = loaded.get('directories') - if directories.get('exception_dir'): - contents['exception_dir'] = base_dir.joinpath(directories.get('exception_dir')).resolve() - if directories.get('action_dir'): - contents['action_dir'] = base_dir.joinpath(directories.get('action_dir')).resolve() - if directories.get('action_connector_dir'): - contents['action_connector_dir'] = base_dir.joinpath(directories.get('action_connector_dir')).resolve() + if loaded.get("directories"): + directories = loaded.get("directories") + if directories.get("exception_dir"): + contents["exception_dir"] = base_dir.joinpath(directories.get("exception_dir")).resolve() + if directories.get("action_dir"): + contents["action_dir"] = base_dir.joinpath(directories.get("action_dir")).resolve() + if directories.get("action_connector_dir"): + contents["action_connector_dir"] = base_dir.joinpath(directories.get("action_connector_dir")).resolve() # version strategy - contents['bypass_version_lock'] = loaded.get('bypass_version_lock', False) + contents["bypass_version_lock"] = loaded.get("bypass_version_lock", False) # bbr_rules_dirs # paths are relative - if loaded.get('bbr_rules_dirs'): - contents['bbr_rules_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('bbr_rules_dirs', [])] + if loaded.get("bbr_rules_dirs"): + contents["bbr_rules_dirs"] = [base_dir.joinpath(d).resolve() for d in loaded.get("bbr_rules_dirs", [])] # kql keyword normalization - contents['normalize_kql_keywords'] = loaded.get('normalize_kql_keywords', True) + contents["normalize_kql_keywords"] = loaded.get("normalize_kql_keywords", True) - if loaded.get('auto_gen_schema_file'): - contents['auto_gen_schema_file'] = base_dir.joinpath(loaded['auto_gen_schema_file']) + if loaded.get("auto_gen_schema_file"): + contents["auto_gen_schema_file"] = base_dir.joinpath(loaded["auto_gen_schema_file"]) # Check if the file exists - if not contents['auto_gen_schema_file'].exists(): + if not contents["auto_gen_schema_file"].exists(): # If the file doesn't exist, create the necessary directories and file - contents['auto_gen_schema_file'].parent.mkdir(parents=True, exist_ok=True) - contents['auto_gen_schema_file'].write_text('{}') + contents["auto_gen_schema_file"].parent.mkdir(parents=True, exist_ok=True) + _ = contents["auto_gen_schema_file"].write_text("{}") # bypass_optional_elastic_validation - contents['bypass_optional_elastic_validation'] = loaded.get('bypass_optional_elastic_validation', False) - if contents['bypass_optional_elastic_validation']: - set_all_validation_bypass(contents['bypass_optional_elastic_validation']) + contents["bypass_optional_elastic_validation"] = loaded.get("bypass_optional_elastic_validation", False) + if contents["bypass_optional_elastic_validation"]: + set_all_validation_bypass(contents["bypass_optional_elastic_validation"]) # no_tactic_filename - contents['no_tactic_filename'] = loaded.get('no_tactic_filename', False) + contents["no_tactic_filename"] = loaded.get("no_tactic_filename", False) # return the config try: - rules_config = RulesConfig(test_config=test_config, **contents) + rules_config = RulesConfig(test_config=test_config, **contents) # type: ignore[reportArgumentType] except (ValueError, TypeError) as e: - raise SystemExit(f'Error parsing packages.yaml: {str(e)}') + raise SystemExit(f"Error parsing packages.yaml: {str(e)}") return rules_config @@ -327,4 +340,4 @@ def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: @cached def load_current_package_version() -> str: """Load the current package version from config file.""" - return parse_rules_config().packages['package']['name'] + return parse_rules_config().packages["package"]["name"] diff --git a/detection_rules/custom_rules.py b/detection_rules/custom_rules.py index dd99006750e..c6e1c8221e7 100644 --- a/detection_rules/custom_rules.py +++ b/detection_rules/custom_rules.py @@ -4,6 +4,7 @@ # 2.0. """Commands for supporting custom rules.""" + from pathlib import Path import click @@ -14,11 +15,11 @@ from .main import root from .utils import ROOT_DIR, get_etc_path, load_etc_dump -DEFAULT_CONFIG_PATH = Path(get_etc_path('_config.yaml')) -CUSTOM_RULES_DOC_PATH = Path(ROOT_DIR).joinpath(REPO_DOCS_DIR, 'custom-rules-management.md') +DEFAULT_CONFIG_PATH = Path(get_etc_path(["_config.yaml"])) +CUSTOM_RULES_DOC_PATH = Path(ROOT_DIR).joinpath(REPO_DOCS_DIR, "custom-rules-management.md") -@root.group('custom-rules') +@root.group("custom-rules") def custom_rules(): """Commands for supporting custom rules.""" @@ -27,22 +28,20 @@ def create_config_content() -> str: """Create the initial content for the _config.yaml file.""" # Base structure of the configuration config_content = { - 'rule_dirs': ['rules'], - 'bbr_rules_dirs': ['rules_building_block'], - 'directories': { - 'action_dir': 'actions', - 'action_connector_dir': 'action_connectors', - 'exception_dir': 'exceptions', + "rule_dirs": ["rules"], + "bbr_rules_dirs": ["rules_building_block"], + "directories": { + "action_dir": "actions", + "action_connector_dir": "action_connectors", + "exception_dir": "exceptions", }, - 'files': { - 'deprecated_rules': 'etc/deprecated_rules.json', - 'packages': 'etc/packages.yaml', - 'stack_schema_map': 'etc/stack-schema-map.yaml', - 'version_lock': 'etc/version.lock.json', + "files": { + "deprecated_rules": "etc/deprecated_rules.json", + "packages": "etc/packages.yaml", + "stack_schema_map": "etc/stack-schema-map.yaml", + "version_lock": "etc/version.lock.json", }, - 'testing': { - 'config': 'etc/test_config.yaml' - } + "testing": {"config": "etc/test_config.yaml"}, } return yaml.safe_dump(config_content, default_flow_style=False) @@ -77,24 +76,24 @@ def format_test_string(test_string: str, comment_char: str) -> str: return "\n".join(lines) -@custom_rules.command('setup-config') -@click.argument('directory', type=Path) -@click.argument('kibana-version', type=str, default=load_etc_dump('packages.yaml')['package']['name']) -@click.option('--overwrite', is_flag=True, help="Overwrite the existing _config.yaml file.") +@custom_rules.command("setup-config") +@click.argument("directory", type=Path) +@click.argument("kibana-version", type=str, default=load_etc_dump(["packages.yaml"])["package"]["name"]) +@click.option("--overwrite", is_flag=True, help="Overwrite the existing _config.yaml file.") @click.option( "--enable-prebuilt-tests", "-e", is_flag=True, help="Enable all prebuilt tests instead of default subset." ) def setup_config(directory: Path, kibana_version: str, overwrite: bool, enable_prebuilt_tests: bool): """Setup the custom rules configuration directory and files with defaults.""" - config = directory / '_config.yaml' + config = directory / "_config.yaml" if not overwrite and config.exists(): - raise FileExistsError(f'{config} already exists. Use --overwrite to update') + raise FileExistsError(f"{config} already exists. Use --overwrite to update") - etc_dir = directory / 'etc' - test_config = etc_dir / 'test_config.yaml' - package_config = etc_dir / 'packages.yaml' - stack_schema_map_config = etc_dir / 'stack-schema-map.yaml' + etc_dir = directory / "etc" + test_config = etc_dir / "test_config.yaml" + package_config = etc_dir / "packages.yaml" + stack_schema_map_config = etc_dir / "stack-schema-map.yaml" config_files = [ package_config, stack_schema_map_config, @@ -102,49 +101,49 @@ def setup_config(directory: Path, kibana_version: str, overwrite: bool, enable_p config, ] directories = [ - directory / 'actions', - directory / 'action_connectors', - directory / 'exceptions', - directory / 'rules', - directory / 'rules_building_block', + directory / "actions", + directory / "action_connectors", + directory / "exceptions", + directory / "rules", + directory / "rules_building_block", etc_dir, ] version_files = [ - etc_dir / 'deprecated_rules.json', - etc_dir / 'version.lock.json', + etc_dir / "deprecated_rules.json", + etc_dir / "version.lock.json", ] # Create directories for dir_ in directories: dir_.mkdir(parents=True, exist_ok=True) - click.echo(f'Created directory: {dir_}') + click.echo(f"Created directory: {dir_}") # Create version_files and populate with default content if applicable for file_ in version_files: - file_.write_text('{}') - click.echo( - f'Created file with default content: {file_}' - ) + _ = file_.write_text("{}") + click.echo(f"Created file with default content: {file_}") # Create the stack-schema-map.yaml file - stack_schema_map_content = load_etc_dump('stack-schema-map.yaml') + stack_schema_map_content = load_etc_dump(["stack-schema-map.yaml"]) latest_version = max(stack_schema_map_content.keys(), key=lambda v: Version.parse(v)) latest_entry = {latest_version: stack_schema_map_content[latest_version]} - stack_schema_map_config.write_text(yaml.safe_dump(latest_entry, default_flow_style=False)) + _ = stack_schema_map_config.write_text(yaml.safe_dump(latest_entry, default_flow_style=False)) # Create default packages.yaml - package_content = {'package': {'name': kibana_version}} - package_config.write_text(yaml.safe_dump(package_content, default_flow_style=False)) + package_content = {"package": {"name": kibana_version}} + _ = package_config.write_text(yaml.safe_dump(package_content, default_flow_style=False)) # Create and configure test_config.yaml - test_config.write_text(create_test_config_content(enable_prebuilt_tests)) + _ = test_config.write_text(create_test_config_content(enable_prebuilt_tests)) # Create and configure _config.yaml - config.write_text(create_config_content()) + _ = config.write_text(create_config_content()) for file_ in config_files: - click.echo(f'Created file with default content: {file_}') + click.echo(f"Created file with default content: {file_}") - click.echo(f'\n# For details on how to configure the _config.yaml file,\n' - f'# consult: {DEFAULT_CONFIG_PATH.resolve()}\n' - f'# or the docs: {CUSTOM_RULES_DOC_PATH.resolve()}') + click.echo( + f"\n# For details on how to configure the _config.yaml file,\n" + f"# consult: {DEFAULT_CONFIG_PATH.resolve()}\n" + f"# or the docs: {CUSTOM_RULES_DOC_PATH.resolve()}" + ) diff --git a/detection_rules/custom_schemas.py b/detection_rules/custom_schemas.py index 84252178b7d..3071d840bfa 100644 --- a/detection_rules/custom_schemas.py +++ b/detection_rules/custom_schemas.py @@ -4,12 +4,13 @@ # 2.0. """Custom Schemas management.""" + import uuid from pathlib import Path +from typing import Any -import eql -import eql.types -from eql import load_dump, save_dump +import eql # type: ignore[reportMissingTypeStubs] +from eql import load_dump, save_dump # type: ignore from .config import parse_rules_config from .utils import cached, clear_caches @@ -19,9 +20,9 @@ @cached -def get_custom_schemas(stack_version: str = None) -> dict: +def get_custom_schemas(stack_version: str | None = None) -> dict[str, Any]: """Load custom schemas if present.""" - custom_schema_dump = {} + custom_schema_dump: dict[str, Any] = {} stack_versions = [stack_version] if stack_version else RULES_CONFIG.stack_schema_map.keys() @@ -34,7 +35,7 @@ def get_custom_schemas(stack_version: str = None) -> dict: if not schema_path.is_absolute(): schema_path = RULES_CONFIG.stack_schema_map_file.parent / value if schema_path.is_file(): - custom_schema_dump.update(eql.utils.load_dump(str(schema_path))) + custom_schema_dump.update(eql.utils.load_dump(str(schema_path))) # type: ignore[reportUnknownMemberType] else: raise ValueError(f"Custom schema must be a file: {schema_path}") @@ -47,13 +48,16 @@ def resolve_schema_path(path: str) -> Path: return path_obj if path_obj.is_absolute() else RULES_CONFIG.stack_schema_map_file.parent.joinpath(path) -def update_data(index: str, field: str, data: dict, field_type: str = None) -> dict: +def update_data(index: str, field: str, data: dict[str, Any], field_type: str | None = None) -> dict[str, Any]: """Update the schema entry with the appropriate index and field.""" data.setdefault(index, {})[field] = field_type if field_type else "keyword" return data -def update_stack_schema_map(stack_schema_map: dict, auto_gen_schema_file: str) -> dict: +def update_stack_schema_map( + stack_schema_map: dict[str, Any], + auto_gen_schema_file: str, +) -> tuple[dict[str, Any], str | None, str]: """Update the stack-schema-map.yaml file with the appropriate auto_gen_schema_file location.""" random_uuid = str(uuid.uuid4()) auto_generated_id = None @@ -72,7 +76,9 @@ def update_stack_schema_map(stack_schema_map: dict, auto_gen_schema_file: str) - return stack_schema_map, auto_generated_id, random_uuid -def clean_stack_schema_map(stack_schema_map: dict, auto_generated_id: str, random_uuid: str) -> dict: +def clean_stack_schema_map( + stack_schema_map: dict[str, Any], auto_generated_id: str, random_uuid: str +) -> dict[str, Any]: """Clean up the stack-schema-map.yaml file replacing the random UUID with a known key if possible.""" for version in stack_schema_map: if random_uuid in stack_schema_map[version]: @@ -80,7 +86,7 @@ def clean_stack_schema_map(stack_schema_map: dict, auto_generated_id: str, rando return stack_schema_map -def update_auto_generated_schema(index: str, field: str, field_type: str = None): +def update_auto_generated_schema(index: str, field: str, field_type: str | None = None): """Load custom schemas if present.""" auto_gen_schema_file = str(RULES_CONFIG.auto_gen_schema_file) stack_schema_map_file = str(RULES_CONFIG.stack_schema_map_file) @@ -93,6 +99,10 @@ def update_auto_generated_schema(index: str, field: str, field_type: str = None) # Update the stack-schema-map.yaml file with the appropriate auto_gen_schema_file location stack_schema_map = load_dump(stack_schema_map_file) stack_schema_map, auto_generated_id, random_uuid = update_stack_schema_map(stack_schema_map, auto_gen_schema_file) + + if not auto_generated_id: + raise ValueError("Autogenerated ID not found") + save_dump(stack_schema_map, stack_schema_map_file) # Clean up the stack-schema-map.yaml file replacing the random UUID with the auto_generated_id diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 11073780b30..97fa44b2c63 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -4,6 +4,7 @@ # 2.0. """CLI commands for internal detection_rules dev team.""" + import csv import dataclasses import io @@ -18,21 +19,21 @@ import urllib.parse from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Literal, Any import click -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] import requests.exceptions import yaml from elasticsearch import Elasticsearch -from eql.table import Table from semver import Version -from kibana.connector import Kibana +from kibana.connector import Kibana # type: ignore[reportMissingTypeStubs] +from eql.table import Table # type: ignore[reportMissingTypeStubs] +from eql.utils import load_dump # type: ignore[reportMissingTypeStubs, reportUnknownVariableType] from . import attack, rule_loader, utils -from .beats import (download_beats_schema, download_latest_beats_schema, - refresh_main_schema) +from .beats import download_beats_schema, download_latest_beats_schema, refresh_main_schema from .cli_utils import single_collection from .config import parse_rules_config from .docs import IntegrationSecurityDocs, IntegrationSecurityDocsMDX, REPO_DOCS_DIR @@ -40,29 +41,37 @@ from .endgame import EndgameSchemaManager from .eswrap import CollectEvents, add_range_to_dsl from .ghwrap import GithubClient, update_gist -from .integrations import (SecurityDetectionEngine, - build_integrations_manifest, - build_integrations_schemas, - find_latest_compatible_version, - find_latest_integration_version, - load_integrations_manifests) +from .integrations import ( + SecurityDetectionEngine, + build_integrations_manifest, + build_integrations_schemas, + find_latest_compatible_version, + find_latest_integration_version, + load_integrations_manifests, +) from .main import root -from .misc import PYTHON_LICENSE, add_client, client_error -from .packaging import (CURRENT_RELEASE_PATH, PACKAGE_FILE, RELEASE_DIR, - Package) -from .rule import (AnyRuleData, BaseRuleData, DeprecatedRule, QueryRuleData, - RuleTransform, ThreatMapping, TOMLRule, TOMLRuleContents) +from .misc import PYTHON_LICENSE, add_client, raise_client_error +from .packaging import CURRENT_RELEASE_PATH, PACKAGE_FILE, RELEASE_DIR, Package +from .rule import ( + AnyRuleData, + BaseRuleData, + QueryRuleData, + RuleTransform, + ThreatMapping, + TOMLRule, + TOMLRuleContents, + DeprecatedRule, +) from .rule_loader import RuleCollection, production_filter from .schemas import definitions, get_stack_versions -from .utils import (dict_hash, get_etc_path, get_path, check_version_lock_double_bumps, - load_dump) +from .utils import dict_hash, get_etc_path, get_path, check_version_lock_double_bumps from .version_lock import VersionLockFile, loaded_version_lock GH_CONFIG = Path.home() / ".config" / "gh" / "hosts.yml" -NAVIGATOR_GIST_ID = '0443cfb5016bed103f1940b2f336e45a' -NAVIGATOR_URL = 'https://ela.st/detection-rules-navigator-trade' +NAVIGATOR_GIST_ID = "0443cfb5016bed103f1940b2f336e45a" +NAVIGATOR_URL = "https://ela.st/detection-rules-navigator-trade" NAVIGATOR_BADGE = ( - f'[![ATT&CK navigator coverage](https://img.shields.io/badge/ATT&CK-Navigator-red.svg)]({NAVIGATOR_URL})' + f"[![ATT&CK navigator coverage](https://img.shields.io/badge/ATT&CK-Navigator-red.svg)]({NAVIGATOR_URL})" ) RULES_CONFIG = parse_rules_config() @@ -74,7 +83,7 @@ MAX_HISTORICAL_VERSIONS_PRE_DIFF = 1 -def get_github_token() -> Optional[str]: +def get_github_token() -> str | None: """Get the current user's GitHub token.""" token = os.getenv("GITHUB_TOKEN") @@ -84,60 +93,81 @@ def get_github_token() -> Optional[str]: return token -@root.group('dev') +@root.group("dev") def dev_group(): """Commands related to the Elastic Stack rules release lifecycle.""" -@dev_group.command('build-release') -@click.argument('config-file', type=click.Path(exists=True, dir_okay=False), required=False, default=PACKAGE_FILE) -@click.option('--update-version-lock', '-u', is_flag=True, - help='Save version.lock.json file with updated rule versions in the package') -@click.option('--generate-navigator', is_flag=True, help='Generate ATT&CK navigator files') -@click.option('--generate-docs', is_flag=True, default=False, help='Generate markdown documentation') -@click.option('--update-message', type=str, help='Update message for new package') +@dev_group.command("build-release") +@click.argument( + "config-file", type=click.Path(exists=True, dir_okay=False, path_type=Path), required=False, default=PACKAGE_FILE +) +@click.option( + "--update-version-lock", + "-u", + is_flag=True, + help="Save version.lock.json file with updated rule versions in the package", +) +@click.option("--generate-navigator", is_flag=True, help="Generate ATT&CK navigator files") +@click.option("--generate-docs", is_flag=True, default=False, help="Generate markdown documentation") +@click.option("--update-message", type=str, help="Update message for new package") @click.pass_context -def build_release(ctx: click.Context, config_file, update_version_lock: bool, generate_navigator: bool, - generate_docs: str, update_message: str, release=None, verbose=True): +def build_release( + ctx: click.Context, + config_file: Path, + update_version_lock: bool, + generate_navigator: bool, + generate_docs: str, + update_message: str, + release: str | None = None, + verbose: bool = True, +): """Assemble all the rules into Kibana-ready release files.""" if RULES_CONFIG.bypass_version_lock: - click.echo('WARNING: You cannot run this command when the versioning strategy is configured to bypass the ' - 'version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock.') + click.echo( + "WARNING: You cannot run this command when the versioning strategy is configured to bypass the " + "version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock." + ) ctx.exit() - config = load_dump(config_file)['package'] + config = load_dump(str(config_file))["package"] - err_msg = f'No `registry_data` in package config. Please see the {get_etc_path("package.yaml")} file for an' \ - f' example on how to supply this field in {PACKAGE_FILE}.' - assert 'registry_data' in config, err_msg + package_path = get_etc_path(["package.yaml"]) + err_msg = ( + f"No `registry_data` in package config. Please see the {package_path} file for an" + f" example on how to supply this field in {PACKAGE_FILE}." + ) + assert "registry_data" in config, err_msg - registry_data = config['registry_data'] + registry_data = config["registry_data"] if generate_navigator: - config['generate_navigator'] = True + config["generate_navigator"] = True if release is not None: - config['release'] = release + config["release"] = release if verbose: - click.echo(f'[+] Building package {config.get("name")}') + click.echo(f"[+] Building package {config.get('name')}") package = Package.from_config(config=config, verbose=verbose) if update_version_lock: - loaded_version_lock.manage_versions(package.rules, save_changes=True, verbose=verbose) + _ = loaded_version_lock.manage_versions(package.rules, save_changes=True, verbose=verbose) package.save(verbose=verbose) - previous_pkg_version = find_latest_integration_version("security_detection_engine", "ga", - registry_data['conditions']['kibana.version'].strip("^")) + previous_pkg_version = find_latest_integration_version( + "security_detection_engine", "ga", registry_data["conditions"]["kibana.version"].strip("^") + ) sde = SecurityDetectionEngine() historical_rules = sde.load_integration_assets(previous_pkg_version) - current_pkg_version = Version.parse(registry_data['version']) + current_pkg_version = Version.parse(registry_data["version"]) # pre-release versions are not included in the version comparison # Version 8.17.0-beta.1 is considered lower than 8.17.0 - current_pkg_version_no_prerelease = Version(major=current_pkg_version.major, - minor=current_pkg_version.minor, patch=current_pkg_version.patch) + current_pkg_version_no_prerelease = Version( + major=current_pkg_version.major, minor=current_pkg_version.minor, patch=current_pkg_version.patch + ) hist_versions_num = ( MAX_HISTORICAL_VERSIONS_FOR_DIFF @@ -145,78 +175,101 @@ def build_release(ctx: click.Context, config_file, update_version_lock: bool, ge else MAX_HISTORICAL_VERSIONS_PRE_DIFF ) click.echo( - '[+] Limit historical rule versions in the release package for ' - f'version {current_pkg_version_no_prerelease}: {hist_versions_num} versions') + "[+] Limit historical rule versions in the release package for " + f"version {current_pkg_version_no_prerelease}: {hist_versions_num} versions" + ) limited_historical_rules = sde.keep_latest_versions(historical_rules, num_versions=hist_versions_num) - package.add_historical_rules(limited_historical_rules, registry_data['version']) - click.echo(f'[+] Adding historical rules from {previous_pkg_version} package') + _ = package.add_historical_rules(limited_historical_rules, registry_data["version"]) + click.echo(f"[+] Adding historical rules from {previous_pkg_version} package") # NOTE: stopgap solution until security doc migration if generate_docs: - click.echo(f'[+] Generating security docs for {registry_data["version"]} package') - docs = IntegrationSecurityDocsMDX(registry_data['version'], Path(f'releases/{config["name"]}-docs'), - True, limited_historical_rules, package, note=update_message) - docs.generate() + click.echo(f"[+] Generating security docs for {registry_data['version']} package") + docs = IntegrationSecurityDocsMDX( + registry_data["version"], + Path(f"releases/{config['name']}-docs"), + True, + package, + limited_historical_rules, + note=update_message, + ) + _ = docs.generate() if verbose: - package.get_package_hash(verbose=verbose) - click.echo(f'- {len(package.rules)} rules included') + _ = package.get_package_hash(verbose=verbose) + click.echo(f"- {len(package.rules)} rules included") return package -def get_release_diff(pre: str, post: str, remote: Optional[str] = 'origin' - ) -> (Dict[str, TOMLRule], Dict[str, TOMLRule], Dict[str, DeprecatedRule]): +def get_release_diff( + pre: str, + post: str, + remote: str = "origin", +) -> tuple[dict[str, TOMLRule], dict[str, TOMLRule], dict[str, DeprecatedRule]]: """Build documents from two git tags for an integration package.""" pre_rules = RuleCollection() - pre_rules.load_git_tag(f'integration-v{pre}', remote, skip_query_validation=True) + pre_rules.load_git_tag(f"integration-v{pre}", remote, skip_query_validation=True) if pre_rules.errors: - click.echo(f'error loading {len(pre_rules.errors)} rule(s) from: {pre}, skipping:') - click.echo(' - ' + '\n - '.join([str(p) for p in pre_rules.errors])) + click.echo(f"error loading {len(pre_rules.errors)} rule(s) from: {pre}, skipping:") + click.echo(" - " + "\n - ".join([str(p) for p in pre_rules.errors])) post_rules = RuleCollection() - post_rules.load_git_tag(f'integration-v{post}', remote, skip_query_validation=True) + post_rules.load_git_tag(f"integration-v{post}", remote, skip_query_validation=True) if post_rules.errors: - click.echo(f'error loading {len(post_rules.errors)} rule(s) from: {post}, skipping:') - click.echo(' - ' + '\n - '.join([str(p) for p in post_rules.errors])) + click.echo(f"error loading {len(post_rules.errors)} rule(s) from: {post}, skipping:") + click.echo(" - " + "\n - ".join([str(p) for p in post_rules.errors])) rules_changes = pre_rules.compare_collections(post_rules) return rules_changes -@dev_group.command('build-integration-docs') -@click.argument('registry-version') -@click.option('--pre', required=True, type=str, help='Tag for pre-existing rules') -@click.option('--post', required=True, type=str, help='Tag for rules post updates') -@click.option('--directory', '-d', type=Path, required=True, help='Output directory to save docs to') -@click.option('--force', '-f', is_flag=True, help='Bypass the confirmation prompt') -@click.option('--remote', '-r', default='origin', help='Override the remote from "origin"') -@click.option('--update-message', default='Rule Updates.', type=str, help='Update message for new package') +@dev_group.command("build-integration-docs") +@click.argument("registry-version") +@click.option("--pre", required=True, type=str, help="Tag for pre-existing rules") +@click.option("--post", required=True, type=str, help="Tag for rules post updates") +@click.option("--directory", "-d", type=Path, required=True, help="Output directory to save docs to") +@click.option("--force", "-f", is_flag=True, help="Bypass the confirmation prompt") +@click.option("--remote", "-r", default="origin", help='Override the remote from "origin"') +@click.option("--update-message", default="Rule Updates.", type=str, help="Update message for new package") @click.pass_context -def build_integration_docs(ctx: click.Context, registry_version: str, pre: str, post: str, - directory: Path, force: bool, update_message: str, - remote: Optional[str] = 'origin') -> IntegrationSecurityDocs: +def build_integration_docs( + ctx: click.Context, + registry_version: str, + pre: str, + post: str, + directory: Path, + force: bool, + update_message: str, + remote: str = "origin", +) -> IntegrationSecurityDocs: """Build documents from two git tags for an integration package.""" if not force: - if not click.confirm(f'This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?'): + if not click.confirm(f"This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?"): ctx.exit(1) - assert Version.parse(pre) < Version.parse(post), f'pre: {pre} is not less than post: {post}' - assert Version.parse(pre), f'pre: {pre} is not a valid semver' - assert Version.parse(post), f'post: {post} is not a valid semver' + assert Version.parse(pre) < Version.parse(post), f"pre: {pre} is not less than post: {post}" + assert Version.parse(pre), f"pre: {pre} is not a valid semver" + assert Version.parse(post), f"post: {post} is not a valid semver" rules_changes = get_release_diff(pre, post, remote) - docs = IntegrationSecurityDocs(registry_version, directory, True, *rules_changes, update_message=update_message) + docs = IntegrationSecurityDocs( + registry_version, + directory, + True, + *rules_changes, + update_message=update_message, + ) package_dir = docs.generate() - click.echo(f'Generated documents saved to: {package_dir}') + click.echo(f"Generated documents saved to: {package_dir}") updated, new, deprecated = rules_changes - click.echo(f'- {len(updated)} updated rules') - click.echo(f'- {len(new)} new rules') - click.echo(f'- {len(deprecated)} deprecated rules') + click.echo(f"- {len(updated)} updated rules") + click.echo(f"- {len(new)} new rules") + click.echo(f"- {len(deprecated)} deprecated rules") return docs @@ -225,13 +278,17 @@ def build_integration_docs(ctx: click.Context, registry_version: str, pre: str, @click.option("--major-release", is_flag=True, help="bump the major version") @click.option("--minor-release", is_flag=True, help="bump the minor version") @click.option("--patch-release", is_flag=True, help="bump the patch version") -@click.option("--new-package", type=click.Choice(['true', 'false']), help="indicates new package") -@click.option("--maturity", type=click.Choice(['beta', 'ga'], case_sensitive=False), - required=True, help="beta or production versions") +@click.option("--new-package", type=click.Choice(["true", "false"]), help="indicates new package") +@click.option( + "--maturity", + type=click.Choice(["beta", "ga"], case_sensitive=False), + required=True, + help="beta or production versions", +) def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, new_package: str, maturity: str): """Bump the versions""" - pkg_data = RULES_CONFIG.packages['package'] + pkg_data = RULES_CONFIG.packages["package"] kibana_ver = Version.parse(pkg_data["name"], optional_minor_and_patch=True) pkg_ver = Version.parse(pkg_data["registry_data"]["version"]) pkg_kibana_ver = Version.parse(pkg_data["registry_data"]["conditions"]["kibana.version"].lstrip("^")) @@ -246,8 +303,9 @@ def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, pkg_data["registry_data"]["conditions"]["kibana.version"] = f"^{pkg_kibana_ver.bump_minor()}" pkg_data["registry_data"]["version"] = str(pkg_ver.bump_minor().bump_prerelease("beta")) if patch_release: - latest_patch_release_ver = find_latest_integration_version("security_detection_engine", - maturity, pkg_kibana_ver) + latest_patch_release_ver = find_latest_integration_version( + "security_detection_engine", maturity, pkg_kibana_ver + ) # if an existing minor or major does not have a package, bump from the last # example is 8.10.0-beta.1 is last, but on 9.0.0 major @@ -265,8 +323,8 @@ def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, latest_patch_release_ver = latest_patch_release_ver.bump_patch() pkg_data["registry_data"]["version"] = str(latest_patch_release_ver.bump_prerelease("beta")) - if 'release' in pkg_data['registry_data']: - pkg_data['registry_data']['release'] = maturity + if "release" in pkg_data["registry_data"]: + pkg_data["registry_data"]["release"] = maturity click.echo(f"Kibana version: {pkg_data['name']}") click.echo(f"Package Kibana version: {pkg_data['registry_data']['conditions']['kibana.version']}") @@ -295,7 +353,12 @@ def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, @click.option("--save-double-bumps", type=Path, help="Optional path to save the double bumps to a file") @click.pass_context def check_version_lock( - ctx: click.Context, pr_number: int, local_file: str, token: str, comment: bool, save_double_bumps: Path + ctx: click.Context, + pr_number: int, + local_file: str, + token: str, + comment: bool, + save_double_bumps: Path, ): """ Check the version lock file and optionally comment on the PR if the --comment flag is set. @@ -312,7 +375,7 @@ def check_version_lock( double_bumps = [] comment_body = "No double bumps detected." - def format_comment_body(double_bumps: list) -> str: + def format_comment_body(double_bumps: list[tuple[str, str, int, int]]) -> str: """Format the comment body for double bumps.""" comment_body = f"{len(double_bumps)} Double bumps detected:\n\n" comment_body += "
\n" @@ -325,7 +388,7 @@ def format_comment_body(double_bumps: list) -> str: comment_body += "\n
\n" return comment_body - def save_double_bumps_to_file(double_bumps: list, save_path: Path): + def save_double_bumps_to_file(double_bumps: list[tuple[str, str, int, int]], save_path: Path): """Save double bumps to a CSV file.""" save_path.parent.mkdir(parents=True, exist_ok=True) if save_path.is_file(): @@ -335,6 +398,8 @@ def save_double_bumps_to_file(double_bumps: list, save_path: Path): csv.writer(csvfile).writerows([["Rule ID", "Rule Name", "Removed", "Added"]] + double_bumps) click.echo(f"Double bumps saved to {save_path}") + pr = None + if pr_number: click.echo(f"Fetching version lock file from PR #{pr_number}") pr = repo.get_pull(pr_number) @@ -349,47 +414,50 @@ def save_double_bumps_to_file(double_bumps: list, save_path: Path): click.echo(f"{len(double_bumps)} Double bumps detected") if comment and pr_number: comment_body = format_comment_body(double_bumps) - pr.create_issue_comment(comment_body) + if pr: + _ = pr.create_issue_comment(comment_body) if save_double_bumps: save_double_bumps_to_file(double_bumps, save_double_bumps) ctx.exit(1) else: click.echo("No double bumps detected.") - if comment and pr_number: - pr.create_issue_comment(comment_body) + if comment and pr_number and pr: + _ = pr.create_issue_comment(comment_body) @dataclasses.dataclass class GitChangeEntry: status: str original_path: Path - new_path: Optional[Path] = None + new_path: Path | None = None @classmethod - def from_line(cls, text: str) -> 'GitChangeEntry': + def from_line(cls, text: str) -> "GitChangeEntry": columns = text.split("\t") assert 2 <= len(columns) <= 3 - columns[1:] = [Path(c) for c in columns[1:]] - return cls(*columns) + paths = [Path(c) for c in columns[1:]] + return cls(columns[0], *paths) @property def path(self) -> Path: return self.new_path or self.original_path - def revert(self, dry_run=False): + def revert(self, dry_run: bool = False): """Run a git command to revert this change.""" - def git(*args): + def git(*args: Any): command_line = ["git"] + [str(arg) for arg in args] click.echo(subprocess.list2cmdline(command_line)) if not dry_run: - subprocess.check_call(command_line) + _ = subprocess.check_call(command_line) if self.status.startswith("R"): # renames are actually Delete (D) and Add (A) # revert in opposite order + if not self.new_path: + raise ValueError("No new path found") GitChangeEntry("A", self.new_path).revert(dry_run=dry_run) GitChangeEntry("D", self.original_path).revert(dry_run=dry_run) return @@ -397,7 +465,7 @@ def git(*args): # remove the file from the staging area (A|M|D) git("restore", "--staged", self.original_path) - def read(self, git_tree="HEAD") -> bytes: + def read(self, git_tree: str = "HEAD") -> bytes: """Read the file from disk or git.""" if self.status == "D": # deleted files need to be recovered from git @@ -417,14 +485,14 @@ def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: } exceptions.update(exception_list.split(",")) - target_stack_version = Version.parse(target_stack_version, optional_minor_and_patch=True) + target_stack_version_parsed = Version.parse(target_stack_version, optional_minor_and_patch=True) # load a structured summary of the diff from git git_output = subprocess.check_output(["git", "diff", "--name-status", "HEAD"]) changes = [GitChangeEntry.from_line(line) for line in git_output.decode("utf-8").splitlines()] # track which changes need to be reverted because of incompatibilities - reversions: List[GitChangeEntry] = [] + reversions: list[GitChangeEntry] = [] for change in changes: if str(change.path) in exceptions: @@ -437,10 +505,11 @@ def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: if str(change.path.absolute()).startswith(str(rules_dir)) and change.path.suffix == ".toml": # bypass TOML validation in case there were schema changes dict_contents = RuleCollection.deserialize_toml_string(change.read()) - min_stack_version: Optional[str] = dict_contents.get("metadata", {}).get("min_stack_version") + min_stack_version: str | None = dict_contents.get("metadata", {}).get("min_stack_version") - if min_stack_version is not None and \ - (target_stack_version < Version.parse(min_stack_version, optional_minor_and_patch=True)): + if min_stack_version is not None and ( + target_stack_version_parsed < Version.parse(min_stack_version, optional_minor_and_patch=True) + ): # rule is incompatible, add to the list of reversions to make later reversions.append(change) break @@ -454,11 +523,11 @@ def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: change.revert(dry_run=dry_run) -@dev_group.command('update-lock-versions') -@click.argument('rule-ids', nargs=-1, required=False) +@dev_group.command("update-lock-versions") +@click.argument("rule-ids", nargs=-1, required=False) @click.pass_context -@click.option('--force', is_flag=True, help='Force update without confirmation') -def update_lock_versions(ctx: click.Context, rule_ids: Tuple[str, ...], force: bool): +@click.option("--force", is_flag=True, help="Force update without confirmation") +def update_lock_versions(ctx: click.Context, rule_ids: tuple[str, ...], force: bool): """Update rule hashes in version.lock.json file without bumping version.""" rules = RuleCollection.default() @@ -468,30 +537,32 @@ def update_lock_versions(ctx: click.Context, rule_ids: Tuple[str, ...], force: b rules = rules.filter(production_filter) if not force and not click.confirm( - f'Are you sure you want to update hashes for {len(rules)} rules without a version bump?' + f"Are you sure you want to update hashes for {len(rules)} rules without a version bump?" ): return if RULES_CONFIG.bypass_version_lock: - click.echo('WARNING: You cannot run this command when the versioning strategy is configured to bypass the ' - 'version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock.') + click.echo( + "WARNING: You cannot run this command when the versioning strategy is configured to bypass the " + "version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock." + ) ctx.exit() # this command may not function as expected anymore due to previous changes eliminating the use of add_new=False - changed, new, _ = loaded_version_lock.manage_versions(rules, exclude_version_update=True, save_changes=True) + changed, _, _ = loaded_version_lock.manage_versions(rules, exclude_version_update=True, save_changes=True) if not changed: - click.echo('No hashes updated') + click.echo("No hashes updated") return changed -@dev_group.command('kibana-diff') -@click.option('--rule-id', '-r', multiple=True, help='Optionally specify rule ID') -@click.option('--repo', default='elastic/kibana', help='Repository where branch is located') -@click.option('--branch', '-b', default='main', help='Specify the kibana branch to diff against') -@click.option('--threads', '-t', type=click.IntRange(1), default=50, help='Number of threads to use to download rules') -def kibana_diff(rule_id, repo, branch, threads): +@dev_group.command("kibana-diff") +@click.option("--rule-id", "-r", multiple=True, help="Optionally specify rule ID") +@click.option("--repo", default="elastic/kibana", help="Repository where branch is located") +@click.option("--branch", "-b", default="main", help="Specify the kibana branch to diff against") +@click.option("--threads", "-t", type=click.IntRange(1), default=50, help="Number of threads to use to download rules") +def kibana_diff(rule_id: list[str], repo: str, branch: str, threads: int): """Diff rules against their version represented in kibana if exists.""" from .misc import get_kibana_rules @@ -504,43 +575,56 @@ def kibana_diff(rule_id, repo, branch, threads): repo_hashes = {r.id: r.contents.get_hash(include_version=True) for r in rules.values()} - kibana_rules = {r['rule_id']: r for r in get_kibana_rules(repo=repo, branch=branch, threads=threads).values()} - kibana_hashes = {r['rule_id']: dict_hash(r) for r in kibana_rules.values()} + kibana_rules = {r["rule_id"]: r for r in get_kibana_rules(repo=repo, branch=branch, threads=threads).values()} + kibana_hashes = {r["rule_id"]: dict_hash(r) for r in kibana_rules.values()} missing_from_repo = list(set(kibana_hashes).difference(set(repo_hashes))) missing_from_kibana = list(set(repo_hashes).difference(set(kibana_hashes))) - rule_diff = [] - for rule_id, rule_hash in repo_hashes.items(): - if rule_id in missing_from_kibana: + rule_diff: list[str] = [] + for _rule_id, _rule_hash in repo_hashes.items(): + if _rule_id in missing_from_kibana: continue - if rule_hash != kibana_hashes[rule_id]: + if _rule_hash != kibana_hashes[_rule_id]: rule_diff.append( - f'versions - repo: {rules[rule_id].contents.autobumped_version}, ' - f'kibana: {kibana_rules[rule_id]["version"]} -> ' - f'{rule_id} - {rules[rule_id].contents.name}' + f"versions - repo: {rules[_rule_id].contents.autobumped_version}, " + f"kibana: {kibana_rules[_rule_id]['version']} -> " + f"{_rule_id} - {rules[_rule_id].contents.name}" ) - diff = { - 'missing_from_kibana': [f'{r} - {rules[r].name}' for r in missing_from_kibana], - 'diff': rule_diff, - 'missing_from_repo': [f'{r} - {kibana_rules[r]["name"]}' for r in missing_from_repo] + diff: dict[str, Any] = { + "missing_from_kibana": [f"{r} - {rules[r].name}" for r in missing_from_kibana], + "diff": rule_diff, + "missing_from_repo": [f"{r} - {kibana_rules[r]['name']}" for r in missing_from_repo], } - diff['stats'] = {k: len(v) for k, v in diff.items()} - diff['stats'].update(total_repo_prod_rules=len(rules), total_gh_prod_rules=len(kibana_rules)) + diff["stats"] = {k: len(v) for k, v in diff.items()} + diff["stats"].update(total_repo_prod_rules=len(rules), total_gh_prod_rules=len(kibana_rules)) click.echo(json.dumps(diff, indent=2, sort_keys=True)) return diff @dev_group.command("integrations-pr") -@click.argument("local-repo", type=click.Path(exists=True, file_okay=False, dir_okay=True), - default=get_path("..", "integrations")) -@click.option("--token", required=True, prompt=get_github_token() is None, default=get_github_token(), - help="GitHub token to use for the PR", hide_input=True) -@click.option("--pkg-directory", "-d", help="Directory to save the package in cloned repository", - default=Path("packages", "security_detection_engine")) +@click.argument( + "local-repo", + type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), + default=get_path(["..", "integrations"]), +) +@click.option( + "--token", + required=True, + prompt=get_github_token() is None, + default=get_github_token(), + help="GitHub token to use for the PR", + hide_input=True, +) +@click.option( + "--pkg-directory", + "-d", + help="Directory to save the package in cloned repository", + default=Path("packages", "security_detection_engine"), +) @click.option("--base-branch", "-b", help="Base branch in target repository", default="main") @click.option("--branch-name", "-n", help="New branch for the rules commit") @click.option("--github-repo", "-r", help="Repository to use for the branch", default="elastic/integrations") @@ -549,9 +633,19 @@ def kibana_diff(rule_id, repo, branch, threads): @click.option("--draft", is_flag=True, help="Open the PR as a draft") @click.option("--remote", help="Override the remote from 'origin'", default="origin") @click.pass_context -def integrations_pr(ctx: click.Context, local_repo: str, token: str, draft: bool, - pkg_directory: str, base_branch: str, remote: str, - branch_name: Optional[str], github_repo: str, assign: Tuple[str, ...], label: Tuple[str, ...]): +def integrations_pr( + ctx: click.Context, + local_repo: Path, + token: str, + draft: bool, + pkg_directory: str, + base_branch: str, + remote: str, + branch_name: str | None, + github_repo: str, + assign: tuple[str, ...], + label: tuple[str, ...], +): """Create a pull request to publish the Fleet package to elastic/integrations.""" github = GithubClient(token) github.assert_github() @@ -559,11 +653,15 @@ def integrations_pr(ctx: click.Context, local_repo: str, token: str, draft: bool repo = client.get_repo(github_repo) # Use elastic-package to format and lint - gopath = utils.gopath().strip("'\"") - assert gopath is not None, "$GOPATH isn't set" + gopath = utils.gopath() + + if not gopath: + raise ValueError("GOPATH not found") + + gopath = gopath.strip("'\"") - err = 'elastic-package missing, run: go install github.com/elastic/elastic-package@latest and verify go bin path' - assert subprocess.check_output(['elastic-package'], stderr=subprocess.DEVNULL), err + err = "elastic-package missing, run: go install github.com/elastic/elastic-package@latest and verify go bin path" + assert subprocess.check_output(["elastic-package"], stderr=subprocess.DEVNULL), err local_repo = Path(local_repo).resolve() stack_version = Package.load_configs()["name"] @@ -588,44 +686,53 @@ def integrations_pr(ctx: click.Context, local_repo: str, token: str, draft: bool # refresh the local clone of the repository git = utils.make_git("-C", local_repo) - git("checkout", base_branch) - git("pull", remote, base_branch) + _ = git("checkout", base_branch) + _ = git("pull", remote, base_branch) # Switch to a new branch in elastic/integrations branch_name = branch_name or f"detection-rules/{package_version}-{short_commit_hash}" - git("checkout", "-b", branch_name) + _ = git("checkout", "-b", branch_name) # Load the changelog in memory, before it's removed. Come back for it after the PR is created target_directory = local_repo / pkg_directory changelog_path = target_directory / "changelog.yml" - changelog_entries: list = yaml.safe_load(changelog_path.read_text(encoding="utf-8")) - - changelog_entries.insert(0, { - "version": package_version, - "changes": [ - # This will be changed later - {"description": "Release security rules update", "type": "enhancement", - "link": "https://github.com/elastic/integrations/pulls/0000"} - ] - }) + changelog_entries: list[dict[str, Any]] = yaml.safe_load(changelog_path.read_text(encoding="utf-8")) + + changelog_entries.insert( + 0, + { + "version": package_version, + "changes": [ + # This will be changed later + { + "description": "Release security rules update", + "type": "enhancement", + "link": "https://github.com/elastic/integrations/pulls/0000", + } + ], + }, + ) # Remove existing assets and replace everything shutil.rmtree(target_directory) actual_target_directory = shutil.copytree(release_dir, target_directory) - assert Path(actual_target_directory).absolute() == Path(target_directory).absolute(), \ + assert Path(actual_target_directory).absolute() == Path(target_directory).absolute(), ( f"Expected a copy to {pkg_directory}" + ) # Add the changelog back def save_changelog(): with changelog_path.open("wt") as f: # add a note for other maintainers of elastic/integrations to be careful with versions - f.write("# newer versions go on top\n") - f.write("# NOTE: please use pre-release versions (e.g. -beta.0) until a package is ready for production\n") + _ = f.write("# newer versions go on top\n") + _ = f.write( + "# NOTE: please use pre-release versions (e.g. -beta.0) until a package is ready for production\n" + ) yaml.dump(changelog_entries, f, allow_unicode=True, default_flow_style=False, indent=2, sort_keys=False) save_changelog() - def elastic_pkg(*args): + def elastic_pkg(*args: Any): """Run a command with $GOPATH/bin/elastic-package in the package directory.""" prev = Path.cwd() os.chdir(target_directory) @@ -637,12 +744,12 @@ def elastic_pkg(*args): finally: os.chdir(str(prev)) - elastic_pkg("format") + _ = elastic_pkg("format") # Upload the files to a branch - git("add", pkg_directory) - git("commit", "-m", message) - git("push", "--set-upstream", remote, branch_name) + _ = git("add", pkg_directory) + _ = git("commit", "-m", message) + _ = git("push", "--set-upstream", remote, branch_name) # Create a pull request (not done yet, but we need the PR number) body = textwrap.dedent(f""" @@ -673,14 +780,17 @@ def elastic_pkg(*args): None """) # noqa: E501 - pr = repo.create_pull(title=message, body=body, base=base_branch, head=branch_name, - maintainer_can_modify=True, draft=draft) + pr = repo.create_pull( + title=message, body=body, base=base_branch, head=branch_name, maintainer_can_modify=True, draft=draft + ) # labels could also be comma separated - label = {lbl for cs_labels in label for lbl in cs_labels.split(",") if lbl} + cs_labels_split = {lbl for cs_labels in label for lbl in cs_labels.split(",") if lbl} + + labels = sorted(list(label) + list(cs_labels_split)) - if label: - pr.add_to_labels(*sorted(label)) + if labels: + pr.add_to_labels(*labels) if assign: pr.add_to_assignees(*assign) @@ -693,26 +803,25 @@ def elastic_pkg(*args): save_changelog() # format the yml file with elastic-package - elastic_pkg("format") - elastic_pkg("lint") + _ = elastic_pkg("format") + _ = elastic_pkg("lint") # Push the updated changelog to the PR branch - git("add", pkg_directory) - git("commit", "-m", f"Add changelog entry for {package_version}") - git("push") + _ = git("add", pkg_directory) + _ = git("commit", "-m", f"Add changelog entry for {package_version}") + _ = git("push") -@dev_group.command('license-check') -@click.option('--ignore-directory', '-i', multiple=True, help='Directories to skip (relative to base)') +@dev_group.command("license-check") +@click.option("--ignore-directory", "-i", multiple=True, help="Directories to skip (relative to base)") @click.pass_context -def license_check(ctx, ignore_directory): +def license_check(ctx: click.Context, ignore_directory: list[str]): """Check that all code files contain a valid license.""" ignore_directory += ("env",) failed = False - base_path = get_path() - for path in base_path.rglob('*.py'): - relative_path = path.relative_to(base_path) + for path in utils.ROOT_DIR.rglob("*.py"): + relative_path = path.relative_to(utils.ROOT_DIR) if relative_path.parts[0] in ignore_directory: continue @@ -733,144 +842,156 @@ def license_check(ctx, ignore_directory): ctx.exit(int(failed)) -@dev_group.command('test-version-lock') -@click.argument('branches', nargs=-1, required=True) -@click.option('--remote', '-r', default='origin', help='Override the remote from "origin"') +@dev_group.command("test-version-lock") +@click.argument("branches", nargs=-1, required=True) +@click.option("--remote", "-r", default="origin", help='Override the remote from "origin"') @click.pass_context -def test_version_lock(ctx: click.Context, branches: tuple, remote: str): +def test_version_lock(ctx: click.Context, branches: list[str], remote: str): """Simulate the incremental step in the version locking to find version change violations.""" - git = utils.make_git('-C', '.') - current_branch = git('rev-parse', '--abbrev-ref', 'HEAD') + git = utils.make_git("-C", ".") + current_branch = git("rev-parse", "--abbrev-ref", "HEAD") try: - click.echo(f'iterating lock process for branches: {branches}') + click.echo(f"iterating lock process for branches: {branches}") for branch in branches: click.echo(branch) - git('checkout', f'{remote}/{branch}') - subprocess.check_call(['python', '-m', 'detection_rules', 'dev', 'build-release', '-u']) + _ = git("checkout", f"{remote}/{branch}") + _ = subprocess.check_call(["python", "-m", "detection_rules", "dev", "build-release", "-u"]) finally: - rules_config = ctx.obj['rules_config'] - diff = git('--no-pager', 'diff', str(rules_config.version_lock_file)) - outfile = get_path() / 'lock-diff.txt' - outfile.write_text(diff) - click.echo(f'diff saved to {outfile}') + rules_config = ctx.obj["rules_config"] + diff = git("--no-pager", "diff", str(rules_config.version_lock_file)) + outfile = utils.ROOT_DIR / "lock-diff.txt" + _ = outfile.write_text(diff) + click.echo(f"diff saved to {outfile}") - click.echo('reverting changes in version.lock') - git('checkout', '-f') - git('checkout', current_branch) + click.echo("reverting changes in version.lock") + _ = git("checkout", "-f") + _ = git("checkout", current_branch) -@dev_group.command('package-stats') -@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)') -@click.option('--threads', default=50, help='Number of threads to download rules from GitHub') +@dev_group.command("package-stats") +@click.option("--token", "-t", help="GitHub token to search API authenticated (may exceed threshold without auth)") +@click.option("--threads", default=50, help="Number of threads to download rules from GitHub") @click.pass_context -def package_stats(ctx, token, threads): +def package_stats(ctx: click.Context, token: str | None, threads: int): """Get statistics for current rule package.""" - current_package: Package = ctx.invoke(build_release, verbose=False, release=None) - release = f'v{current_package.name}.0' - new, modified, errors = rule_loader.load_github_pr_rules(labels=[release], token=token, threads=threads) - - click.echo(f'Total rules as of {release} package: {len(current_package.rules)}') - click.echo(f'New rules: {len(current_package.new_ids)}') - click.echo(f'Modified rules: {len(current_package.changed_ids)}') - click.echo(f'Deprecated rules: {len(current_package.removed_ids)}') - - click.echo('\n-----\n') - click.echo('Rules in active PRs for current package: ') - click.echo(f'New rules: {len(new)}') - click.echo(f'Modified rules: {len(modified)}') - - -@dev_group.command('search-rule-prs') -@click.argument('query', required=False) -@click.option('--no-loop', '-n', is_flag=True, help='Run once with no loop') -@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') -@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)') -@click.option('--threads', default=50, help='Number of threads to download rules from GitHub') + current_package: Package = ctx.invoke(build_release, verbose=False) + release = f"v{current_package.name}.0" + new, modified, _ = rule_loader.load_github_pr_rules(labels=[release], token=token, threads=threads) + + click.echo(f"Total rules as of {release} package: {len(current_package.rules)}") + click.echo(f"New rules: {len(current_package.new_ids)}") + click.echo(f"Modified rules: {len(current_package.changed_ids)}") + click.echo(f"Deprecated rules: {len(current_package.removed_ids)}") + + click.echo("\n-----\n") + click.echo("Rules in active PRs for current package: ") + click.echo(f"New rules: {len(new)}") + click.echo(f"Modified rules: {len(modified)}") + + +@dev_group.command("search-rule-prs") +@click.argument("query", required=False) +@click.option("--no-loop", "-n", is_flag=True, help="Run once with no loop") +@click.option("--columns", "-c", multiple=True, help="Specify columns to add the table") +@click.option("--language", type=click.Choice(["eql", "kql"]), default="kql") +@click.option("--token", "-t", help="GitHub token to search API authenticated (may exceed threshold without auth)") +@click.option("--threads", default=50, help="Number of threads to download rules from GitHub") @click.pass_context -def search_rule_prs(ctx, no_loop, query, columns, language, token, threads): +def search_rule_prs( + ctx: click.Context, + no_loop: bool, + query: str | None, + columns: list[str], + language: Literal["eql", "kql"], + token: str | None, + threads: int, +): """Use KQL or EQL to find matching rules from active GitHub PRs.""" from uuid import uuid4 from .main import search_rules - all_rules: Dict[Path, TOMLRule] = {} - new, modified, errors = rule_loader.load_github_pr_rules(token=token, threads=threads) + all_rules: dict[Path, TOMLRule] = {} + new, modified, _ = rule_loader.load_github_pr_rules(token=token, threads=threads) - def add_github_meta(this_rule: TOMLRule, status: str, original_rule_id: Optional[definitions.UUIDString] = None): + def add_github_meta(this_rule: TOMLRule, status: str, original_rule_id: definitions.UUIDString | None = None): pr = this_rule.gh_pr data = rule.contents.data extend_meta = { - 'status': status, - 'github': { - 'base': pr.base.label, - 'comments': [c.body for c in pr.get_comments()], - 'commits': pr.commits, - 'created_at': str(pr.created_at), - 'head': pr.head.label, - 'is_draft': pr.draft, - 'labels': [lbl.name for lbl in pr.get_labels()], - 'last_modified': str(pr.last_modified), - 'title': pr.title, - 'url': pr.html_url, - 'user': pr.user.login - } + "status": status, + "github": { + "base": pr.base.label, + "comments": [c.body for c in pr.get_comments()], + "commits": pr.commits, + "created_at": str(pr.created_at), + "head": pr.head.label, + "is_draft": pr.draft, + "labels": [lbl.name for lbl in pr.get_labels()], + "last_modified": str(pr.last_modified), + "title": pr.title, + "url": pr.html_url, + "user": pr.user.login, + }, } if original_rule_id: - extend_meta['original_rule_id'] = original_rule_id + extend_meta["original_rule_id"] = original_rule_id data = dataclasses.replace(rule.contents.data, rule_id=str(uuid4())) - rule_path = Path(f'pr-{pr.number}-{rule.path}') + rule_path = Path(f"pr-{pr.number}-{rule.path}") new_meta = dataclasses.replace(rule.contents.metadata, extended=extend_meta) contents = dataclasses.replace(rule.contents, metadata=new_meta, data=data) new_rule = TOMLRule(path=rule_path, contents=contents) + if not new_rule.path: + raise ValueError("No rule path found") all_rules[new_rule.path] = new_rule for rule_id, rule in new.items(): - add_github_meta(rule, 'new') + add_github_meta(rule, "new") for rule_id, rules in modified.items(): for rule in rules: - add_github_meta(rule, 'modified', rule_id) + add_github_meta(rule, "modified", rule_id) loop = not no_loop ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=loop) while loop: - query = click.prompt(f'Search loop - enter new {language} query or ctrl-z to exit') - columns = click.prompt('columns', default=','.join(columns)).split(',') + query = click.prompt(f"Search loop - enter new {language} query or ctrl-z to exit") + columns = click.prompt("columns", default=",".join(columns)).split(",") ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=True) -@dev_group.command('deprecate-rule') -@click.argument('rule-file', type=Path) -@click.option('--deprecation-folder', '-d', type=Path, required=True, - help='Location to move the deprecated rule file to') +@dev_group.command("deprecate-rule") +@click.argument("rule-file", type=Path) +@click.option( + "--deprecation-folder", "-d", type=Path, required=True, help="Location to move the deprecated rule file to" +) @click.pass_context def deprecate_rule(ctx: click.Context, rule_file: Path, deprecation_folder: Path): """Deprecate a rule.""" version_info = loaded_version_lock.version_lock rule_collection = RuleCollection() contents = rule_collection.load_file(rule_file).contents - rule = TOMLRule(path=rule_file, contents=contents) + rule = TOMLRule(path=rule_file, contents=contents) # type: ignore[reportArgumentType] if rule.contents.id not in version_info and not RULES_CONFIG.bypass_version_lock: - click.echo('Rule has not been version locked and so does not need to be deprecated. ' - 'Delete the file or update the maturity to `development` instead.') + click.echo( + "Rule has not been version locked and so does not need to be deprecated. " + "Delete the file or update the maturity to `development` instead." + ) ctx.exit() - today = time.strftime('%Y/%m/%d') + today = time.strftime("%Y/%m/%d") deprecated_path = deprecation_folder / rule_file.name # create the new rule and save it - new_meta = dataclasses.replace(rule.contents.metadata, - updated_date=today, - deprecation_date=today, - maturity='deprecated') + new_meta = dataclasses.replace( + rule.contents.metadata, updated_date=today, deprecation_date=today, maturity="deprecated" + ) contents = dataclasses.replace(rule.contents, metadata=new_meta) new_rule = TOMLRule(contents=contents, path=deprecated_path) deprecated_path.parent.mkdir(parents=True, exist_ok=True) @@ -878,72 +999,87 @@ def deprecate_rule(ctx: click.Context, rule_file: Path, deprecation_folder: Path # remove the old rule rule_file.unlink() - click.echo(f'Rule moved to {deprecated_path} - remember to git add this file') - - -@dev_group.command('update-navigator-gists') -@click.option('--directory', type=Path, default=CURRENT_RELEASE_PATH.joinpath('extras', 'navigator_layers'), - help='Directory containing only navigator files.') -@click.option('--token', required=True, prompt=get_github_token() is None, default=get_github_token(), - help='GitHub token to push to gist', hide_input=True) -@click.option('--gist-id', default=NAVIGATOR_GIST_ID, help='Gist ID to be updated (must exist).') -@click.option('--print-markdown', is_flag=True, help='Print the generated urls') -@click.option('--update-coverage', is_flag=True, help=f'Update the {REPO_DOCS_DIR}/ATT&CK-coverage.md file') -def update_navigator_gists(directory: Path, token: str, gist_id: str, print_markdown: bool, - update_coverage: bool) -> list: + click.echo(f"Rule moved to {deprecated_path} - remember to git add this file") + + +@dev_group.command("update-navigator-gists") +@click.option( + "--directory", + type=Path, + default=CURRENT_RELEASE_PATH.joinpath("extras", "navigator_layers"), + help="Directory containing only navigator files.", +) +@click.option( + "--token", + required=True, + prompt=get_github_token() is None, + default=get_github_token(), + help="GitHub token to push to gist", + hide_input=True, +) +@click.option("--gist-id", default=NAVIGATOR_GIST_ID, help="Gist ID to be updated (must exist).") +@click.option("--print-markdown", is_flag=True, help="Print the generated urls") +@click.option("--update-coverage", is_flag=True, help=f"Update the {REPO_DOCS_DIR}/ATT&CK-coverage.md file") +def update_navigator_gists( + directory: Path, + token: str, + gist_id: str, + print_markdown: bool, + update_coverage: bool, +): """Update the gists with new navigator files.""" - assert directory.exists(), f'{directory} does not exist' + assert directory.exists(), f"{directory} does not exist" - def raw_permalink(raw_link): + def raw_permalink(raw_link: str): # Gist file URLs change with each revision, but can be permalinked to the latest by removing the hash after raw - prefix, _, suffix = raw_link.rsplit('/', 2) - return '/'.join([prefix, suffix]) + prefix, _, suffix = raw_link.rsplit("/", 2) + return "/".join([prefix, suffix]) - file_map = {f: f.read_text() for f in directory.glob('*.json')} + file_map = {f: f.read_text() for f in directory.glob("*.json")} try: - response = update_gist(token, - file_map, - description='ATT&CK Navigator layer files.', - gist_id=gist_id, - pre_purge=True) + response = update_gist( + token, file_map, description="ATT&CK Navigator layer files.", gist_id=gist_id, pre_purge=True + ) except requests.exceptions.HTTPError as exc: if exc.response.status_code == requests.status_codes.codes.not_found: - raise client_error('Gist not found: verify the gist_id exists and the token has access to it', exc=exc) + raise raise_client_error( + "Gist not found: verify the gist_id exists and the token has access to it", exc=exc + ) else: raise response_data = response.json() - raw_urls = {name: raw_permalink(data['raw_url']) for name, data in response_data['files'].items()} + raw_urls = {name: raw_permalink(data["raw_url"]) for name, data in response_data["files"].items()} - base_url = 'https://mitre-attack.github.io/attack-navigator/#layerURL={}&leave_site_dialog=false&tabs=false' + base_url = "https://mitre-attack.github.io/attack-navigator/#layerURL={}&leave_site_dialog=false&tabs=false" # pull out full and platform coverage to print on top of markdown table - all_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop('Elastic-detection-rules-all.json'))) - platforms_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop('Elastic-detection-rules-platforms.json'))) + all_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop("Elastic-detection-rules-all.json"))) + platforms_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop("Elastic-detection-rules-platforms.json"))) generated_urls = [all_url, platforms_url] - markdown_links = [] + markdown_links: list[str] = [] for name, gist_url in raw_urls.items(): query = urllib.parse.quote_plus(gist_url) - url = f'https://mitre-attack.github.io/attack-navigator/#layerURL={query}&leave_site_dialog=false&tabs=false' + url = f"https://mitre-attack.github.io/attack-navigator/#layerURL={query}&leave_site_dialog=false&tabs=false" generated_urls.append(url) - link_name = name.split('.')[0] - markdown_links.append(f'|[{link_name}]({url})|') + link_name = name.split(".")[0] + markdown_links.append(f"|[{link_name}]({url})|") markdown = [ - f'**Full coverage**: {NAVIGATOR_BADGE}', - '\n', - f'**Coverage by platform**: [navigator]({platforms_url})', - '\n', - '| other navigator links by rule attributes |', - '|------------------------------------------|', + f"**Full coverage**: {NAVIGATOR_BADGE}", + "\n", + f"**Coverage by platform**: [navigator]({platforms_url})", + "\n", + "| other navigator links by rule attributes |", + "|------------------------------------------|", ] + markdown_links if print_markdown: - click.echo('\n'.join(markdown) + '\n') + click.echo("\n".join(markdown) + "\n") if update_coverage: - coverage_file_path = get_path(REPO_DOCS_DIR, 'ATT&CK-coverage.md') + coverage_file_path = get_path([REPO_DOCS_DIR, "ATT&CK-coverage.md"]) header_lines = textwrap.dedent("""# Rule coverage ATT&CK navigator layer files are generated when a package is built with `make release` or @@ -958,59 +1094,65 @@ def raw_permalink(raw_link): The source files for these links are regenerated with every successful merge to main. These represent coverage from the state of rules in the `main` branch. """) - updated_file = header_lines + '\n\n' + '\n'.join(markdown) + '\n' + updated_file = header_lines + "\n\n" + "\n".join(markdown) + "\n" # Replace the old URLs with the new ones - with open(coverage_file_path, 'w') as md_file: - md_file.write(updated_file) - click.echo(f'Updated ATT&CK coverage URL(s) in {coverage_file_path}' + '\n') + with open(coverage_file_path, "w") as md_file: + _ = md_file.write(updated_file) + click.echo(f"Updated ATT&CK coverage URL(s) in {coverage_file_path}" + "\n") - click.echo(f'Gist update status on {len(generated_urls)} files: {response.status_code} {response.reason}') + click.echo(f"Gist update status on {len(generated_urls)} files: {response.status_code} {response.reason}") return generated_urls -@dev_group.command('trim-version-lock') -@click.argument('stack_version') -@click.option('--skip-rule-updates', is_flag=True, help='Skip updating the rules') -@click.option('--dry-run', is_flag=True, help='Print the changes rather than saving the file') +@dev_group.command("trim-version-lock") +@click.argument("stack_version") +@click.option("--skip-rule-updates", is_flag=True, help="Skip updating the rules") +@click.option("--dry-run", is_flag=True, help="Print the changes rather than saving the file") @click.pass_context def trim_version_lock(ctx: click.Context, stack_version: str, skip_rule_updates: bool, dry_run: bool): """Trim all previous entries within the version lock file which are lower than the min_version.""" stack_versions = get_stack_versions() - assert stack_version in stack_versions, \ - f'Unknown min_version ({stack_version}), expected: {", ".join(stack_versions)}' + assert stack_version in stack_versions, ( + f"Unknown min_version ({stack_version}), expected: {', '.join(stack_versions)}" + ) min_version = Version.parse(stack_version) if RULES_CONFIG.bypass_version_lock: - click.echo('WARNING: Cannot trim the version lock when the versioning strategy is configured to bypass the ' - 'version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock.') + click.echo( + "WARNING: Cannot trim the version lock when the versioning strategy is configured to bypass the " + "version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) ctx.exit() version_lock_dict = loaded_version_lock.version_lock.to_dict() - removed = defaultdict(list) - rule_msv_drops = [] + removed: dict[str, list[str]] = defaultdict(list) + rule_msv_drops: list[str] = [] - today = time.strftime('%Y/%m/%d') + today = time.strftime("%Y/%m/%d") rc: RuleCollection | None = None if dry_run: rc = RuleCollection() else: if not skip_rule_updates: - click.echo('Loading rules ...') + click.echo("Loading rules ...") rc = RuleCollection.default() + if not rc: + raise ValueError("No rule collection found") + for rule_id, lock in version_lock_dict.items(): file_min_stack: Version | None = None - if 'min_stack_version' in lock: - file_min_stack = Version.parse((lock['min_stack_version']), optional_minor_and_patch=True) + if "min_stack_version" in lock: + file_min_stack = Version.parse((lock["min_stack_version"]), optional_minor_and_patch=True) if file_min_stack <= min_version: removed[rule_id].append( - f'locked min_stack_version <= {min_version} - {"will remove" if dry_run else "removing"}!' + f"locked min_stack_version <= {min_version} - {'will remove' if dry_run else 'removing'}!" ) rule_msv_drops.append(rule_id) file_min_stack = None if not dry_run: - lock.pop('min_stack_version') + lock.pop("min_stack_version") if not skip_rule_updates: # remove the min_stack_version and min_stack_comments from rules as well (and update date) rule = rc.id_map.get(rule_id) @@ -1019,17 +1161,17 @@ def trim_version_lock(ctx: click.Context, stack_version: str, skip_rule_updates: rule.contents.metadata, updated_date=today, min_stack_version=None, - min_stack_comments=None + min_stack_comments=None, ) contents = dataclasses.replace(rule.contents, metadata=new_meta) new_rule = TOMLRule(contents=contents, path=rule.path) new_rule.save_toml() - removed[rule_id].append('rule min_stack_version dropped') + removed[rule_id].append("rule min_stack_version dropped") else: - removed[rule_id].append('rule not found to update!') + removed[rule_id].append("rule not found to update!") - if 'previous' in lock: - prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock['previous'])] + if "previous" in lock: + prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock["previous"])] outdated_vers = [v for v in prev_vers if v < min_version] if not outdated_vers: @@ -1041,60 +1183,60 @@ def trim_version_lock(ctx: click.Context, stack_version: str, skip_rule_updates: for outdated in outdated_vers: short_outdated = f"{outdated.major}.{outdated.minor}" - popped = lock['previous'].pop(str(short_outdated)) + popped = lock["previous"].pop(str(short_outdated)) # the core of the update - we only need to keep previous entries that are newer than the min supported # version (from stack-schema-map and stack-version parameter) and older than the locked # min_stack_version for a given rule, if one exists if file_min_stack and outdated == latest_version and outdated < file_min_stack: - lock['previous'][f'{min_version.major}.{min_version.minor}'] = popped - removed[rule_id].append(f'{short_outdated} updated to: {min_version.major}.{min_version.minor}') + lock["previous"][f"{min_version.major}.{min_version.minor}"] = popped + removed[rule_id].append(f"{short_outdated} updated to: {min_version.major}.{min_version.minor}") else: - removed[rule_id].append(f'{outdated} dropped') + removed[rule_id].append(f"{outdated} dropped") # remove the whole previous entry if it is now blank - if not lock['previous']: - lock.pop('previous') + if not lock["previous"]: + lock.pop("previous") - click.echo(f'Changes {"that will be " if dry_run else ""} applied:' if removed else 'No changes') - click.echo('\n'.join(f'{k}: {", ".join(v)}' for k, v in removed.items())) + click.echo(f"Changes {'that will be ' if dry_run else ''} applied:" if removed else "No changes") + click.echo("\n".join(f"{k}: {', '.join(v)}" for k, v in removed.items())) if not dry_run: new_lock = VersionLockFile.from_dict(dict(data=version_lock_dict)) new_lock.save_to_file() -@dev_group.group('diff') +@dev_group.group("diff") def diff_group(): """Commands for statistics on changes and diffs.""" -@diff_group.command('endpoint-by-attack') -@click.option('--pre', required=True, help='Tag for pre-existing rules') -@click.option('--post', required=True, help='Tag for rules post updates') -@click.option('--force', '-f', is_flag=True, help='Bypass the confirmation prompt') -@click.option('--remote', '-r', default='origin', help='Override the remote from "origin"') +@diff_group.command("endpoint-by-attack") +@click.option("--pre", required=True, help="Tag for pre-existing rules") +@click.option("--post", required=True, help="Tag for rules post updates") +@click.option("--force", "-f", is_flag=True, help="Bypass the confirmation prompt") +@click.option("--remote", "-r", default="origin", help='Override the remote from "origin"') @click.pass_context -def endpoint_by_attack(ctx: click.Context, pre: str, post: str, force: bool, remote: Optional[str] = 'origin'): +def endpoint_by_attack(ctx: click.Context, pre: str, post: str, force: bool, remote: str = "origin"): """Rule diffs across tagged branches, broken down by ATT&CK tactics.""" if not force: - if not click.confirm(f'This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?'): + if not click.confirm(f"This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?"): ctx.exit(1) changed, new, deprecated = get_release_diff(pre, post, remote) - oses = ('windows', 'linux', 'macos') + oses = ("windows", "linux", "macos") - def delta_stats(rule_map) -> List[dict]: - stats = defaultdict(lambda: defaultdict(int)) - os_totals = defaultdict(int) - tactic_totals = defaultdict(int) + def delta_stats(rule_map: dict[str, TOMLRule] | dict[str, DeprecatedRule]) -> list[dict[str, Any]]: + stats: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + os_totals: dict[str, int] = defaultdict(int) + tactic_totals: dict[str, int] = defaultdict(int) - for rule_id, rule in rule_map.items(): - threat = rule.contents.data.get('threat') - os_types = [i.lower() for i in rule.contents.data.get('tags') or [] if i.lower() in oses] + for _, rule in rule_map.items(): + threat = rule.contents.data.get("threat") + os_types: list[str] = [i.lower() for i in rule.contents.data.get("tags") or [] if i.lower() in oses] # type: ignore[reportUnknownVariableType] if not threat or not os_types: continue if isinstance(threat[0], dict): - tactics = sorted(set(e['tactic']['name'] for e in threat)) + tactics = sorted(set(e["tactic"]["name"] for e in threat)) else: tactics = ThreatMapping.flatten(threat).tactic_names for tactic in tactics: @@ -1104,138 +1246,181 @@ def delta_stats(rule_map) -> List[dict]: stats[tactic][os_type] += 1 # structure stats for table - rows = [] + rows: list[dict[str, Any]] = [] for tac, stat in stats.items(): - row = {'tactic': tac, 'total': tactic_totals[tac]} + row: dict[str, Any] = {"tactic": tac, "total": tactic_totals[tac]} for os_type, count in stat.items(): row[os_type] = count rows.append(row) - rows.append(dict(tactic='total_by_os', **os_totals)) - + rows.append(dict(tactic="total_by_os", **os_totals)) return rows - fields = ['tactic', 'linux', 'macos', 'windows', 'total'] + fields = ["tactic", "linux", "macos", "windows", "total"] changed_stats = delta_stats(changed) - table = Table.from_list(fields, changed_stats) - click.echo(f'Changed rules {len(changed)}\n{table}\n') + table = Table.from_list(fields, changed_stats) # type: ignore[reportUnknownMemberType] + click.echo(f"Changed rules {len(changed)}\n{table}\n") new_stats = delta_stats(new) - table = Table.from_list(fields, new_stats) - click.echo(f'New rules {len(new)}\n{table}\n') + table = Table.from_list(fields, new_stats) # type: ignore[reportUnknownMemberType] + click.echo(f"New rules {len(new)}\n{table}\n") dep_stats = delta_stats(deprecated) - table = Table.from_list(fields, dep_stats) - click.echo(f'Deprecated rules {len(deprecated)}\n{table}\n') + table = Table.from_list(fields, dep_stats) # type: ignore[reportUnknownMemberType] + click.echo(f"Deprecated rules {len(deprecated)}\n{table}\n") return changed_stats, new_stats, dep_stats -@dev_group.group('test') +@dev_group.group("test") def test_group(): """Commands for testing against stack resources.""" -@test_group.command('event-search') -@click.argument('query') -@click.option('--index', '-i', multiple=True, help='Index patterns to search against') -@click.option('--eql/--lucene', '-e/-l', 'language', default=None, help='Query language used (default: kql)') -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--count', '-c', is_flag=True, help='Return count of results only') -@click.option('--max-results', '-m', type=click.IntRange(1, 1000), default=100, - help='Max results to return (capped at 1000)') -@click.option('--verbose', '-v', is_flag=True, default=True) -@add_client('elasticsearch') -def event_search(query, index, language, date_range, count, max_results, verbose=True, - elasticsearch_client: Elasticsearch = None): +@test_group.command("event-search") +@click.argument("query") +@click.option("--index", "-i", multiple=True, help="Index patterns to search against") +@click.option("--eql/--lucene", "-e/-l", "language", default=None, help="Query language used (default: kql)") +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option("--count", "-c", is_flag=True, help="Return count of results only") +@click.option( + "--max-results", + "-m", + type=click.IntRange(1, 1000), + default=100, + help="Max results to return (capped at 1000)", +) +@click.option("--verbose", "-v", is_flag=True, default=True) +@add_client(["elasticsearch"]) +def event_search( + query: str, + index: list[str], + language: str | None, + date_range: tuple[str, str], + count: bool, + max_results: int, + elasticsearch_client: Elasticsearch, + verbose: bool = True, +): """Search using a query against an Elasticsearch instance.""" start_time, end_time = date_range - index = index or ('*',) - language_used = "kql" if language is None else "eql" if language is True else "lucene" + index = index or ["*"] + language_used = "kql" if language is None else "eql" if language else "lucene" collector = CollectEvents(elasticsearch_client, max_results) if verbose: - click.echo(f'searching {",".join(index)} from {start_time} to {end_time}') - click.echo(f'{language_used}: {query}') + click.echo(f"searching {','.join(index)} from {start_time} to {end_time}") + click.echo(f"{language_used}: {query}") if count: results = collector.count(query, language_used, index, start_time, end_time) - click.echo(f'total results: {results}') + click.echo(f"total results: {results}") else: results = collector.search(query, language_used, index, start_time, end_time, max_results) - click.echo(f'total results: {len(results)} (capped at {max_results})') + click.echo(f"total results: {len(results)} (capped at {max_results})") click.echo_via_pager(json.dumps(results, indent=2, sort_keys=True)) return results -@test_group.command('rule-event-search') +@test_group.command("rule-event-search") @single_collection -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--count', '-c', is_flag=True, help='Return count of results only') -@click.option('--max-results', '-m', type=click.IntRange(1, 1000), default=100, - help='Max results to return (capped at 1000)') -@click.option('--verbose', '-v', is_flag=True) +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option("--count", "-c", is_flag=True, help="Return count of results only") +@click.option( + "--max-results", + "-m", + type=click.IntRange(1, 1000), + default=100, + help="Max results to return (capped at 1000)", +) +@click.option("--verbose", "-v", is_flag=True) @click.pass_context -@add_client('elasticsearch') -def rule_event_search(ctx, rule, date_range, count, max_results, verbose, - elasticsearch_client: Elasticsearch = None): +@add_client(["elasticsearch"]) +def rule_event_search( + ctx: click.Context, + rule: Any, + date_range: tuple[str, str], + count: bool, + max_results: int, + elasticsearch_client: Elasticsearch, + verbose: bool = False, +): """Search using a rule file against an Elasticsearch instance.""" if isinstance(rule.contents.data, QueryRuleData): if verbose: - click.echo(f'Searching rule: {rule.name}') + click.echo(f"Searching rule: {rule.name}") data = rule.contents.data rule_lang = data.language - if rule_lang == 'kuery': + if rule_lang == "kuery": language_flag = None - elif rule_lang == 'eql': + elif rule_lang == "eql": language_flag = True else: language_flag = False - index = data.index or ['*'] - ctx.invoke(event_search, query=data.query, index=index, language=language_flag, - date_range=date_range, count=count, max_results=max_results, verbose=verbose, - elasticsearch_client=elasticsearch_client) + index = data.index or ["*"] + ctx.invoke( + event_search, + query=data.query, + index=index, + language=language_flag, + date_range=date_range, + count=count, + max_results=max_results, + verbose=verbose, + elasticsearch_client=elasticsearch_client, + ) else: - client_error('Rule is not a query rule!') + raise_client_error("Rule is not a query rule!") -@test_group.command('rule-survey') -@click.argument('query', required=False) -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--dump-file', type=click.Path(dir_okay=False), - default=get_path('surveys', f'{time.strftime("%Y%m%dT%H%M%SL")}.json'), - help='Save details of results (capped at 1000 results/rule)') -@click.option('--hide-zero-counts', '-z', is_flag=True, help='Exclude rules with zero hits from printing') -@click.option('--hide-errors', '-e', is_flag=True, help='Exclude rules with errors from printing') +@test_group.command("rule-survey") +@click.argument("query", required=False) +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option( + "--dump-file", + type=click.Path(dir_okay=False, path_type=Path), + default=get_path(["surveys", f"{time.strftime('%Y%m%dT%H%M%SL')}.json"]), + help="Save details of results (capped at 1000 results/rule)", +) +@click.option("--hide-zero-counts", "-z", is_flag=True, help="Exclude rules with zero hits from printing") +@click.option("--hide-errors", "-e", is_flag=True, help="Exclude rules with errors from printing") @click.pass_context -@add_client('elasticsearch', 'kibana', add_to_ctx=True) -def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_counts, hide_errors, - elasticsearch_client: Elasticsearch = None, kibana_client: Kibana = None): +@add_client(["elasticsearch", "kibana"], add_to_ctx=True) +def rule_survey( + ctx: click.Context, + query: str, + date_range: tuple[str, str], + dump_file: Path, + hide_zero_counts: bool, + hide_errors: bool, + elasticsearch_client: Elasticsearch, + kibana_client: Kibana, +): """Survey rule counts.""" - from kibana.resources import Signal + from kibana.resources import Signal # type: ignore[reportMissingTypeStubs] from .main import search_rules # from .eswrap import parse_unique_field_results - survey_results = [] + survey_results: list[dict[str, int]] = [] start_time, end_time = date_range if query: rules = RuleCollection() - paths = [Path(r['file']) for r in ctx.invoke(search_rules, query=query, verbose=False)] + paths = [Path(r["file"]) for r in ctx.invoke(search_rules, query=query, verbose=False)] rules.load_files(paths) else: rules = RuleCollection.default().filter(production_filter) - click.echo(f'Running survey against {len(rules)} rules') - click.echo(f'Saving detailed dump to: {dump_file}') + click.echo(f"Running survey against {len(rules)} rules") + click.echo(f"Saving detailed dump to: {dump_file}") collector = CollectEvents(elasticsearch_client) details = collector.search_from_rule(rules, start_time=start_time, end_time=end_time) @@ -1243,10 +1428,12 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun # add alerts with kibana_client: - range_dsl = {'query': {'bool': {'filter': []}}} - add_range_to_dsl(range_dsl['query']['bool']['filter'], start_time, end_time) - alerts = {a['_source']['signal']['rule']['rule_id']: a['_source'] - for a in Signal.search(range_dsl, size=10000)['hits']['hits']} + range_dsl: dict[str, Any] = {"query": {"bool": {"filter": []}}} + add_range_to_dsl(range_dsl["query"]["bool"]["filter"], start_time, end_time) + alerts: dict[str, Any] = { + a["_source"]["signal"]["rule"]["rule_id"]: a["_source"] + for a in Signal.search(range_dsl, size=10000)["hits"]["hits"] # type: ignore[reportUnknownMemberType] + } # for alert in alerts: # rule_id = alert['signal']['rule']['rule_id'] @@ -1256,58 +1443,64 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun for rule_id, count in counts.items(): alert_count = len(alerts.get(rule_id, [])) if alert_count > 0: - count['alert_count'] = alert_count + count["alert_count"] = alert_count details[rule_id].update(count) - search_count = count['search_count'] + search_count = count["search_count"] if not alert_count and (hide_zero_counts and search_count == 0) or (hide_errors and search_count == -1): continue survey_results.append(count) - fields = ['rule_id', 'name', 'search_count', 'alert_count'] - table = Table.from_list(fields, survey_results) + fields = ["rule_id", "name", "search_count", "alert_count"] + table = Table.from_list(fields, survey_results) # type: ignore[reportUnknownMemberType] if len(survey_results) > 200: click.echo_via_pager(table) else: click.echo(table) - os.makedirs(get_path('surveys'), exist_ok=True) - with open(dump_file, 'w') as f: + os.makedirs(get_path(["surveys"]), exist_ok=True) + with open(dump_file, "w") as f: json.dump(details, f, indent=2, sort_keys=True) return survey_results -@dev_group.group('utils') +@dev_group.group("utils") def utils_group(): """Commands for dev utility methods.""" -@utils_group.command('get-branches') -@click.option('--outfile', '-o', type=Path, default=get_etc_path("target-branches.yaml"), help='File to save output to') +@utils_group.command("get-branches") +@click.option( + "--outfile", + "-o", + type=Path, + default=get_etc_path(["target-branches.yaml"]), + help="File to save output to", +) def get_branches(outfile: Path): branch_list = get_stack_versions(drop_patch=True) target_branches = json.dumps(branch_list[:-1]) + "\n" - outfile.write_text(target_branches) + _ = outfile.write_text(target_branches) -@dev_group.group('integrations') +@dev_group.group("integrations") def integrations_group(): """Commands for dev integrations methods.""" -@integrations_group.command('build-manifests') -@click.option('--overwrite', '-o', is_flag=True, help="Overwrite the existing integrations-manifest.json.gz file") +@integrations_group.command("build-manifests") +@click.option("--overwrite", "-o", is_flag=True, help="Overwrite the existing integrations-manifest.json.gz file") @click.option("--integration", "-i", type=str, help="Adds an integration tag to the manifest file") @click.option("--prerelease", "-p", is_flag=True, default=False, help="Include prerelease versions") def build_integration_manifests(overwrite: bool, integration: str, prerelease: bool = False): """Builds consolidated integrations manifests file.""" click.echo("loading rules to determine all integration tags") - def flatten(tag_list: List[str]) -> List[str]: + def flatten(tag_list: list[str | list[str]] | list[str]) -> list[str]: return list(set([tag for tags in tag_list for tag in (flatten(tags) if isinstance(tags, list) else [tags])])) if integration: @@ -1320,10 +1513,11 @@ def flatten(tag_list: List[str]) -> List[str]: build_integrations_manifest(overwrite, rule_integrations=unique_integration_tags) -@integrations_group.command('build-schemas') -@click.option('--overwrite', '-o', is_flag=True, help="Overwrite the entire integrations-schema.json.gz file") -@click.option('--integration', '-i', type=str, - help="Adds a single integration schema to the integrations-schema.json.gz file") +@integrations_group.command("build-schemas") +@click.option("--overwrite", "-o", is_flag=True, help="Overwrite the entire integrations-schema.json.gz file") +@click.option( + "--integration", "-i", type=str, help="Adds a single integration schema to the integrations-schema.json.gz file" +) def build_integration_schemas(overwrite: bool, integration: str): """Builds consolidated integrations schemas file.""" click.echo("Building integration schemas...") @@ -1337,9 +1531,9 @@ def build_integration_schemas(overwrite: bool, integration: str): click.echo(f"Time taken to generate schemas: {(end_time - start_time) / 60:.2f} minutes") -@integrations_group.command('show-latest-compatible') -@click.option('--package', '-p', help='Name of package') -@click.option('--stack_version', '-s', required=True, help='Rule stack version') +@integrations_group.command("show-latest-compatible") +@click.option("--package", "-p", help="Name of package") +@click.option("--stack_version", "-s", required=True, help="Rule stack version") def show_latest_compatible_version(package: str, stack_version: str) -> None: """Prints the latest integration compatible version for specified package based on stack version supplied.""" @@ -1351,16 +1545,16 @@ def show_latest_compatible_version(package: str, stack_version: str) -> None: return try: - version = find_latest_compatible_version(package, "", - Version.parse(stack_version, optional_minor_and_patch=True), - packages_manifest) + version = find_latest_compatible_version( + package, "", Version.parse(stack_version, optional_minor_and_patch=True), packages_manifest + ) click.echo(f"Compatible integration {version=}") except Exception as e: click.echo(f"Error finding compatible version: {str(e)}") return -@dev_group.group('schemas') +@dev_group.group("schemas") def schemas_group(): """Commands for dev schema methods.""" @@ -1374,10 +1568,21 @@ def update_rule_data_schemas(): @schemas_group.command("generate") -@click.option("--token", required=True, prompt=get_github_token() is None, default=get_github_token(), - help="GitHub token to use for the PR", hide_input=True) -@click.option("--schema", "-s", required=True, type=click.Choice(["endgame", "ecs", "beats", "endpoint"]), - help="Schema to generate") +@click.option( + "--token", + required=True, + prompt=get_github_token() is None, + default=get_github_token(), + help="GitHub token to use for the PR", + hide_input=True, +) +@click.option( + "--schema", + "-s", + required=True, + type=click.Choice(["endgame", "ecs", "beats", "endpoint"]), + help="Schema to generate", +) @click.option("--schema-version", "-sv", help="Tagged version from TBD. e.g., 1.9.0") @click.option("--endpoint-target", "-t", type=str, default="endpoint", help="Target endpoint schema") @click.option("--overwrite", is_flag=True, help="Overwrite if versions exist") @@ -1413,8 +1618,9 @@ def generate_schema(token: str, schema: str, schema_version: str, endpoint_targe repo = client.get_repo("elastic/endpoint-package") contents = repo.get_contents("custom_schemas") optional_endpoint_targets = [ - Path(f.path).name.replace("custom_", "").replace(".yml", "") - for f in contents if f.name.endswith(".yml") or Path(f.path).name == endpoint_target + Path(f.path).name.replace("custom_", "").replace(".yml", "") # type: ignore[reportUnknownMemberType] + for f in contents # type: ignore[reportUnknownVariableType] + if f.name.endswith(".yml") or Path(f.path).name == endpoint_target # type: ignore ] if not endpoint_target: @@ -1426,45 +1632,45 @@ def generate_schema(token: str, schema: str, schema_version: str, endpoint_targe click.echo(f"Done generating {schema} schema") -@dev_group.group('attack') +@dev_group.group("attack") def attack_group(): """Commands for managing Mitre ATT&CK data and mappings.""" -@attack_group.command('refresh-data') -def refresh_attack_data() -> dict: +@attack_group.command("refresh-data") +def refresh_attack_data() -> dict[str, Any] | None: """Refresh the ATT&CK data file.""" data, _ = attack.refresh_attack_data() return data -@attack_group.command('refresh-redirect-mappings') +@attack_group.command("refresh-redirect-mappings") def refresh_threat_mappings(): """Refresh the ATT&CK redirect file and update all rule threat mappings.""" # refresh the attack_technique_redirects - click.echo('refreshing data in attack_technique_redirects.json') + click.echo("refreshing data in attack_technique_redirects.json") attack.refresh_redirected_techniques_map() -@attack_group.command('update-rules') -def update_attack_in_rules() -> List[Optional[TOMLRule]]: +@attack_group.command("update-rules") +def update_attack_in_rules() -> list[TOMLRule]: """Update threat mappings attack data in all rules.""" - new_rules = [] + new_rules: list[TOMLRule] = [] redirected_techniques = attack.load_techniques_redirect() - today = time.strftime('%Y/%m/%d') + today = time.strftime("%Y/%m/%d") rules = RuleCollection.default() for rule in rules.rules: needs_update = False - valid_threat: List[ThreatMapping] = [] - threat_pending_update = {} + valid_threat: list[ThreatMapping] = [] + threat_pending_update: dict[str, list[str]] = {} threat = rule.contents.data.threat or [] for entry in threat: tactic = entry.tactic.name - technique_ids = [] - technique_names = [] + technique_ids: list[str] = [] + technique_names: list[str] = [] for technique in entry.technique or []: technique_ids.append(technique.id) technique_names.append(technique.name) @@ -1494,7 +1700,7 @@ def update_attack_in_rules() -> List[Optional[TOMLRule]]: try: updated_threat = attack.build_threat_map_entry(tactic, *techniques) except ValueError as err: - raise ValueError(f'{rule.id} - {rule.name}: {err}') + raise ValueError(f"{rule.id} - {rule.name}: {err}") tm = ThreatMapping.from_dict(updated_threat) valid_threat.append(tm) @@ -1507,66 +1713,71 @@ def update_attack_in_rules() -> List[Optional[TOMLRule]]: new_rules.append(new_rule) if new_rules: - click.echo(f'\nFinished - {len(new_rules)} rules updated!') + click.echo(f"\nFinished - {len(new_rules)} rules updated!") else: - click.echo('No rule changes needed') + click.echo("No rule changes needed") return new_rules -@dev_group.group('transforms') +@dev_group.group("transforms") def transforms_group(): """Commands for managing TOML [transform].""" -def guide_plugin_convert_(contents: Optional[str] = None, default: Optional[str] = '' - ) -> Optional[Dict[str, Dict[str, list]]]: +def guide_plugin_convert_( + contents: str | None = None, + default: str | None = "", +) -> dict[str, dict[str, list[str]]] | None: """Convert investigation guide plugin format to toml""" - contents = contents or click.prompt('Enter plugin contents', default=default) + contents = contents or click.prompt("Enter plugin contents", default=default) if not contents: return - parsed = re.match(r'!{(?P\w+)(?P{.+})}', contents.strip()) + parsed = re.match(r"!{(?P\w+)(?P{.+})}", contents.strip()) + if not parsed: + raise ValueError("No plugin name found") try: - plugin = parsed.group('plugin') - data = parsed.group('data') + plugin = parsed.group("plugin") + data = parsed.group("data") except AttributeError as e: - raise client_error('Unrecognized pattern', exc=e) - loaded = {'transform': {plugin: [json.loads(data)]}} - click.echo(pytoml.dumps(loaded)) + raise raise_client_error("Unrecognized pattern", exc=e) + loaded = {"transform": {plugin: [json.loads(data)]}} + click.echo(pytoml.dumps(loaded)) # type: ignore[reportUnknownMemberType] return loaded -@transforms_group.command('guide-plugin-convert') -def guide_plugin_convert(contents: Optional[str] = None, default: Optional[str] = '' - ) -> Optional[Dict[str, Dict[str, list]]]: +@transforms_group.command("guide-plugin-convert") +def guide_plugin_convert( + contents: str | None = None, default: str | None = "" +) -> dict[str, dict[str, list[str]]] | None: """Convert investigation guide plugin format to toml.""" return guide_plugin_convert_(contents=contents, default=default) -@transforms_group.command('guide-plugin-to-rule') -@click.argument('rule-path', type=Path) +@transforms_group.command("guide-plugin-to-rule") +@click.argument("rule-path", type=Path) @click.pass_context def guide_plugin_to_rule(ctx: click.Context, rule_path: Path, save: bool = True) -> TOMLRule: """Convert investigation guide plugin format to toml and save to rule.""" rc = RuleCollection() rule = rc.load_file(rule_path) - transforms = defaultdict(list) - existing_transform = rule.contents.transform - transforms.update(existing_transform.to_dict() if existing_transform is not None else {}) + transforms: dict[str, list[Any]] = defaultdict(list) + existing_transform: RuleTransform | None = rule.contents.transform # type: ignore[reportAssignmentType] + transforms.update(existing_transform.to_dict() if existing_transform else {}) - click.secho('(blank line to continue)', fg='yellow') + click.secho("(blank line to continue)", fg="yellow") while True: loaded = ctx.invoke(guide_plugin_convert) if not loaded: break - data = loaded['transform'] + data = loaded["transform"] for plugin, entries in data.items(): transforms[plugin].extend(entries) transform = RuleTransform.from_dict(transforms) - new_contents = TOMLRuleContents(data=rule.contents.data, metadata=rule.contents.metadata, transform=transform) + new_contents = TOMLRuleContents(data=rule.contents.data, metadata=rule.contents.metadata, transform=transform) # type: ignore[reportArgumentType] updated_rule = TOMLRule(contents=new_contents, path=rule.path) if save: diff --git a/detection_rules/docs.py b/detection_rules/docs.py index 27c455d203f..98600ada4a0 100644 --- a/detection_rules/docs.py +++ b/detection_rules/docs.py @@ -4,6 +4,7 @@ # 2.0. """Create summary documents for a rule package.""" + import itertools import json import re @@ -13,9 +14,11 @@ from dataclasses import asdict, dataclass from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +import typing +from typing import Any -import xlsxwriter +import xlsxwriter # type: ignore[reportMissingTypeStubs] +import xlsxwriter.format # type: ignore[reportMissingTypeStubs] from semver import Version from .attack import attack_tm, matrix, tactics, technique_lookup @@ -30,38 +33,41 @@ class PackageDocument(xlsxwriter.Workbook): """Excel document for summarizing a rules package.""" - def __init__(self, path, package: Package): + def __init__(self, path: str, package: Package): """Create an excel workbook for the package.""" - self._default_format = {'font_name': 'Helvetica', 'font_size': 12} - super(PackageDocument, self).__init__(path) + self._default_format = {"font_name": "Helvetica", "font_size": 12} + super(PackageDocument, self).__init__(path) # type: ignore[reportUnknownMemberType] self.package = package self.deprecated_rules = package.deprecated_rules self.production_rules = package.rules - self.percent = self.add_format({'num_format': '0%'}) - self.bold = self.add_format({'bold': True}) - self.default_header_format = self.add_format({'bold': True, 'bg_color': '#FFBE33'}) - self.center = self.add_format({'align': 'center', 'valign': 'center'}) - self.bold_center = self.add_format({'bold': True, 'align': 'center', 'valign': 'center'}) - self.right_align = self.add_format({'align': 'right'}) + self.percent = self.add_format({"num_format": "0%"}) + self.bold = self.add_format({"bold": True}) + self.default_header_format = self.add_format({"bold": True, "bg_color": "#FFBE33"}) + self.center = self.add_format({"align": "center", "valign": "center"}) + self.bold_center = self.add_format({"bold": True, "align": "center", "valign": "center"}) + self.right_align = self.add_format({"align": "right"}) self._coverage = self._get_attack_coverage() - def add_format(self, properties=None): + def add_format(self, properties: dict[str, Any] | None = None) -> xlsxwriter.format.Format: """Add a format to the doc.""" properties = properties or {} for key in self._default_format: if key not in properties: properties[key] = self._default_format[key] - return super(PackageDocument, self).add_format(properties) + return super(PackageDocument, self).add_format(properties) # type: ignore[reportUnknownMemberType] def _get_attack_coverage(self): - coverage = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) + coverage: dict[str, dict[str, dict[str, int]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) for rule in self.package.rules: threat = rule.contents.data.threat + if not rule.path: + raise ValueError("No rule path found") + sub_dir = Path(rule.path).parent.name if threat: @@ -79,11 +85,12 @@ def populate(self): self.add_summary() self.add_rule_details() self.add_attack_matrix() - self.add_rule_details(self.deprecated_rules, 'Deprecated Rules') + self.add_rule_details(self.deprecated_rules, "Deprecated Rules") + @typing.no_type_check def add_summary(self): """Add the summary worksheet.""" - worksheet = self.add_worksheet('Summary') + worksheet = self.add_worksheet("Summary") worksheet.freeze_panes(1, 0) worksheet.set_column(0, 0, 25) worksheet.set_column(1, 1, 10) @@ -92,66 +99,88 @@ def add_summary(self): worksheet.merge_range(row, 0, row, 1, "SUMMARY", self.bold_center) row += 1 - worksheet.write(row, 0, "Package Name") - worksheet.write(row, 1, self.package.name, self.right_align) + _ = worksheet.write(row, 0, "Package Name") + _ = worksheet.write(row, 1, self.package.name, self.right_align) row += 1 - tactic_counts = defaultdict(int) + tactic_counts: dict[str, int] = defaultdict(int) for rule in self.package.rules: threat = rule.contents.data.threat if threat: for entry in threat: tactic_counts[entry.tactic.name] += 1 - worksheet.write(row, 0, "Total Production Rules") - worksheet.write(row, 1, len(self.production_rules)) + _ = worksheet.write(row, 0, "Total Production Rules") + _ = worksheet.write(row, 1, len(self.production_rules)) row += 2 - worksheet.write(row, 0, "Total Deprecated Rules") - worksheet.write(row, 1, len(self.deprecated_rules)) + _ = worksheet.write(row, 0, "Total Deprecated Rules") + _ = worksheet.write(row, 1, len(self.deprecated_rules)) row += 1 - worksheet.write(row, 0, "Total Rules") - worksheet.write(row, 1, len(self.package.rules)) + _ = worksheet.write(row, 0, "Total Rules") + _ = worksheet.write(row, 1, len(self.package.rules)) row += 2 worksheet.merge_range(row, 0, row, 3, f"MITRE {attack_tm} TACTICS", self.bold_center) row += 1 for tactic in tactics: - worksheet.write(row, 0, tactic) - worksheet.write(row, 1, tactic_counts[tactic]) + _ = worksheet.write(row, 0, tactic) + _ = worksheet.write(row, 1, tactic_counts[tactic]) num_techniques = len(self._coverage[tactic]) total_techniques = len(matrix[tactic]) percent = float(num_techniques) / float(total_techniques) - worksheet.write(row, 2, percent, self.percent) - worksheet.write(row, 3, f'{num_techniques}/{total_techniques}', self.right_align) + _ = worksheet.write(row, 2, percent, self.percent) + _ = worksheet.write(row, 3, f"{num_techniques}/{total_techniques}", self.right_align) row += 1 - def add_rule_details(self, rules: Optional[Union[DeprecatedCollection, RuleCollection]] = None, - name='Rule Details'): + def add_rule_details( + self, + rules: DeprecatedCollection | RuleCollection | None = None, + name: str = "Rule Details", + ): """Add a worksheet for detailed metadata of rules.""" if rules is None: rules = self.production_rules - worksheet = self.add_worksheet(name) - worksheet.freeze_panes(1, 1) - headers = ('Name', 'ID', 'Version', 'Type', 'Language', 'Index', 'Tags', - f'{attack_tm} Tactics', f'{attack_tm} Techniques', 'Description') + worksheet = self.add_worksheet(name) # type: ignore[reportUnknownVariableType] + worksheet.freeze_panes(1, 1) # type: ignore[reportUnknownVariableType] + headers = ( + "Name", + "ID", + "Version", + "Type", + "Language", + "Index", + "Tags", + f"{attack_tm} Tactics", + f"{attack_tm} Techniques", + "Description", + ) for column, header in enumerate(headers): - worksheet.write(0, column, header, self.default_header_format) + _ = worksheet.write(0, column, header, self.default_header_format) # type: ignore[reportUnknownMemberType] - column_max_widths = [0 for i in range(len(headers))] + column_max_widths = [0 for _ in range(len(headers))] metadata_fields = ( - 'name', 'rule_id', 'version', 'type', 'language', 'index', 'tags', 'tactics', 'techniques', 'description' + "name", + "rule_id", + "version", + "type", + "language", + "index", + "tags", + "tactics", + "techniques", + "description", ) for row, rule in enumerate(rules, 1): - rule_contents = {'tactics': '', 'techniques': ''} + rule_contents = {"tactics": "", "techniques": ""} if isinstance(rules, RuleCollection): - flat_mitre = ThreatMapping.flatten(rule.contents.data.threat) - rule_contents = {'tactics': flat_mitre.tactic_names, 'techniques': flat_mitre.technique_ids} + flat_mitre = ThreatMapping.flatten(rule.contents.data.threat) # type: ignore[reportAttributeAccessIssue] + rule_contents = {"tactics": flat_mitre.tactic_names, "techniques": flat_mitre.technique_ids} rule_contents.update(rule.contents.to_api_format()) @@ -160,8 +189,8 @@ def add_rule_details(self, rules: Optional[Union[DeprecatedCollection, RuleColle if value is None: continue elif isinstance(value, list): - value = ', '.join(value) - worksheet.write(row, column, value) + value = ", ".join(value) + _ = worksheet.write(row, column, value) # type: ignore[reportUnknownMemberType] column_max_widths[column] = max(column_max_widths[column], len(str(value))) # cap description width at 80 @@ -169,37 +198,43 @@ def add_rule_details(self, rules: Optional[Union[DeprecatedCollection, RuleColle # this is still not perfect because the font used is not monospaced, but it gets it close for index, width in enumerate(column_max_widths): - worksheet.set_column(index, index, width) + _ = worksheet.set_column(index, index, width) # type: ignore[reportUnknownMemberType] - worksheet.autofilter(0, 0, len(rules) + 1, len(headers) - 1) + _ = worksheet.autofilter(0, 0, len(rules) + 1, len(headers) - 1) # type: ignore[reportUnknownMemberType] def add_attack_matrix(self): """Add a worksheet for ATT&CK coverage.""" - worksheet = self.add_worksheet(attack_tm + ' Coverage') - worksheet.freeze_panes(1, 0) - header = self.add_format({'font_size': 12, 'bold': True, 'bg_color': '#005B94', 'font_color': 'white'}) - default = self.add_format({'font_size': 10, 'text_wrap': True}) - bold = self.add_format({'font_size': 10, 'bold': True, 'text_wrap': True}) - technique_url = 'https://attack.mitre.org/techniques/' + worksheet = self.add_worksheet(attack_tm + " Coverage") # type: ignore[reportUnknownMemberType] + worksheet.freeze_panes(1, 0) # type: ignore[reportUnknownMemberType] + header = self.add_format({"font_size": 12, "bold": True, "bg_color": "#005B94", "font_color": "white"}) + default = self.add_format({"font_size": 10, "text_wrap": True}) + bold = self.add_format({"font_size": 10, "bold": True, "text_wrap": True}) + technique_url = "https://attack.mitre.org/techniques/" for column, tactic in enumerate(tactics): - worksheet.write(0, column, tactic, header) - worksheet.set_column(column, column, 20) + _ = worksheet.write(0, column, tactic, header) # type: ignore[reportUnknownMemberType] + _ = worksheet.set_column(column, column, 20) # type: ignore[reportUnknownMemberType] for row, technique_id in enumerate(matrix[tactic], 1): technique = technique_lookup[technique_id] fmt = bold if technique_id in self._coverage[tactic] else default coverage = self._coverage[tactic].get(technique_id) - coverage_str = '' + coverage_str = "" if coverage: - coverage_str = '\n\n' - coverage_str += '\n'.join(f'{sub_dir}: {count}' for sub_dir, count in coverage.items()) + coverage_str = "\n\n" + coverage_str += "\n".join(f"{sub_dir}: {count}" for sub_dir, count in coverage.items()) - worksheet.write_url(row, column, technique_url + technique_id.replace('.', '/'), cell_format=fmt, - string=technique['name'], tip=f'{technique_id}{coverage_str}') + _ = worksheet.write_url( # type: ignore[reportUnknownMemberType] + row, + column, + technique_url + technique_id.replace(".", "/"), + cell_format=fmt, + string=technique["name"], + tip=f"{technique_id}{coverage_str}", + ) - worksheet.autofilter(0, 0, max([len(v) for k, v in matrix.items()]) + 1, len(tactics) - 1) + _ = worksheet.autofilter(0, 0, max([len(v) for _, v in matrix.items()]) + 1, len(tactics) - 1) # type: ignore[reportUnknownMemberType] # product rule docs @@ -207,41 +242,40 @@ def add_attack_matrix(self): class AsciiDoc: - @classmethod def bold_kv(cls, key: str, value: str): - return f'*{key}*: {value}' + return f"*{key}*: {value}" @classmethod - def description_list(cls, value: Dict[str, str], linesep='\n\n'): - return f'{linesep}'.join(f'{k}::\n{v}' for k, v in value.items()) + def description_list(cls, value: dict[str, str], linesep: str = "\n\n"): + return f"{linesep}".join(f"{k}::\n{v}" for k, v in value.items()) @classmethod - def bulleted(cls, value: str, depth=1): - return f'{"*" * depth} {value}' + def bulleted(cls, value: str, depth: int = 1): + return f"{'*' * depth} {value}" @classmethod - def bulleted_list(cls, values: Iterable): - return '* ' + '\n* '.join(values) + def bulleted_list(cls, values: list[str]): + return "* " + "\n* ".join(values) @classmethod - def code(cls, value: str, code='js'): + def code(cls, value: str, code: str = "js"): line_sep = "-" * 34 - return f'[source, {code}]\n{line_sep}\n{value}\n{line_sep}' + return f"[source, {code}]\n{line_sep}\n{value}\n{line_sep}" @classmethod def title(cls, depth: int, value: str): - return f'{"=" * depth} {value}' + return f"{'=' * depth} {value}" @classmethod def inline_anchor(cls, value: str): - return f'[[{value}]]' + return f"[[{value}]]" @classmethod - def table(cls, data: dict) -> str: - entries = [f'| {k} | {v}' for k, v in data.items()] - table = ['[width="100%"]', '|==='] + entries + ['|==='] - return '\n'.join(table) + def table(cls, data: dict[str, Any]) -> str: + entries = [f"| {k} | {v}" for k, v in data.items()] + table = ['[width="100%"]', "|==="] + entries + ["|==="] + return "\n".join(table) class SecurityDocs: @@ -252,35 +286,50 @@ class KibanaSecurityDocs: """Generate docs for prebuilt rules in Elastic documentation.""" @staticmethod - def cmp_value(value): + def cmp_value(value: Any) -> Any: if isinstance(value, list): - cmp_new = tuple(value) + cmp_new = tuple(value) # type: ignore[reportUnknownArgumentType] elif isinstance(value, dict): cmp_new = json.dumps(value, sort_keys=True, indent=2) else: cmp_new = value - return cmp_new + return cmp_new # type: ignore[reportUnknownVariableType] class IntegrationSecurityDocs: """Generate docs for prebuilt rules in Elastic documentation.""" - def __init__(self, registry_version: str, directory: Path, overwrite=False, - updated_rules: Optional[Dict[str, TOMLRule]] = None, new_rules: Optional[Dict[str, TOMLRule]] = None, - deprecated_rules: Optional[Dict[str, TOMLRule]] = None, update_message: str = ""): + def __init__( + self, + registry_version: str, + directory: Path, + overwrite: bool = False, + updated_rules: dict[str, TOMLRule] | None = None, + new_rules: dict[str, TOMLRule] | None = None, + deprecated_rules: dict[str, DeprecatedRule] | None = None, + update_message: str = "", + ): self.new_rules = new_rules self.updated_rules = updated_rules self.deprecated_rules = deprecated_rules - self.included_rules = list(itertools.chain(new_rules.values(), - updated_rules.values(), - deprecated_rules.values())) + self.included_rules: list[TOMLRule | DeprecatedRule] = [] + if new_rules: + self.included_rules += new_rules.values() + + if updated_rules: + self.included_rules += updated_rules.values() + + if deprecated_rules: + self.included_rules += deprecated_rules.values() all_rules = RuleCollection.default().rules self.sorted_rules = sorted(all_rules, key=lambda rule: rule.name) self.registry_version_str, self.base_name, self.prebuilt_rule_base = self.parse_registry(registry_version) self.directory = directory - self.package_directory = directory / "docs" / "detections" / "prebuilt-rules" / "downloadable-packages" / self.base_name # noqa: E501 + self.package_directory = ( + directory / "docs" / "detections" / "prebuilt-rules" / "downloadable-packages" / self.base_name + ) # noqa: E501 self.rule_details = directory / "docs" / "detections" / "prebuilt-rules" / "rule-details" self.update_message = update_message @@ -290,18 +339,20 @@ def __init__(self, registry_version: str, directory: Path, overwrite=False, self.package_directory.mkdir(parents=True, exist_ok=overwrite) @staticmethod - def parse_registry(registry_version: str) -> (str, str, str): - registry_version = Version.parse(registry_version, optional_minor_and_patch=True) - short_registry_version = [str(n) for n in registry_version[:3]] - registry_version_str = '.'.join(short_registry_version) + def parse_registry(registry_version_val: str) -> tuple[str, str, str]: + registry_version = Version.parse(registry_version_val, optional_minor_and_patch=True) + + parts = registry_version[:3] + short_registry_version = [str(n) for n in parts] # type: ignore + registry_version_str = ".".join(short_registry_version) base_name = "-".join(short_registry_version) - prebuilt_rule_base = f'prebuilt-rule-{base_name}' + prebuilt_rule_base = f"prebuilt-rule-{base_name}" return registry_version_str, base_name, prebuilt_rule_base def generate_appendix(self): # appendix - appendix = self.package_directory / f'prebuilt-rules-{self.base_name}-appendix.asciidoc' + appendix = self.package_directory / f"prebuilt-rules-{self.base_name}-appendix.asciidoc" appendix_header = textwrap.dedent(f""" ["appendix",role="exclude",id="prebuilt-rule-{self.base_name}-prebuilt-rules-{self.base_name}-appendix"] @@ -309,15 +360,15 @@ def generate_appendix(self): This section lists all updates associated with version {self.registry_version_str} of the Fleet integration *Prebuilt Security Detection Rules*. - """).lstrip() # noqa: E501 + """).lstrip() - include_format = f'include::{self.prebuilt_rule_base}-' + '{}.asciidoc[]' - appendix_lines = [appendix_header] + [include_format.format(name_to_title(r.name)) for r in self.included_rules] - appendix_str = '\n'.join(appendix_lines) + '\n' - appendix.write_text(appendix_str) + include_format = f"include::{self.prebuilt_rule_base}-" + "{}.asciidoc[]" + appendix_lines = [appendix_header] + [include_format.format(name_to_title(r.name)) for r in self.included_rules] # type: ignore[reportArgumentType] + appendix_str = "\n".join(appendix_lines) + "\n" + _ = appendix.write_text(appendix_str) def generate_summary(self): - summary = self.package_directory / f'prebuilt-rules-{self.base_name}-summary.asciidoc' + summary = self.package_directory / f"prebuilt-rules-{self.base_name}-summary.asciidoc" summary_header = textwrap.dedent(f""" [[prebuilt-rule-{self.base_name}-prebuilt-rules-{self.base_name}-summary]] @@ -332,25 +383,34 @@ def generate_summary(self): |Rule |Description |Status |Version """).lstrip() # noqa: E501 - rule_entries = [] + rule_entries: list[str] = [] for rule in self.included_rules: - if rule.contents.metadata.get('maturity') == 'development': + if rule.contents.metadata.get("maturity") == "development": continue - title_name = name_to_title(rule.name) - status = 'new' if rule.id in self.new_rules else 'update' if rule.id in self.updated_rules else 'deprecated' - description = rule.contents.to_api_format()['description'] + title_name = name_to_title(rule.name) # type: ignore[reportArgumentType] + + if self.new_rules and rule.id in self.new_rules: + status = "new" + elif self.updated_rules and rule.id in self.updated_rules: + status = "update" + else: + status = "deprecated" + + description = rule.contents.to_api_format()["description"] version = rule.contents.autobumped_version - rule_entries.append(f'|<> ' - f'| {description} | {status} | {version} \n') + rule_entries.append( + f"|<> " + f"| {description} | {status} | {version} \n" + ) - summary_lines = [summary_header] + rule_entries + ['|=============================================='] - summary_str = '\n'.join(summary_lines) + '\n' - summary.write_text(summary_str) + summary_lines = [summary_header] + rule_entries + ["|=============================================="] + summary_str = "\n".join(summary_lines) + "\n" + _ = summary.write_text(summary_str) def generate_rule_reference(self): """Generate rule reference page for prebuilt rules.""" - summary = self.directory / "docs" / "detections" / "prebuilt-rules" / 'prebuilt-rules-reference.asciidoc' - rule_list = self.directory / "docs" / "detections" / "prebuilt-rules" / 'rule-desc-index.asciidoc' + summary = self.directory / "docs" / "detections" / "prebuilt-rules" / "prebuilt-rules-reference.asciidoc" + rule_list = self.directory / "docs" / "detections" / "prebuilt-rules" / "rule-desc-index.asciidoc" summary_header = textwrap.dedent(""" [[prebuilt-rules]] @@ -370,51 +430,51 @@ def generate_rule_reference(self): """).lstrip() # noqa: E501 - rule_entries = [] - rule_includes = [] + rule_entries: list[str] = [] + rule_includes: list[str] = [] for rule in self.sorted_rules: if isinstance(rule, DeprecatedRule): continue - if rule.contents.metadata.get('maturity') == 'development': + if rule.contents.metadata.get("maturity") == "development": continue title_name = name_to_title(rule.name) # skip rules not built for this package - built_rules = [x.name for x in self.rule_details.glob('*.asciidoc')] + built_rules = [x.name for x in self.rule_details.glob("*.asciidoc")] if f"{title_name}.asciidoc" not in built_rules: continue - rule_includes.append(f'include::rule-details/{title_name}.asciidoc[]') - tags = ', '.join(f'[{tag}]' for tag in rule.contents.data.tags) - description = rule.contents.to_api_format()['description'] + rule_includes.append(f"include::rule-details/{title_name}.asciidoc[]") + tags = ", ".join(f"[{tag}]" for tag in rule.contents.data.tags) # type: ignore[reportOptionalIterable] + description = rule.contents.to_api_format()["description"] version = rule.contents.autobumped_version added = rule.contents.metadata.min_stack_version - rule_entries.append(f'|<<{title_name}, {rule.name}>> |{description} |{tags} |{added} |{version}\n') + rule_entries.append(f"|<<{title_name}, {rule.name}>> |{description} |{tags} |{added} |{version}\n") - summary_lines = [summary_header] + rule_entries + ['|=============================================='] - summary_str = '\n'.join(summary_lines) + '\n' - summary.write_text(summary_str) + summary_lines = [summary_header] + rule_entries + ["|=============================================="] + summary_str = "\n".join(summary_lines) + "\n" + _ = summary.write_text(summary_str) # update rule-desc-index.asciidoc - rule_list.write_text('\n'.join(rule_includes)) + _ = rule_list.write_text("\n".join(rule_includes)) def generate_rule_details(self): """Generate rule details for each prebuilt rule.""" included_rules = [x.name for x in self.included_rules] for rule in self.sorted_rules: - if rule.contents.metadata.get('maturity') == 'development': + if rule.contents.metadata.get("maturity") == "development": continue rule_detail = IntegrationRuleDetail(rule.id, rule.contents.to_api_format(), {}, self.base_name) - rule_path = self.package_directory / f'{self.prebuilt_rule_base}-{name_to_title(rule.name)}.asciidoc' - prebuilt_rule_path = self.rule_details / f'{name_to_title(rule.name)}.asciidoc' # noqa: E501 + rule_path = self.package_directory / f"{self.prebuilt_rule_base}-{name_to_title(rule.name)}.asciidoc" + prebuilt_rule_path = self.rule_details / f"{name_to_title(rule.name)}.asciidoc" # noqa: E501 if rule.name in included_rules: # only include updates - rule_path.write_text(rule_detail.generate()) + _ = rule_path.write_text(rule_detail.generate()) # add all available rules to the rule details directory - prebuilt_rule_path.write_text(rule_detail.generate(title=f'{name_to_title(rule.name)}')) + _ = prebuilt_rule_path.write_text(rule_detail.generate(title=f"{name_to_title(rule.name)}")) def generate_manual_updates(self): """ @@ -423,33 +483,33 @@ def generate_manual_updates(self): updates = {} # Update downloadable rule updates entry - today = datetime.today().strftime('%d %b %Y') + today = datetime.today().strftime("%d %b %Y") - updates['downloadable-updates.asciidoc'] = { - 'table_entry': ( - f'|<> | {today} | {len(self.new_rules)} | ' - f'{len(self.updated_rules)} | ' + updates["downloadable-updates.asciidoc"] = { + "table_entry": ( + f"|<> | {today} | {len(self.new_rules or [])} | " + f"{len(self.updated_rules or [])} | " + ), + "table_include": ( + f"include::downloadable-packages/{self.base_name}/" + f"prebuilt-rules-{self.base_name}-summary.asciidoc[leveloffset=+1]" ), - 'table_include': ( - f'include::downloadable-packages/{self.base_name}/' - f'prebuilt-rules-{self.base_name}-summary.asciidoc[leveloffset=+1]' - ) } - updates['index.asciidoc'] = { - 'index_include': ( - f'include::detections/prebuilt-rules/downloadable-packages/{self.base_name}/' - f'prebuilt-rules-{self.base_name}-appendix.asciidoc[]' + updates["index.asciidoc"] = { + "index_include": ( + f"include::detections/prebuilt-rules/downloadable-packages/{self.base_name}/" + f"prebuilt-rules-{self.base_name}-appendix.asciidoc[]" ) } # Add index.asciidoc:index_include in docs/index.asciidoc - docs_index = self.package_directory.parent.parent.parent.parent / 'index.asciidoc' - docs_index.write_text(docs_index.read_text() + '\n' + updates['index.asciidoc']['index_include'] + '\n') + docs_index = self.package_directory.parent.parent.parent.parent / "index.asciidoc" + _ = docs_index.write_text(docs_index.read_text() + "\n" + updates["index.asciidoc"]["index_include"] + "\n") # Add table_entry to docs/detections/prebuilt-rules/prebuilt-rules-downloadable-updates.asciidoc - downloadable_updates = self.package_directory.parent.parent / 'prebuilt-rules-downloadable-updates.asciidoc' + downloadable_updates = self.package_directory.parent.parent / "prebuilt-rules-downloadable-updates.asciidoc" version = Version.parse(self.registry_version_str) last_version = f"{version.major}.{version.minor - 1}" update_url = f"https://www.elastic.co/guide/en/security/{last_version}/prebuilt-rules-downloadable-updates.html" @@ -469,21 +529,22 @@ def generate_manual_updates(self): |Update version |Date | New rules | Updated rules | Notes """).lstrip() # noqa: E501 - new_content = updates['downloadable-updates.asciidoc']['table_entry'] + '\n' + self.update_message + new_content = updates["downloadable-updates.asciidoc"]["table_entry"] + "\n" + self.update_message self.add_content_to_table_top(downloadable_updates, summary_header, new_content) # Add table_include to/docs/detections/prebuilt-rules/prebuilt-rules-downloadable-updates.asciidoc # Reset the historic information at the beginning of each minor version - historic_data = downloadable_updates.read_text() if Version.parse(self.registry_version_str).patch > 1 else '' - downloadable_updates.write_text(historic_data + # noqa: W504 - updates['downloadable-updates.asciidoc']['table_include'] + '\n') + historic_data = downloadable_updates.read_text() if Version.parse(self.registry_version_str).patch > 1 else "" + _ = downloadable_updates.write_text( + historic_data + updates["downloadable-updates.asciidoc"]["table_include"] + "\n" + ) def add_content_to_table_top(self, file_path: Path, summary_header: str, new_content: str): """Insert content at the top of a Markdown table right after the specified header.""" file_contents = file_path.read_text() # Find the header in the file - header = '|Update version |Date | New rules | Updated rules | Notes\n' + header = "|Update version |Date | New rules | Updated rules | Notes\n" header_index = file_contents.find(header) if header_index == -1: @@ -496,7 +557,7 @@ def add_content_to_table_top(self, file_path: Path, summary_header: str, new_con updated_contents = summary_header + f"\n{new_content}\n" + file_contents[insert_position:] # Write the updated contents back to the file - file_path.write_text(updated_contents) + _ = file_path.write_text(updated_contents) def generate(self) -> Path: self.generate_appendix() @@ -510,139 +571,147 @@ def generate(self) -> Path: class IntegrationRuleDetail: """Rule detail page generation.""" - def __init__(self, rule_id: str, rule: dict, changelog: Dict[str, dict], package_str: str): + def __init__(self, rule_id: str, rule: dict[str, Any], changelog: dict[str, dict[str, Any]], package_str: str): self.rule_id = rule_id self.rule = rule self.changelog = changelog self.package = package_str - self.rule_title = f'prebuilt-rule-{self.package}-{name_to_title(self.rule["name"])}' + self.rule_title = f"prebuilt-rule-{self.package}-{name_to_title(self.rule['name'])}" # set some defaults - self.rule.setdefault('max_signals', 100) - self.rule.setdefault('interval', '5m') + self.rule.setdefault("max_signals", 100) + self.rule.setdefault("interval", "5m") - def generate(self, title: str = None) -> str: + def generate(self, title: str | None = None) -> str: """Generate the rule detail page.""" title = title or self.rule_title page = [ AsciiDoc.inline_anchor(title), - AsciiDoc.title(3, self.rule['name']), - '', - self.rule['description'], - '', + AsciiDoc.title(3, self.rule["name"]), + "", + self.rule["description"], + "", self.metadata_str(), - '' + "", ] - if 'note' in self.rule: - page.extend([self.guide_str(), '']) - if 'setup' in self.rule: - page.extend([self.setup_str(), '']) - if 'query' in self.rule: - page.extend([self.query_str(), '']) - if 'threat' in self.rule: - page.extend([self.threat_mapping_str(), '']) + if "note" in self.rule: + page.extend([self.guide_str(), ""]) + if "setup" in self.rule: + page.extend([self.setup_str(), ""]) + if "query" in self.rule: + page.extend([self.query_str(), ""]) + if "threat" in self.rule: + page.extend([self.threat_mapping_str(), ""]) - return '\n'.join(page) + return "\n".join(page) def metadata_str(self) -> str: """Add the metadata section to the rule detail page.""" fields = { - 'type': 'Rule type', - 'index': 'Rule indices', - 'severity': 'Severity', - 'risk_score': 'Risk score', - 'interval': 'Runs every', - 'from': 'Searches indices from', - 'max_signals': 'Maximum alerts per execution', - 'references': 'References', - 'tags': 'Tags', - 'version': 'Version', - 'author': 'Rule authors', - 'license': 'Rule license' + "type": "Rule type", + "index": "Rule indices", + "severity": "Severity", + "risk_score": "Risk score", + "interval": "Runs every", + "from": "Searches indices from", + "max_signals": "Maximum alerts per execution", + "references": "References", + "tags": "Tags", + "version": "Version", + "author": "Rule authors", + "license": "Rule license", } - values = [] + values: list[str] = [] for field, friendly_name in fields.items(): value = self.rule.get(field) or self.changelog.get(field) - if isinstance(value, list): - str_value = f'\n\n{AsciiDoc.bulleted_list(value)}' + if value is None: + str_value = "None" + elif isinstance(value, list): + str_value = f"\n\n{AsciiDoc.bulleted_list(value)}" # type: ignore[reportUnknownArgumentType] else: str_value = str(value) - if field == 'from': - str_value += ' ({ref}/common-options.html#date-math[Date Math format], see also <>)' + if field == "from": + str_value += ( + " ({ref}/common-options.html#date-math[Date Math format], see also <>)" + ) - values.extend([AsciiDoc.bold_kv(friendly_name, str_value), '']) + values.extend([AsciiDoc.bold_kv(friendly_name, str_value), ""]) - return '\n'.join(values) + return "\n".join(values) def guide_str(self) -> str: """Add the guide section to the rule detail page.""" - guide = convert_markdown_to_asciidoc(self.rule['note']) - return f'{AsciiDoc.title(4, "Investigation guide")}\n\n\n{guide}' + guide = convert_markdown_to_asciidoc(self.rule["note"]) + return f"{AsciiDoc.title(4, 'Investigation guide')}\n\n\n{guide}" def setup_str(self) -> str: """Add the setup section to the rule detail page.""" - setup = convert_markdown_to_asciidoc(self.rule['setup']) - return f'{AsciiDoc.title(4, "Setup")}\n\n\n{setup}' + setup = convert_markdown_to_asciidoc(self.rule["setup"]) + return f"{AsciiDoc.title(4, 'Setup')}\n\n\n{setup}" def query_str(self) -> str: """Add the query section to the rule detail page.""" - return f'{AsciiDoc.title(4, "Rule query")}\n\n\n{AsciiDoc.code(self.rule["query"])}' + return f"{AsciiDoc.title(4, 'Rule query')}\n\n\n{AsciiDoc.code(self.rule['query'])}" def threat_mapping_str(self) -> str: """Add the threat mapping section to the rule detail page.""" - values = [AsciiDoc.bold_kv('Framework', 'MITRE ATT&CK^TM^'), ''] + values = [AsciiDoc.bold_kv("Framework", "MITRE ATT&CK^TM^"), ""] - for entry in self.rule['threat']: - tactic = entry['tactic'] + for entry in self.rule["threat"]: + tactic = entry["tactic"] entry_values = [ - AsciiDoc.bulleted('Tactic:'), - AsciiDoc.bulleted(f'Name: {tactic["name"]}', depth=2), - AsciiDoc.bulleted(f'ID: {tactic["id"]}', depth=2), - AsciiDoc.bulleted(f'Reference URL: {tactic["reference"]}', depth=2) + AsciiDoc.bulleted("Tactic:"), + AsciiDoc.bulleted(f"Name: {tactic['name']}", depth=2), + AsciiDoc.bulleted(f"ID: {tactic['id']}", depth=2), + AsciiDoc.bulleted(f"Reference URL: {tactic['reference']}", depth=2), ] - techniques = entry.get('technique', []) + techniques = entry.get("technique", []) for technique in techniques: - entry_values.extend([ - AsciiDoc.bulleted('Technique:'), - AsciiDoc.bulleted(f'Name: {technique["name"]}', depth=2), - AsciiDoc.bulleted(f'ID: {technique["id"]}', depth=2), - AsciiDoc.bulleted(f'Reference URL: {technique["reference"]}', depth=2) - ]) - - subtechniques = technique.get('subtechnique', []) + entry_values.extend( + [ + AsciiDoc.bulleted("Technique:"), + AsciiDoc.bulleted(f"Name: {technique['name']}", depth=2), + AsciiDoc.bulleted(f"ID: {technique['id']}", depth=2), + AsciiDoc.bulleted(f"Reference URL: {technique['reference']}", depth=2), + ] + ) + + subtechniques = technique.get("subtechnique", []) for subtechnique in subtechniques: - entry_values.extend([ - AsciiDoc.bulleted('Sub-technique:'), - AsciiDoc.bulleted(f'Name: {subtechnique["name"]}', depth=2), - AsciiDoc.bulleted(f'ID: {subtechnique["id"]}', depth=2), - AsciiDoc.bulleted(f'Reference URL: {subtechnique["reference"]}', depth=2) - ]) + entry_values.extend( + [ + AsciiDoc.bulleted("Sub-technique:"), + AsciiDoc.bulleted(f"Name: {subtechnique['name']}", depth=2), + AsciiDoc.bulleted(f"ID: {subtechnique['id']}", depth=2), + AsciiDoc.bulleted(f"Reference URL: {subtechnique['reference']}", depth=2), + ] + ) values.extend(entry_values) - return '\n'.join(values) + return "\n".join(values) def name_to_title(name: str) -> str: """Convert a rule name to tile.""" - initial = re.sub(r'[^\w]|_', r'-', name.lower().strip()) - return re.sub(r'-{2,}', '-', initial).strip('-') + initial = re.sub(r"[^\w]|_", r"-", name.lower().strip()) + return re.sub(r"-{2,}", "-", initial).strip("-") def convert_markdown_to_asciidoc(text: str) -> str: """Convert investigation guides and setup content from markdown to asciidoc.""" # Format the content after the stripped headers (#) to bold text with newlines. - markdown_header_pattern = re.compile(r'^(#+)\s*(.*?)$', re.MULTILINE) - text = re.sub(markdown_header_pattern, lambda m: f'\n*{m.group(2).strip()}*\n', text) + markdown_header_pattern = re.compile(r"^(#+)\s*(.*?)$", re.MULTILINE) + text = re.sub(markdown_header_pattern, lambda m: f"\n*{m.group(2).strip()}*\n", text) # Convert Markdown links to AsciiDoc format - markdown_link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') - text = re.sub(markdown_link_pattern, lambda m: f'{m.group(2)}[{m.group(1)}]', text) + markdown_link_pattern = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") + text = re.sub(markdown_link_pattern, lambda m: f"{m.group(2)}[{m.group(1)}]", text) return text @@ -650,6 +719,7 @@ def convert_markdown_to_asciidoc(text: str) -> str: @dataclass class UpdateEntry: """A class schema for downloadable update entries.""" + update_version: str date: str new_rules: int @@ -661,20 +731,21 @@ class UpdateEntry: @dataclass class DownloadableUpdates: """A class for managing downloadable updates.""" - packages: List[UpdateEntry] + + packages: list[UpdateEntry] @classmethod def load_updates(cls): """Load the package.""" - prebuilt = load_etc_dump("downloadable_updates.json") - packages = [UpdateEntry(**entry) for entry in prebuilt['packages']] + prebuilt = load_etc_dump(["downloadable_updates.json"]) + packages = [UpdateEntry(**entry) for entry in prebuilt["packages"]] return cls(packages) def save_updates(self): """Save the package.""" sorted_package = sorted(self.packages, key=lambda entry: Version.parse(entry.update_version), reverse=True) - data = {'packages': [asdict(entry) for entry in sorted_package]} - save_etc_dump(data, "downloadable_updates.json") + data = {"packages": [asdict(entry) for entry in sorted_package]} + save_etc_dump(data, ["downloadable_updates.json"]) def add_entry(self, entry: UpdateEntry, overwrite: bool = False): """Add an entry to the package.""" @@ -698,37 +769,37 @@ class MDX: @classmethod def bold(cls, value: str): """Return a bold str in Markdown.""" - return f'**{value}**' + return f"**{value}**" @classmethod def bold_kv(cls, key: str, value: str): """Return a bold key-value pair in Markdown.""" - return f'**{key}**: {value}' + return f"**{key}**: {value}" @classmethod - def description_list(cls, value: Dict[str, str], linesep='\n\n'): + def description_list(cls, value: dict[str, str], linesep: str = "\n\n"): """Create a description list in Markdown.""" - return f'{linesep}'.join(f'**{k}**:\n\n{v}' for k, v in value.items()) + return f"{linesep}".join(f"**{k}**:\n\n{v}" for k, v in value.items()) @classmethod - def bulleted(cls, value: str, depth=1): + def bulleted(cls, value: str, depth: int = 1): """Create a bulleted list item with a specified depth.""" - return f'{" " * (depth - 1)}* {value}' + return f"{' ' * (depth - 1)}* {value}" @classmethod - def bulleted_list(cls, values: Iterable): + def bulleted_list(cls, values: list[str]): """Create a bulleted list from an iterable.""" - return '\n* ' + '\n* '.join(values) + return "\n* " + "\n* ".join(values) @classmethod - def code(cls, value: str, language='js'): + def code(cls, value: str, language: str = "js"): """Return a code block with the specified language.""" return f"```{language}\n{value}```" @classmethod def title(cls, depth: int, value: str): """Create a title with the specified depth.""" - return f'{"#" * depth} {value}' + return f"{'#' * depth} {value}" @classmethod def inline_anchor(cls, value: str): @@ -736,26 +807,31 @@ def inline_anchor(cls, value: str): return f'' @classmethod - def table(cls, data: dict) -> str: + def table(cls, data: dict[str, Any]) -> str: """Create a table from a dictionary.""" - entries = [f'| {k} | {v}' for k, v in data.items()] - table = ['|---|---|'] + entries - return '\n'.join(table) + entries = [f"| {k} | {v}" for k, v in data.items()] + table = ["|---|---|"] + entries + return "\n".join(table) class IntegrationSecurityDocsMDX: """Generate docs for prebuilt rules in Elastic documentation using MDX.""" - def __init__(self, release_version: str, directory: Path, overwrite: bool = False, - historical_package: Optional[Dict[str, dict]] = - None, new_package: Optional[Dict[str, TOMLRule]] = None, - note: Optional[str] = "Rule Updates."): + def __init__( + self, + release_version: str, + directory: Path, + overwrite: bool = False, + new_package: Package | None = None, + historical_package: dict[str, Any] | None = None, + note: str | None = "Rule Updates.", + ): self.historical_package = historical_package self.new_package = new_package self.rule_changes = self.get_rule_changes() - self.included_rules = list(itertools.chain(self.rule_changes["new"], - self.rule_changes["updated"], - self.rule_changes["deprecated"])) + self.included_rules = list( + itertools.chain(self.rule_changes["new"], self.rule_changes["updated"], self.rule_changes["deprecated"]) + ) self.release_version_str, self.base_name, self.prebuilt_rule_base = self.parse_release(release_version) self.package_directory = directory / self.base_name @@ -768,62 +844,69 @@ def __init__(self, release_version: str, directory: Path, overwrite: bool = Fals self.package_directory.mkdir(parents=True, exist_ok=overwrite) @staticmethod - def parse_release(release_version: str) -> (str, str, str): + def parse_release(release_version_val: str) -> tuple[str, str, str]: """Parse the release version into a string, base name, and prebuilt rule base.""" - release_version = Version.parse(release_version) - short_release_version = [str(n) for n in release_version[:3]] - release_version_str = '.'.join(short_release_version) + release_version = Version.parse(release_version_val) + parts = release_version[:3] + short_release_version = [str(n) for n in parts] # type: ignore + release_version_str = ".".join(short_release_version) base_name = "-".join(short_release_version) - prebuilt_rule_base = f'prebuilt-rule-{base_name}' + prebuilt_rule_base = f"prebuilt-rule-{base_name}" return release_version_str, base_name, prebuilt_rule_base def get_rule_changes(self): """Compare the rules from the new_package against rules in the historical_package.""" - rule_changes = defaultdict(list) - rule_changes["new"], rule_changes["updated"], rule_changes["deprecated"] = [], [], [] + rule_changes: dict[str, list[TOMLRule | DeprecatedRule]] = { + "new": [], + "updated": [], + "deprecated": [], + } - historical_rule_ids = set(self.historical_package.keys()) + historical_package: dict[str, Any] = self.historical_package or dict() + historical_rule_ids: set[str] = set(historical_package.keys()) - # Identify new and updated rules - for rule in self.new_package.rules: - rule_to_api_format = rule.contents.to_api_format() + if self.new_package: + # Identify new and updated rules + for rule in self.new_package.rules: + rule_to_api_format = rule.contents.to_api_format() - latest_version = rule_to_api_format["version"] - rule_id = f'{rule.id}_{latest_version}' + latest_version = rule_to_api_format["version"] + rule_id = f"{rule.id}_{latest_version}" - if rule_id not in historical_rule_ids and latest_version == 1: - rule_changes['new'].append(rule) - elif rule_id not in historical_rule_ids: - rule_changes['updated'].append(rule) + if rule_id not in historical_rule_ids and latest_version == 1: + rule_changes["new"].append(rule) + elif rule_id not in historical_rule_ids: + rule_changes["updated"].append(rule) # Identify deprecated rules # if rule is in the historical but not in the current package, its deprecated - deprecated_rule_ids = [] - for _, content in self.historical_package.items(): + deprecated_rule_ids: list[str] = [] + for _, content in historical_package.items(): rule_id = content["attributes"]["rule_id"] - if rule_id in self.new_package.deprecated_rules.id_map.keys(): + if self.new_package and rule_id in self.new_package.deprecated_rules.id_map.keys(): deprecated_rule_ids.append(rule_id) deprecated_rule_ids = list(set(deprecated_rule_ids)) for rule_id in deprecated_rule_ids: - rule_changes['deprecated'].append(self.new_package.deprecated_rules.id_map[rule_id]) + if self.new_package: + rule_changes["deprecated"].append(self.new_package.deprecated_rules.id_map[rule_id]) return dict(rule_changes) def generate_current_rule_summary(self): """Generate a summary of all available current rules in the latest package.""" - slug = f'prebuilt-rules-{self.base_name}-all-available-summary.mdx' + slug = f"prebuilt-rules-{self.base_name}-all-available-summary.mdx" summary = self.package_directory / slug - title = f'Latest rules for Stack Version ^{self.release_version_str}' + title = f"Latest rules for Stack Version ^{self.release_version_str}" summary_header = textwrap.dedent(f""" --- id: {slug} slug: /security-rules/{slug} title: {title} - date: {datetime.today().strftime('%Y-%d-%m')} + date: {datetime.today().strftime("%Y-%d-%m")} tags: ["rules", "security", "detection-rules"] --- @@ -835,23 +918,27 @@ def generate_current_rule_summary(self): |---|---|---|---| """).lstrip() - rule_entries = [] - for rule in self.new_package.rules: - title_name = name_to_title(rule.name) - to_api_format = rule.contents.to_api_format() - tags = ", ".join(to_api_format["tags"]) - rule_entries.append(f'| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | ' - f'{to_api_format["description"]} | {tags} | ' - f'{to_api_format["version"]}') + rule_entries: list[str] = [] + + if self.new_package: + for rule in self.new_package.rules: + title_name = name_to_title(rule.name) + to_api_format = rule.contents.to_api_format() + tags = ", ".join(to_api_format["tags"]) + rule_entries.append( + f"| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | " + f"{to_api_format['description']} | {tags} | " + f"{to_api_format['version']}" + ) rule_entries = sorted(rule_entries) - rule_entries = '\n'.join(rule_entries) + rule_entries_str = "\n".join(rule_entries) - summary.write_text(summary_header + rule_entries) + _ = summary.write_text(summary_header + rule_entries_str) def generate_update_summary(self): """Generate a summary of all rule updates based on the latest package.""" - slug = f'prebuilt-rules-{self.base_name}-update-summary.mdx' + slug = f"prebuilt-rules-{self.base_name}-update-summary.mdx" summary = self.package_directory / slug title = "Current Available Rules" @@ -860,7 +947,7 @@ def generate_update_summary(self): id: {slug} slug: /security-rules/{slug} title: {title} - date: {datetime.today().strftime('%Y-%d-%m')} + date: {datetime.today().strftime("%Y-%d-%m")} tags: ["rules", "security", "detection-rules"] --- @@ -872,52 +959,58 @@ def generate_update_summary(self): |---|---|---|---| """).lstrip() - rule_entries = [] + rule_entries: list[str] = [] new_rule_id_list = [rule.id for rule in self.rule_changes["new"]] updated_rule_id_list = [rule.id for rule in self.rule_changes["updated"]] for rule in self.included_rules: + if not rule.name: + raise ValueError("No rule name found") title_name = name_to_title(rule.name) - status = 'new' if rule.id in new_rule_id_list else 'update' if rule.id in updated_rule_id_list \ - else 'deprecated' + status = ( + "new" if rule.id in new_rule_id_list else "update" if rule.id in updated_rule_id_list else "deprecated" + ) to_api_format = rule.contents.to_api_format() - rule_entries.append(f'| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | ' - f'{to_api_format["description"]} | {status} | ' - f'{to_api_format["version"]}') + rule_entries.append( + f"| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | " + f"{to_api_format['description']} | {status} | " + f"{to_api_format['version']}" + ) rule_entries = sorted(rule_entries) - rule_entries = '\n'.join(rule_entries) + rule_entries_str = "\n".join(rule_entries) - summary.write_text(summary_header + rule_entries) + _ = summary.write_text(summary_header + rule_entries_str) def generate_rule_details(self): """Generate a markdown file for each rule.""" rules_dir = self.package_directory / "rules" rules_dir.mkdir(exist_ok=True) - for rule in self.new_package.rules: - slug = f'{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx' - rule_detail = IntegrationRuleDetailMDX(rule.id, rule.contents.to_api_format(), {}, self.base_name) - rule_path = rules_dir / slug - tags = ', '.join(f"\"{tag}\"" for tag in rule.contents.data.tags) - frontmatter = textwrap.dedent(f""" - --- - id: {slug} - slug: /security-rules/{slug} - title: {rule.name} - date: {datetime.today().strftime('%Y-%d-%m')} - tags: [{tags}] - --- - - """).lstrip() - rule_path.write_text(frontmatter + rule_detail.generate()) + if self.new_package: + for rule in self.new_package.rules: + slug = f"{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx" + rule_detail = IntegrationRuleDetailMDX(rule.id, rule.contents.to_api_format(), {}, self.base_name) + rule_path = rules_dir / slug + tags = ", ".join(f'"{tag}"' for tag in rule.contents.data.tags) # type: ignore[reportOptionalIterable] + frontmatter = textwrap.dedent(f""" + --- + id: {slug} + slug: /security-rules/{slug} + title: {rule.name} + date: {datetime.today().strftime("%Y-%d-%m")} + tags: [{tags}] + --- + + """).lstrip() + _ = rule_path.write_text(frontmatter + rule_detail.generate()) def generate_downloadable_updates_summary(self): """Generate a summary of all the downloadable updates.""" - docs_url = 'https://www.elastic.co/guide/en/security/current/rules-ui-management.html#update-prebuilt-rules' - slug = 'prebuilt-rules-downloadable-packages-summary.mdx' + docs_url = "https://www.elastic.co/guide/en/security/current/rules-ui-management.html#update-prebuilt-rules" + slug = "prebuilt-rules-downloadable-packages-summary.mdx" title = "Downloadable rule updates" summary = self.package_directory / slug - today = datetime.today().strftime('%d %b %Y') + today = datetime.today().strftime("%d %b %Y") package_list = DownloadableUpdates.load_updates() ref = f"./prebuilt-rules-{self.base_name}-update-summary.mdx" @@ -927,8 +1020,8 @@ def generate_downloadable_updates_summary(self): date=today, new_rules=len(self.rule_changes["new"]), updated_rules=len(self.rule_changes["updated"]), - note=self.note, - url=ref + note=self.note or "", + url=ref, ) package_list.add_entry(new_entry, self.overwrite) @@ -941,7 +1034,7 @@ def generate_downloadable_updates_summary(self): id: {slug} slug: /security-rules/{slug} title: {title} - date: {datetime.today().strftime('%Y-%d-%m')} + date: {datetime.today().strftime("%Y-%d-%m")} tags: ["rules", "security", "detection-rules"] --- @@ -957,13 +1050,15 @@ def generate_downloadable_updates_summary(self): |---|---|---|---|---| """).lstrip() - entries = [] + entries: list[str] = [] for entry in sorted(package_list.packages, key=lambda entry: Version.parse(entry.update_version), reverse=True): - entries.append(f'| [{entry.update_version}]({entry.url}) | {today} |' - f' {entry.new_rules} | {entry.updated_rules} | {entry.note}| ') + entries.append( + f"| [{entry.update_version}]({entry.url}) | {today} |" + f" {entry.new_rules} | {entry.updated_rules} | {entry.note}| " + ) - entries = '\n'.join(entries) - summary.write_text(summary_header + entries) + entries_str = "\n".join(entries) + _ = summary.write_text(summary_header + entries_str) def generate(self) -> Path: """Generate the updates.""" @@ -986,7 +1081,7 @@ def generate(self) -> Path: class IntegrationRuleDetailMDX: """Generates a rule detail page in Markdown.""" - def __init__(self, rule_id: str, rule: dict, changelog: Dict[str, dict], package_str: str): + def __init__(self, rule_id: str, rule: dict[str, Any], changelog: dict[str, dict[str, Any]], package_str: str): """Initialize with rule ID, rule details, changelog, and package string. >>> rule_file = "/path/to/rule.toml" @@ -999,30 +1094,23 @@ def __init__(self, rule_id: str, rule: dict, changelog: Dict[str, dict], package self.rule = rule self.changelog = changelog self.package = package_str - self.rule_title = f'prebuilt-rule-{self.package}-{name_to_title(self.rule["name"])}' + self.rule_title = f"prebuilt-rule-{self.package}-{name_to_title(self.rule['name'])}" # set some defaults - self.rule.setdefault('max_signals', 100) - self.rule.setdefault('interval', '5m') + self.rule.setdefault("max_signals", 100) + self.rule.setdefault("interval", "5m") def generate(self) -> str: """Generate the rule detail page in Markdown.""" - page = [ - MDX.title(1, self.rule["name"]), - '', - self.rule['description'], - '', - self.metadata_str(), - '' - ] - if 'note' in self.rule: - page.extend([self.guide_str(), '']) - if 'query' in self.rule: - page.extend([self.query_str(), '']) - if 'threat' in self.rule: - page.extend([self.threat_mapping_str(), '']) + page = [MDX.title(1, self.rule["name"]), "", self.rule["description"], "", self.metadata_str(), ""] + if "note" in self.rule: + page.extend([self.guide_str(), ""]) + if "query" in self.rule: + page.extend([self.query_str(), ""]) + if "threat" in self.rule: + page.extend([self.threat_mapping_str(), ""]) - return '\n'.join(page) + return "\n".join(page) def metadata_str(self) -> str: """Generate the metadata section for the rule detail page.""" @@ -1030,73 +1118,79 @@ def metadata_str(self) -> str: date_math_doc = "https://www.elastic.co/guide/en/elasticsearch/reference/current/common-options.html#date-math" loopback_doc = "https://www.elastic.co/guide/en/security/current/rules-ui-create.html#rule-schedule" fields = { - 'type': 'Rule type', - 'index': 'Rule indices', - 'severity': 'Severity', - 'risk_score': 'Risk score', - 'interval': 'Runs every', - 'from': 'Searches indices from', - 'max_signals': 'Maximum alerts per execution', - 'references': 'References', - 'tags': 'Tags', - 'version': 'Version', - 'author': 'Rule authors', - 'license': 'Rule license' + "type": "Rule type", + "index": "Rule indices", + "severity": "Severity", + "risk_score": "Risk score", + "interval": "Runs every", + "from": "Searches indices from", + "max_signals": "Maximum alerts per execution", + "references": "References", + "tags": "Tags", + "version": "Version", + "author": "Rule authors", + "license": "Rule license", } - values = [] + values: list[str] = [] for field, friendly_name in fields.items(): value = self.rule.get(field) or self.changelog.get(field) - if isinstance(value, list): - str_value = MDX.bulleted_list(value) + if value is None: + str_value = "NONE" + elif isinstance(value, list): + str_value = MDX.bulleted_list(value) # type: ignore[reportUnknownArgumentType] else: str_value = str(value) - if field == 'from': - str_value += f' ([Date Math format]({date_math_doc}), [Additional look-back time]({loopback_doc}))' + if field == "from": + str_value += f" ([Date Math format]({date_math_doc}), [Additional look-back time]({loopback_doc}))" values.append(MDX.bold_kv(friendly_name, str_value)) - return '\n\n'.join(values) + return "\n\n".join(values) def guide_str(self) -> str: """Generate the investigation guide section for the rule detail page.""" - return f'{MDX.title(2, "Investigation guide")}\n\n{MDX.code(self.rule["note"], "markdown")}' + return f"{MDX.title(2, 'Investigation guide')}\n\n{MDX.code(self.rule['note'], 'markdown')}" def query_str(self) -> str: """Generate the rule query section for the rule detail page.""" - return f'{MDX.title(2, "Rule query")}\n\n{MDX.code(self.rule["query"], "sql")}' + return f"{MDX.title(2, 'Rule query')}\n\n{MDX.code(self.rule['query'], 'sql')}" def threat_mapping_str(self) -> str: """Generate the threat mapping section for the rule detail page.""" - values = [MDX.bold_kv('Framework', 'MITRE ATT&CK^TM^')] + values = [MDX.bold_kv("Framework", "MITRE ATT&CK^TM^")] - for entry in self.rule['threat']: - tactic = entry['tactic'] + for entry in self.rule["threat"]: + tactic = entry["tactic"] entry_values = [ - MDX.bulleted(MDX.bold('Tactic:')), - MDX.bulleted(f'Name: {tactic["name"]}', depth=2), - MDX.bulleted(f'ID: {tactic["id"]}', depth=2), - MDX.bulleted(f'Reference URL: {tactic["reference"]}', depth=2) + MDX.bulleted(MDX.bold("Tactic:")), + MDX.bulleted(f"Name: {tactic['name']}", depth=2), + MDX.bulleted(f"ID: {tactic['id']}", depth=2), + MDX.bulleted(f"Reference URL: {tactic['reference']}", depth=2), ] - techniques = entry.get('technique', []) + techniques = entry.get("technique", []) for technique in techniques: - entry_values.extend([ - MDX.bulleted('Technique:'), - MDX.bulleted(f'Name: {technique["name"]}', depth=3), - MDX.bulleted(f'ID: {technique["id"]}', depth=3), - MDX.bulleted(f'Reference URL: {technique["reference"]}', depth=3) - ]) - - subtechniques = technique.get('subtechnique', []) + entry_values.extend( + [ + MDX.bulleted("Technique:"), + MDX.bulleted(f"Name: {technique['name']}", depth=3), + MDX.bulleted(f"ID: {technique['id']}", depth=3), + MDX.bulleted(f"Reference URL: {technique['reference']}", depth=3), + ] + ) + + subtechniques = technique.get("subtechnique", []) for subtechnique in subtechniques: - entry_values.extend([ - MDX.bulleted('Sub-technique:'), - MDX.bulleted(f'Name: {subtechnique["name"]}', depth=3), - MDX.bulleted(f'ID: {subtechnique["id"]}', depth=3), - MDX.bulleted(f'Reference URL: {subtechnique["reference"]}', depth=4) - ]) + entry_values.extend( + [ + MDX.bulleted("Sub-technique:"), + MDX.bulleted(f"Name: {subtechnique['name']}", depth=3), + MDX.bulleted(f"ID: {subtechnique['id']}", depth=3), + MDX.bulleted(f"Reference URL: {subtechnique['reference']}", depth=4), + ] + ) values.extend(entry_values) - return '\n'.join(values) + return "\n".join(values) diff --git a/detection_rules/ecs.py b/detection_rules/ecs.py index e3fe2a66247..7a3322faf3e 100644 --- a/detection_rules/ecs.py +++ b/detection_rules/ecs.py @@ -4,32 +4,34 @@ # 2.0. """ECS Schemas management.""" + import copy import glob import json import os import shutil -import eql -import eql.types +import eql # type: ignore[reportMissingTypeStubs] +import eql.types # type: ignore[reportMissingTypeStubs] import requests from semver import Version import yaml +from typing import Any + from .config import CUSTOM_RULES_DIR, parse_rules_config from .custom_schemas import get_custom_schemas from .integrations import load_integrations_schemas -from .utils import (DateTimeEncoder, cached, get_etc_path, gzip_compress, - load_etc_dump, read_gzip, unzip) +from .utils import DateTimeEncoder, cached, get_etc_path, gzip_compress, load_etc_dump, read_gzip, unzip ECS_NAME = "ecs_schemas" -ECS_SCHEMAS_DIR = get_etc_path(ECS_NAME) +ECS_SCHEMAS_DIR = get_etc_path([ECS_NAME]) ENDPOINT_NAME = "endpoint_schemas" -ENDPOINT_SCHEMAS_DIR = get_etc_path(ENDPOINT_NAME) +ENDPOINT_SCHEMAS_DIR = get_etc_path([ENDPOINT_NAME]) RULES_CONFIG = parse_rules_config() -def add_field(schema, name, info): +def add_field(schema: dict[str, Any], name: str, info: Any): """Nest a dotted field within a dictionary.""" if "." not in name: schema[name] = info @@ -41,7 +43,7 @@ def add_field(schema, name, info): add_field(schema, remaining, info) -def _recursive_merge(existing, new, depth=0): +def _recursive_merge(existing: dict[str, Any], new: dict[str, Any], depth: int = 0): """Return an existing dict merged into a new one.""" for key, value in existing.items(): if isinstance(value, dict): @@ -49,7 +51,7 @@ def _recursive_merge(existing, new, depth=0): new = copy.deepcopy(new) node = new.setdefault(key, {}) - _recursive_merge(value, node, depth + 1) + _ = _recursive_merge(value, node, depth + 1) # type: ignore[reportUnknownArgumentType] else: new[key] = value @@ -58,16 +60,16 @@ def _recursive_merge(existing, new, depth=0): def get_schema_files(): """Get schema files from ecs directory.""" - return glob.glob(os.path.join(ECS_SCHEMAS_DIR, '*', '*.json.gz'), recursive=True) + return glob.glob(os.path.join(ECS_SCHEMAS_DIR, "*", "*.json.gz"), recursive=True) def get_schema_map(): """Get local schema files by version.""" - schema_map = {} + schema_map: dict[str, Any] = {} for file_name in get_schema_files(): path, name = os.path.split(file_name) - name = name.split('.')[0] + name = name.split(".")[0] version = os.path.basename(path) schema_map.setdefault(version, {})[name] = file_name @@ -86,40 +88,40 @@ def get_schemas(): return schema_map -def get_max_version(include_master=False): +def get_max_version(include_master: bool = False): """Get maximum available schema version.""" versions = get_schema_map().keys() - if include_master and any([v.startswith('master') for v in versions]): - return list(ECS_SCHEMAS_DIR.glob('master*'))[0].name + if include_master and any([v.startswith("master") for v in versions]): + return list(ECS_SCHEMAS_DIR.glob("master*"))[0].name - return str(max([Version.parse(v) for v in versions if not v.startswith('master')])) + return str(max([Version.parse(v) for v in versions if not v.startswith("master")])) @cached -def get_schema(version=None, name='ecs_flat'): +def get_schema(version: str | None = None, name: str = "ecs_flat") -> dict[str, Any]: """Get schema by version.""" - if version == 'master': + if version == "master": version = get_max_version(include_master=True) return get_schemas()[version or str(get_max_version())][name] @cached -def get_eql_schema(version=None, index_patterns=None): +def get_eql_schema(version: str | None = None, index_patterns: list[str] | None = None): """Return schema in expected format for eql.""" - schema = get_schema(version, name='ecs_flat') - str_types = ('text', 'ip', 'keyword', 'date', 'object', 'geo_point') - num_types = ('float', 'integer', 'long') + schema = get_schema(version, name="ecs_flat") + str_types = ("text", "ip", "keyword", "date", "object", "geo_point") + num_types = ("float", "integer", "long") schema = schema.copy() - def convert_type(t): - return 'string' if t in str_types else 'number' if t in num_types else 'boolean' + def convert_type(t: str): + return "string" if t in str_types else "number" if t in num_types else "boolean" - converted = {} + converted: dict[str, Any] = {} for field, schema_info in schema.items(): - field_type = schema_info.get('type', '') + field_type = schema_info.get("type", "") add_field(converted, field, convert_type(field_type)) # add non-ecs schema @@ -141,20 +143,20 @@ def convert_type(t): return converted -def flatten(schema): - flattened = {} +def flatten(schema: dict[str, Any]) -> dict[str, Any]: + flattened: dict[str, Any] = {} for k, v in schema.items(): if isinstance(v, dict): - flattened.update((k + "." + vk, vv) for vk, vv in flatten(v).items()) + flattened.update((k + "." + vk, vv) for vk, vv in flatten(v).items()) # type: ignore[reportUnknownArgumentType] else: flattened[k] = v return flattened @cached -def get_all_flattened_schema() -> dict: +def get_all_flattened_schema() -> dict[str, Any]: """Load all schemas into a flattened dictionary.""" - all_flattened_schema = {} + all_flattened_schema: dict[str, Any] = {} for _, schema in get_non_ecs_schema().items(): all_flattened_schema.update(flatten(schema)) @@ -179,31 +181,31 @@ def get_all_flattened_schema() -> dict: @cached def get_non_ecs_schema(): """Load non-ecs schema.""" - return load_etc_dump('non-ecs-schema.json') + return load_etc_dump(["non-ecs-schema.json"]) @cached -def get_custom_index_schema(index_name: str, stack_version: str = None): +def get_custom_index_schema(index_name: str, stack_version: str | None = None): """Load custom schema.""" custom_schemas = get_custom_schemas(stack_version) index_schema = custom_schemas.get(index_name, {}) - ccs_schema = custom_schemas.get(index_name.replace('::', ':').split(":", 1)[-1], {}) + ccs_schema = custom_schemas.get(index_name.replace("::", ":").split(":", 1)[-1], {}) index_schema.update(ccs_schema) return index_schema @cached -def get_index_schema(index_name): +def get_index_schema(index_name: str): """Load non-ecs schema.""" non_ecs_schema = get_non_ecs_schema() index_schema = non_ecs_schema.get(index_name, {}) - ccs_schema = non_ecs_schema.get(index_name.replace('::', ':').split(":", 1)[-1], {}) + ccs_schema = non_ecs_schema.get(index_name.replace("::", ":").split(":", 1)[-1], {}) index_schema.update(ccs_schema) return index_schema -def flatten_multi_fields(schema): - converted = {} +def flatten_multi_fields(schema: dict[str, Any]) -> dict[str, Any]: + converted: dict[str, Any] = {} for field, info in schema.items(): converted[field] = info["type"] for subfield in info.get("multi_fields", []): @@ -224,20 +226,23 @@ class KqlSchema2Eql(eql.Schema): "boolean": eql.types.TypeHint.Boolean, } - def __init__(self, kql_schema): + def __init__(self, kql_schema: dict[str, Any]): self.kql_schema = kql_schema - eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) + eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) # type: ignore[reportUnknownMemberType] - def validate_event_type(self, event_type): + def validate_event_type(self, event_type: Any): # allow all event types to fill in X: # `X` where .... return True - def get_event_type_hint(self, event_type, path): - from kql.parser import elasticsearch_type_family + def get_event_type_hint(self, event_type: Any, path: list[str]): + from kql.parser import elasticsearch_type_family # type: ignore[reportMissingTypeStubs] dotted = ".".join(path) elasticsearch_type = self.kql_schema.get(dotted) + if not elasticsearch_type: + return None + es_type_family = elasticsearch_type_family(elasticsearch_type) eql_hint = self.type_mapping.get(es_type_family) @@ -246,10 +251,14 @@ def get_event_type_hint(self, event_type, path): @cached -def get_kql_schema(version=None, indexes=None, beat_schema=None) -> dict: +def get_kql_schema( + version: str | None = None, + indexes: list[str] | None = None, + beat_schema: dict[str, Any] | None = None, +) -> dict[str, Any]: """Get schema for KQL.""" - indexes = indexes or () - converted = flatten_multi_fields(get_schema(version, name='ecs_flat')) + indexes = indexes or [] + converted = flatten_multi_fields(get_schema(version, name="ecs_flat")) # non-ecs schema for index_name in indexes: @@ -269,14 +278,14 @@ def get_kql_schema(version=None, indexes=None, beat_schema=None) -> dict: return converted -def download_schemas(refresh_master=True, refresh_all=False, verbose=True): +def download_schemas(refresh_master: bool = True, refresh_all: bool = False, verbose: bool = True): """Download additional schemas from ecs releases.""" existing = [Version.parse(v) for v in get_schema_map()] if not refresh_all else [] - url = 'https://api.github.com/repos/elastic/ecs/releases' + url = "https://api.github.com/repos/elastic/ecs/releases" releases = requests.get(url) for release in releases.json(): - version = Version.parse(release.get('tag_name', '').lstrip('v')) + version = Version.parse(release.get("tag_name", "").lstrip("v")) # we don't ever want beta if not version or version < Version.parse("1.0.1") or version in existing: @@ -284,13 +293,13 @@ def download_schemas(refresh_master=True, refresh_all=False, verbose=True): schema_dir = os.path.join(ECS_SCHEMAS_DIR, str(version)) - with unzip(requests.get(release['zipball_url']).content) as archive: + with unzip(requests.get(release["zipball_url"]).content) as archive: name_list = archive.namelist() base = name_list[0] # members = [m for m in name_list if m.startswith('{}{}/'.format(base, 'use-cases')) and m.endswith('.yml')] - members = ['{}generated/ecs/ecs_flat.yml'.format(base), '{}generated/ecs/ecs_nested.yml'.format(base)] - saved = [] + members = ["{}generated/ecs/ecs_flat.yml".format(base), "{}generated/ecs/ecs_nested.yml".format(base)] + saved: list[str] = [] for member in members: file_name = os.path.basename(member) @@ -301,38 +310,38 @@ def download_schemas(refresh_master=True, refresh_all=False, verbose=True): out_file = file_name.replace(".yml", ".json.gz") compressed = gzip_compress(json.dumps(contents, sort_keys=True, cls=DateTimeEncoder)) - new_path = get_etc_path(ECS_NAME, str(version), out_file) - with open(new_path, 'wb') as f: - f.write(compressed) + new_path = get_etc_path([ECS_NAME, str(version), out_file]) + with open(new_path, "wb") as f: + _ = f.write(compressed) saved.append(out_file) if verbose: - print('Saved files to {}: \n\t- {}'.format(schema_dir, '\n\t- '.join(saved))) + print("Saved files to {}: \n\t- {}".format(schema_dir, "\n\t- ".join(saved))) # handle working master separately if refresh_master: - master_ver = requests.get('https://raw.githubusercontent.com/elastic/ecs/master/version') + master_ver = requests.get("https://raw.githubusercontent.com/elastic/ecs/master/version") master_ver = Version.parse(master_ver.text.strip()) - master_schema = requests.get('https://raw.githubusercontent.com/elastic/ecs/master/generated/ecs/ecs_flat.yml') + master_schema = requests.get("https://raw.githubusercontent.com/elastic/ecs/master/generated/ecs/ecs_flat.yml") master_schema = yaml.safe_load(master_schema.text) # prepend with underscore so that we can differentiate the fact that this is a working master version # but first clear out any existing masters, since we only ever want 1 at a time - existing_master = glob.glob(os.path.join(ECS_SCHEMAS_DIR, 'master_*')) + existing_master = glob.glob(os.path.join(ECS_SCHEMAS_DIR, "master_*")) for m in existing_master: shutil.rmtree(m, ignore_errors=True) master_dir = "master_{}".format(master_ver) - os.makedirs(get_etc_path(ECS_NAME, master_dir), exist_ok=True) + os.makedirs(get_etc_path([ECS_NAME, master_dir]), exist_ok=True) compressed = gzip_compress(json.dumps(master_schema, sort_keys=True, cls=DateTimeEncoder)) - new_path = get_etc_path(ECS_NAME, master_dir, "ecs_flat.json.gz") - with open(new_path, 'wb') as f: - f.write(compressed) + new_path = get_etc_path([ECS_NAME, master_dir, "ecs_flat.json.gz"]) + with open(new_path, "wb") as f: + _ = f.write(compressed) if verbose: - print('Saved files to {}: \n\t- {}'.format(master_dir, 'ecs_flat.json.gz')) + print("Saved files to {}: \n\t- {}".format(master_dir, "ecs_flat.json.gz")) def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: @@ -351,11 +360,11 @@ def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: # iterate over nested fields and flatten them for f in fields: - if 'multi_fields' in f: - for mf in f['multi_fields']: - flattened[f"{root_name}.{f['name']}.{mf['name']}"] = mf['type'] + if "multi_fields" in f: + for mf in f["multi_fields"]: + flattened[f"{root_name}.{f['name']}.{mf['name']}"] = mf["type"] else: - flattened[f"{root_name}.{f['name']}"] = f['type'] + flattened[f"{root_name}.{f['name']}"] = f["type"] # save schema to disk ENDPOINT_SCHEMAS_DIR.mkdir(parents=True, exist_ok=True) @@ -363,16 +372,16 @@ def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: new_path = ENDPOINT_SCHEMAS_DIR / f"endpoint_{target}.json.gz" if overwrite: shutil.rmtree(new_path, ignore_errors=True) - with open(new_path, 'wb') as f: - f.write(compressed) + with open(new_path, "wb") as f: + _ = f.write(compressed) print(f"Saved endpoint schema to {new_path}") @cached -def get_endpoint_schemas() -> dict: +def get_endpoint_schemas() -> dict[str, Any]: """Load endpoint schemas.""" - schema = {} - existing = glob.glob(os.path.join(ENDPOINT_SCHEMAS_DIR, '*.json.gz')) + schema: dict[str, Any] = {} + existing = glob.glob(os.path.join(ENDPOINT_SCHEMAS_DIR, "*.json.gz")) for f in existing: schema.update(json.loads(read_gzip(f))) return schema diff --git a/detection_rules/endgame.py b/detection_rules/endgame.py index 4ed6bd6246d..809c168b974 100644 --- a/detection_rules/endgame.py +++ b/detection_rules/endgame.py @@ -4,12 +4,16 @@ # 2.0. """Endgame Schemas management.""" + import json import shutil import sys -import eql +import eql # type: ignore[reportMissingTypeStubs] + +from typing import Any +from github import Github from .utils import ETC_DIR, DateTimeEncoder, cached, gzip_compress, read_gzip ENDGAME_SCHEMA_DIR = ETC_DIR / "endgame_schemas" @@ -18,12 +22,12 @@ class EndgameSchemaManager: """Endgame Class to download, convert, and save endgame schemas from endgame-evecs.""" - def __init__(self, github_client, endgame_version: str): + def __init__(self, github_client: Github, endgame_version: str): self.repo = github_client.get_repo("elastic/endgame-evecs") self.endgame_version = endgame_version self.endgame_schema = self.download_endgame_schema() - def download_endgame_schema(self) -> dict: + def download_endgame_schema(self) -> dict[str, Any]: """Download schema from endgame-evecs.""" # Use the static mapping.json file downloaded from the endgame-evecs repo. @@ -31,7 +35,7 @@ def download_endgame_schema(self) -> dict: main_branch_sha = main_branch.commit.sha schema_path = "pkg/mapper/ecs/schema.json" contents = self.repo.get_contents(schema_path, ref=main_branch_sha) - endgame_mapping = json.loads(contents.decoded_content.decode()) + endgame_mapping = json.loads(contents.decoded_content.decode()) # type: ignore[reportAttributeAccessIssue] return endgame_mapping @@ -49,31 +53,32 @@ def save_schemas(self, overwrite: bool = False): raw_os_schema = self.endgame_schema os_schema_path = schemas_dir / "endgame_ecs_mapping.json.gz" compressed = gzip_compress(json.dumps(raw_os_schema, sort_keys=True, cls=DateTimeEncoder)) - os_schema_path.write_bytes(compressed) + _ = os_schema_path.write_bytes(compressed) print(f"Endgame raw schema file saved: {os_schema_path}") class EndgameSchema(eql.Schema): """Endgame schema for query validation.""" - type_mapping = { - "keyword": eql.types.TypeHint.String, - "ip": eql.types.TypeHint.String, - "float": eql.types.TypeHint.Numeric, - "integer": eql.types.TypeHint.Numeric, - "boolean": eql.types.TypeHint.Boolean, - "text": eql.types.TypeHint.String, + type_mapping: dict[str, Any] = { + "keyword": eql.types.TypeHint.String, # type: ignore[reportAttributeAccessIssue] + "ip": eql.types.TypeHint.String, # type: ignore[reportAttributeAccessIssue] + "float": eql.types.TypeHint.Numeric, # type: ignore[reportAttributeAccessIssue] + "integer": eql.types.TypeHint.Numeric, # type: ignore[reportAttributeAccessIssue] + "boolean": eql.types.TypeHint.Boolean, # type: ignore[reportAttributeAccessIssue] + "text": eql.types.TypeHint.String, # type: ignore[reportAttributeAccessIssue] } - def __init__(self, endgame_schema): + def __init__(self, endgame_schema: dict[str, Any]): self.endgame_schema = endgame_schema - eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) + eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) # type: ignore[reportUnknownMemberType] + + def get_event_type_hint(self, event_type: str, path: list[str]): + from kql.parser import elasticsearch_type_family # type: ignore[reportMissingTypeStubs] - def get_event_type_hint(self, event_type, path): - from kql.parser import elasticsearch_type_family dotted = ".".join(str(p) for p in path) elasticsearch_type = self.endgame_schema.get(dotted) - es_type_family = elasticsearch_type_family(elasticsearch_type) + es_type_family = elasticsearch_type_family(elasticsearch_type) # type: ignore[reportArgumentType] eql_hint = self.type_mapping.get(es_type_family) if eql_hint is not None: @@ -81,7 +86,7 @@ def get_event_type_hint(self, event_type, path): @cached -def read_endgame_schema(endgame_version: str, warn=False) -> dict: +def read_endgame_schema(endgame_version: str, warn: bool = False) -> dict[str, Any] | None: """Load Endgame json schema. The schemas must be generated with the `download_endgame_schema()` method.""" # expect versions to be in format of N.N.N or master/main diff --git a/detection_rules/eswrap.py b/detection_rules/eswrap.py index eaca05b553d..56d85c24194 100644 --- a/detection_rules/eswrap.py +++ b/detection_rules/eswrap.py @@ -4,70 +4,72 @@ # 2.0. """Elasticsearch cli commands.""" + import json import os import sys import time +from pathlib import Path from collections import defaultdict -from typing import List, Union +from typing import Any, IO import click import elasticsearch from elasticsearch import Elasticsearch from elasticsearch.client import AsyncSearchClient -import kql +import kql # type: ignore[reportMissingTypeStubs] from .config import parse_rules_config from .main import root -from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get +from .misc import add_params, raise_client_error, elasticsearch_options, get_elasticsearch_client, nested_get from .rule import TOMLRule from .rule_loader import RuleCollection -from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path +from .utils import event_sort, format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path -COLLECTION_DIR = get_path('collections') -MATCH_ALL = {'bool': {'filter': [{'match_all': {}}]}} +COLLECTION_DIR = get_path(["collections"]) +MATCH_ALL: dict[str, dict[str, Any]] = {"bool": {"filter": [{"match_all": {}}]}} RULES_CONFIG = parse_rules_config() -def add_range_to_dsl(dsl_filter, start_time, end_time='now'): +def add_range_to_dsl(dsl_filter: list[dict[str, Any]], start_time: str, end_time: str = "now"): dsl_filter.append( {"range": {"@timestamp": {"gt": start_time, "lte": end_time, "format": "strict_date_optional_time"}}} ) -def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict): - parsed_results = defaultdict(lambda: defaultdict(int)) - hits = search_results['hits'] - hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', []) +def parse_unique_field_results(rule_type: str, unique_fields: list[str], search_results: dict[str, Any]): + parsed_results: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + hits = search_results["hits"] + hits = hits["hits"] if rule_type != "eql" else hits.get("events") or hits.get("sequences", []) for hit in hits: for field in unique_fields: - if 'events' in hit: - match = [] - for event in hit['events']: - matched = nested_get(event['_source'], field) - match.extend([matched] if not isinstance(matched, list) else matched) + if "events" in hit: + match: list[Any] = [] + for event in hit["events"]: + matched = nested_get(event["_source"], field) + match.extend([matched] if not isinstance(matched, list) else matched) # type: ignore[reportUnknownArgumentType] if not match: continue else: - match = nested_get(hit['_source'], field) + match = nested_get(hit["_source"], field) if not match: continue - match = ','.join(sorted(match)) if isinstance(match, list) else match - parsed_results[field][match] += 1 + match = ",".join(sorted(match)) if isinstance(match, list) else match # type: ignore[reportUnknownArgumentType] + parsed_results[field][match] += 1 # type: ignore[reportUnknownArgumentType] # if rule.type == eql, structure is different - return {'results': parsed_results} if parsed_results else {} + return {"results": parsed_results} if parsed_results else {} class Events: """Events collected from Elasticsearch.""" - def __init__(self, events): - self.events: dict = self._normalize_event_timing(events) + def __init__(self, events: dict[str, Any]): + self.events = self._normalize_event_timing(events) @staticmethod - def _normalize_event_timing(events): + def _normalize_event_timing(events: dict[str, Any]) -> dict[str, Any]: """Normalize event timestamps and sort.""" for agent_type, _events in events.items(): events[agent_type] = normalize_timing_and_sort(_events) @@ -75,338 +77,435 @@ def _normalize_event_timing(events): return events @staticmethod - def _get_dump_dir(rta_name=None, host_id=None, host_os_family=None): + def _get_dump_dir( + rta_name: str | None = None, host_id: str | None = None, host_os_family: str | None = None + ) -> Path: """Prepare and get the dump path.""" if rta_name and host_os_family: - dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name, host_os_family) + dump_dir = get_path(["unit_tests", "data", "true_positives", rta_name, host_os_family]) os.makedirs(dump_dir, exist_ok=True) return dump_dir else: - time_str = time.strftime('%Y%m%dT%H%M%SL') - dump_dir = os.path.join(COLLECTION_DIR, host_id or 'unknown_host', time_str) + time_str = time.strftime("%Y%m%dT%H%M%SL") + dump_dir = os.path.join(COLLECTION_DIR, host_id or "unknown_host", time_str) os.makedirs(dump_dir, exist_ok=True) - return dump_dir + return Path(dump_dir) - def evaluate_against_rule(self, rule_id, verbose=True): + def evaluate_against_rule(self, rule_id: str, verbose: bool = True): """Evaluate a rule against collected events and update mapping.""" - from .utils import combine_sources, evaluate - rule = RuleCollection.default().id_map.get(rule_id) assert rule is not None, f"Unable to find rule with ID {rule_id}" merged_events = combine_sources(*self.events.values()) filtered = evaluate(rule, merged_events, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) if verbose: - click.echo('Matching results found') + click.echo("Matching results found") return filtered - def echo_events(self, pager=False, pretty=True): + def echo_events(self, pager: bool = False, pretty: bool = True): """Print events to stdout.""" echo_fn = click.echo_via_pager if pager else click.echo echo_fn(json.dumps(self.events, indent=2 if pretty else None, sort_keys=True)) - def save(self, rta_name=None, dump_dir=None, host_id=None): + def save(self, rta_name: str | None = None, dump_dir: Path | None = None, host_id: str | None = None): """Save collected events.""" - assert self.events, 'Nothing to save. Run Collector.run() method first or verify logging' + assert self.events, "Nothing to save. Run Collector.run() method first or verify logging" host_os_family = None for key in self.events.keys(): - if self.events.get(key, {})[0].get('host', {}).get('id') == host_id: - host_os_family = self.events.get(key, {})[0].get('host', {}).get('os').get('family') + if self.events.get(key, {})[0].get("host", {}).get("id") == host_id: + host_os_family = self.events.get(key, {})[0].get("host", {}).get("os").get("family") break if not host_os_family: - click.echo('Unable to determine host.os.family for host_id: {}'.format(host_id)) - host_os_family = click.prompt("Please enter the host.os.family for this host_id", - type=click.Choice(["windows", "macos", "linux"]), default="windows") + click.echo("Unable to determine host.os.family for host_id: {}".format(host_id)) + host_os_family = click.prompt( + "Please enter the host.os.family for this host_id", + type=click.Choice(["windows", "macos", "linux"]), + default="windows", + ) dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id, host_os_family=host_os_family) for source, events in self.events.items(): - path = os.path.join(dump_dir, source + '.ndjson') - with open(path, 'w') as f: - f.writelines([json.dumps(e, sort_keys=True) + '\n' for e in events]) - click.echo('{} events saved to: {}'.format(len(events), path)) + path = os.path.join(dump_dir, source + ".ndjson") + with open(path, "w") as f: + f.writelines([json.dumps(e, sort_keys=True) + "\n" for e in events]) + click.echo("{} events saved to: {}".format(len(events), path)) class CollectEvents(object): """Event collector for elastic stack.""" - def __init__(self, client, max_events=3000): - self.client: Elasticsearch = client + def __init__(self, client: Elasticsearch, max_events: int = 3000): + self.client = client self.max_events = max_events - def _build_timestamp_map(self, index_str): + def _build_timestamp_map(self, index: str): """Build a mapping of indexes to timestamp data formats.""" - mappings = self.client.indices.get_mapping(index=index_str) - timestamp_map = {n: m['mappings'].get('properties', {}).get('@timestamp', {}) for n, m in mappings.items()} + mappings = self.client.indices.get_mapping(index=index) + timestamp_map = {n: m["mappings"].get("properties", {}).get("@timestamp", {}) for n, m in mappings.items()} return timestamp_map - def _get_last_event_time(self, index_str, dsl=None): + def _get_last_event_time(self, index: str, dsl: dict[str, Any] | None = None): """Get timestamp of most recent event.""" - last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits'] + last_event = self.client.search(query=dsl, index=index, size=1, sort="@timestamp:desc")["hits"]["hits"] if not last_event: return last_event = last_event[0] - index = last_event['_index'] - timestamp = last_event['_source']['@timestamp'] + index = last_event["_index"] + timestamp = last_event["_source"]["@timestamp"] - timestamp_map = self._build_timestamp_map(index_str) - event_date_format = timestamp_map[index].get('format', '').split('||') + timestamp_map = self._build_timestamp_map(index) + event_date_format = timestamp_map[index].get("format", "").split("||") # there are many native supported date formats and even custom data formats, but most, including beats use the # default `strict_date_optional_time`. It would be difficult to try to account for all possible formats, so this # will work on the default and unix time. - if set(event_date_format) & {'epoch_millis', 'epoch_second'}: + if set(event_date_format) & {"epoch_millis", "epoch_second"}: timestamp = unix_time_to_formatted(timestamp) return timestamp @staticmethod - def _prep_query(query, language, index, start_time=None, end_time=None): + def _prep_query( + query: str | dict[str, Any], + language: str, + index: str | list[str] | tuple[str], + start_time: str | None = None, + end_time: str | None = None, + ) -> tuple[str, dict[str, Any], str | None]: """Prep a query for search.""" - index_str = ','.join(index if isinstance(index, (list, tuple)) else index.split(',')) - lucene_query = query if language == 'lucene' else None - - if language in ('kql', 'kuery'): - formatted_dsl = {'query': kql.to_dsl(query)} - elif language == 'eql': - formatted_dsl = {'query': query, 'filter': MATCH_ALL} - elif language == 'lucene': - formatted_dsl = {'query': {'bool': {'filter': []}}} - elif language == 'dsl': - formatted_dsl = {'query': query} + index_str = ",".join(index if isinstance(index, (list, tuple)) else index.split(",")) + lucene_query = str(query) if language == "lucene" else None + + if language in ("kql", "kuery"): + formatted_dsl = {"query": kql.to_dsl(query)} # type: ignore[reportUnknownMemberType] + elif language == "eql": + formatted_dsl = {"query": query, "filter": MATCH_ALL} + elif language == "lucene": + formatted_dsl: dict[str, Any] = {"query": {"bool": {"filter": []}}} + elif language == "dsl": + formatted_dsl = {"query": query} else: - raise ValueError(f'Unknown search language: {language}') + raise ValueError(f"Unknown search language: {language}") if start_time or end_time: - end_time = end_time or 'now' - dsl = formatted_dsl['filter']['bool']['filter'] if language == 'eql' else \ - formatted_dsl['query']['bool'].setdefault('filter', []) + end_time = end_time or "now" + dsl = ( + formatted_dsl["filter"]["bool"]["filter"] + if language == "eql" + else formatted_dsl["query"]["bool"].setdefault("filter", []) + ) + if not start_time: + raise ValueError("No start time provided") + add_range_to_dsl(dsl, start_time, end_time) return index_str, formatted_dsl, lucene_query - def search(self, query, language, index: Union[str, list] = '*', start_time=None, end_time=None, size=None, - **kwargs): + def search( + self, + query: str | dict[str, Any], + language: str, + index: str | list[str] = "*", + start_time: str | None = None, + end_time: str | None = None, + size: int | None = None, + **kwargs: Any, + ) -> list[Any]: """Search an elasticsearch instance.""" - index_str, formatted_dsl, lucene_query = self._prep_query(query=query, language=language, index=index, - start_time=start_time, end_time=end_time) + index_str, formatted_dsl, lucene_query = self._prep_query( + query=query, language=language, index=index, start_time=start_time, end_time=end_time + ) formatted_dsl.update(size=size or self.max_events) - if language == 'eql': - results = self.client.eql.search(body=formatted_dsl, index=index_str, **kwargs)['hits'] - results = results.get('events') or results.get('sequences', []) + if language == "eql": + results = self.client.eql.search(body=formatted_dsl, index=index_str, **kwargs)["hits"] + results = results.get("events") or results.get("sequences", []) else: - results = self.client.search(body=formatted_dsl, q=lucene_query, index=index_str, - allow_no_indices=True, ignore_unavailable=True, **kwargs)['hits']['hits'] + results = self.client.search( + body=formatted_dsl, + q=lucene_query, + index=index_str, + allow_no_indices=True, + ignore_unavailable=True, + **kwargs, + )["hits"]["hits"] return results - def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None): + def search_from_rule( + self, + rules: RuleCollection, + start_time: str | None = None, + end_time: str = "now", + size: int | None = None, + ): """Search an elasticsearch instance using a rule.""" async_client = AsyncSearchClient(self.client) - survey_results = {} - multi_search = [] - multi_search_rules = [] - async_searches = [] - eql_searches = [] + survey_results: dict[str, Any] = {} + multi_search: list[dict[str, Any]] = [] + multi_search_rules: list[TOMLRule] = [] + async_searches: list[tuple[TOMLRule, Any]] = [] + eql_searches: list[tuple[TOMLRule, dict[str, Any]]] = [] for rule in rules: - if not rule.contents.data.get('query'): + if not rule.contents.data.get("query"): continue - language = rule.contents.data.get('language') - query = rule.contents.data.query + language = rule.contents.data.get("language") + query = rule.contents.data.query # type: ignore[reportAttributeAccessIssue] rule_type = rule.contents.data.type - index_str, formatted_dsl, lucene_query = self._prep_query(query=query, - language=language, - index=rule.contents.data.get('index', '*'), - start_time=start_time, - end_time=end_time) + index_str, formatted_dsl, _ = self._prep_query( + query=query, # type: ignore[reportUnknownArgumentType] + language=language, # type: ignore[reportUnknownArgumentType] + index=rule.contents.data.get("index", "*"), # type: ignore[reportUnknownArgumentType] + start_time=start_time, + end_time=end_time, + ) formatted_dsl.update(size=size or self.max_events) # prep for searches: msearch for kql | async search for lucene | eql client search for eql - if language == 'kuery': + if language == "kuery": multi_search_rules.append(rule) - multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'}) + multi_search.append({"index": index_str, "allow_no_indices": "true", "ignore_unavailable": "true"}) multi_search.append(formatted_dsl) - elif language == 'lucene': + elif language == "lucene": # wait for 0 to try and force async with no immediate results (not guaranteed) - result = async_client.submit(body=formatted_dsl, q=query, index=index_str, - allow_no_indices=True, ignore_unavailable=True, - wait_for_completion_timeout=0) - if result['is_running'] is True: - async_searches.append((rule, result['id'])) + result = async_client.submit( + body=formatted_dsl, + q=query, # type: ignore[reportUnknownArgumentType] + index=index_str, + allow_no_indices=True, + ignore_unavailable=True, + wait_for_completion_timeout=0, + ) + if result["is_running"] is True: + async_searches.append((rule, result["id"])) else: - survey_results[rule.id] = parse_unique_field_results(rule_type, ['process.name'], - result['response']) - elif language == 'eql': - eql_body = { - 'index': index_str, - 'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'}, - 'body': {'query': query, 'filter': formatted_dsl['filter']} + survey_results[rule.id] = parse_unique_field_results( + rule_type, ["process.name"], result["response"] + ) + elif language == "eql": + eql_body: dict[str, Any] = { + "index": index_str, + "params": {"ignore_unavailable": "true", "allow_no_indices": "true"}, + "body": {"query": query, "filter": formatted_dsl["filter"]}, } eql_searches.append((rule, eql_body)) # assemble search results multi_search_results = self.client.msearch(searches=multi_search) - for index, result in enumerate(multi_search_results['responses']): + for index, result in enumerate(multi_search_results["responses"]): try: rule = multi_search_rules[index] - survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, - rule.contents.data.unique_fields, result) + survey_results[rule.id] = parse_unique_field_results( + rule.contents.data.type, + rule.contents.data.unique_fields, # type: ignore[reportAttributeAccessIssje] + result, + ) except KeyError: - survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True} + survey_results[multi_search_rules[index].id] = {"error_retrieving_results": True} for entry in eql_searches: - rule: TOMLRule - search_args: dict rule, search_args = entry try: result = self.client.eql.search(**search_args) - survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, - rule.contents.data.unique_fields, result) + survey_results[rule.id] = parse_unique_field_results( + rule.contents.data.type, + rule.contents.data.unique_fields, # type: ignore[reportAttributeAccessIssue] + result, # type: ignore[reportAttributeAccessIssue] + ) except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e: - survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']} + survey_results[rule.id] = {"error_retrieving_results": True, "error": e.info["error"]["reason"]} for entry in async_searches: rule: TOMLRule rule, async_id = entry - result = async_client.get(id=async_id)['response'] - survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ['process.name'], result) + result = async_client.get(id=async_id)["response"] + survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ["process.name"], result) return survey_results - def count(self, query, language, index: Union[str, list], start_time=None, end_time='now'): + def count( + self, + query: str, + language: str, + index: str | list[str], + start_time: str | None = None, + end_time: str | None = "now", + ): """Get a count of documents from elasticsearch.""" - index_str, formatted_dsl, lucene_query = self._prep_query(query=query, language=language, index=index, - start_time=start_time, end_time=end_time) + index_str, formatted_dsl, lucene_query = self._prep_query( + query=query, language=language, index=index, start_time=start_time, end_time=end_time + ) # EQL API has no count endpoint - if language == 'eql': - results = self.search(query=query, language=language, index=index, start_time=start_time, end_time=end_time, - size=1000) + if language == "eql": + results = self.search( + query=query, language=language, index=index, start_time=start_time, end_time=end_time, size=1000 + ) return len(results) else: - return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True, - ignore_unavailable=True)['count'] - - def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'): + return self.client.count( + body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True, ignore_unavailable=True + )["count"] + + def count_from_rule( + self, + rules: RuleCollection, + start_time: str | None = None, + end_time: str | None = "now", + ): """Get a count of documents from elasticsearch using a rule.""" - survey_results = {} + survey_results: dict[str, Any] = {} for rule in rules.rules: - rule_results = {'rule_id': rule.id, 'name': rule.name} + rule_results: dict[str, Any] = {"rule_id": rule.id, "name": rule.name} - if not rule.contents.data.get('query'): + if not rule.contents.data.get("query"): continue try: - rule_results['search_count'] = self.count(query=rule.contents.data.query, - language=rule.contents.data.language, - index=rule.contents.data.get('index', '*'), - start_time=start_time, - end_time=end_time) + rule_results["search_count"] = self.count( + query=rule.contents.data.query, # type: ignore[reportAttributeAccessIssue] + language=rule.contents.data.language, # type: ignore[reportAttributeAccessIssue] + index=rule.contents.data.get("index", "*"), # type: ignore[reportAttributeAccessIssue] + start_time=start_time, + end_time=end_time, + ) except (elasticsearch.NotFoundError, elasticsearch.RequestError): - rule_results['search_count'] = -1 + rule_results["search_count"] = -1 survey_results[rule.id] = rule_results return survey_results +def evaluate(rule: TOMLRule, events: list[Any], normalize_kql_keywords: bool = False) -> list[Any]: + """Evaluate a query against events.""" + evaluator = kql.get_evaluator(kql.parse(rule.query), normalize_kql_keywords=normalize_kql_keywords) # type: ignore[reportUnknownMemberType] + filtered = list(filter(evaluator, events)) # type: ignore[reportUnknownMemberType] + return filtered # type: ignore[reportUnknownMemberType] + + +def combine_sources(sources: list[Any]) -> list[Any]: + """Combine lists of events from multiple sources.""" + combined: list[Any] = [] + for source in sources: + combined.extend(source.copy()) + + return event_sort(combined) + + class CollectEventsWithDSL(CollectEvents): """Collect events from elasticsearch.""" @staticmethod - def _group_events_by_type(events): + def _group_events_by_type(events: list[Any]): """Group events by agent.type.""" - event_by_type = {} + event_by_type: dict[str, list[Any]] = {} for event in events: - event_by_type.setdefault(event['_source']['agent']['type'], []).append(event['_source']) + event_by_type.setdefault(event["_source"]["agent"]["type"], []).append(event["_source"]) return event_by_type - def run(self, dsl, indexes, start_time): + def run(self, dsl: dict[str, Any], indexes: str | list[str], start_time: str): """Collect the events.""" - results = self.search(dsl, language='dsl', index=indexes, start_time=start_time, end_time='now', size=5000, - sort=[{'@timestamp': {'order': 'asc'}}]) + results = self.search( + dsl, + language="dsl", + index=indexes, + start_time=start_time, + end_time="now", + size=5000, + sort=[{"@timestamp": {"order": "asc"}}], + ) events = self._group_events_by_type(results) return Events(events) -@root.command('normalize-data') -@click.argument('events-file', type=click.File('r')) -def normalize_data(events_file): +@root.command("normalize-data") +@click.argument("events-file", type=click.File("r")) +def normalize_data(events_file: IO[Any]): """Normalize Elasticsearch data timestamps and sort.""" file_name = os.path.splitext(os.path.basename(events_file.name))[0] events = Events({file_name: [json.loads(e) for e in events_file.readlines()]}) - events.save(dump_dir=os.path.dirname(events_file.name)) + dirname = os.path.dirname(events_file.name) + events.save(dump_dir=Path(dirname)) -@root.group('es') +@root.group("es") @add_params(*elasticsearch_options) @click.pass_context -def es_group(ctx: click.Context, **kwargs): +def es_group(ctx: click.Context, **kwargs: Any): """Commands for integrating with Elasticsearch.""" - ctx.ensure_object(dict) + _ = ctx.ensure_object(dict) # type: ignore[reportUnknownVariableType] # only initialize an es client if the subcommand is invoked without help (hacky) if sys.argv[-1] in ctx.help_option_names: - click.echo('Elasticsearch client:') + click.echo("Elasticsearch client:") click.echo(format_command_options(ctx)) else: - ctx.obj['es'] = get_elasticsearch_client(ctx=ctx, **kwargs) + ctx.obj["es"] = get_elasticsearch_client(ctx=ctx, **kwargs) -@es_group.command('collect-events') -@click.argument('host-id') -@click.option('--query', '-q', help='KQL query to scope search') -@click.option('--index', '-i', multiple=True, help='Index(es) to search against (default: all indexes)') -@click.option('--rta-name', '-r', help='Name of RTA in order to save events directly to unit tests data directory') -@click.option('--rule-id', help='Updates rule mapping in rule-mapping.yaml file (requires --rta-name)') -@click.option('--view-events', is_flag=True, help='Print events after saving') +@es_group.command("collect-events") +@click.argument("host-id") +@click.option("--query", "-q", help="KQL query to scope search") +@click.option("--index", "-i", multiple=True, help="Index(es) to search against (default: all indexes)") +@click.option("--rta-name", "-r", help="Name of RTA in order to save events directly to unit tests data directory") +@click.option("--rule-id", help="Updates rule mapping in rule-mapping.yaml file (requires --rta-name)") +@click.option("--view-events", is_flag=True, help="Print events after saving") @click.pass_context -def collect_events(ctx, host_id, query, index, rta_name, rule_id, view_events): +def collect_events( + ctx: click.Context, + host_id: str, + query: str, + index: list[str], + rta_name: str, + rule_id: str, + view_events: bool, +): """Collect events from Elasticsearch.""" - client: Elasticsearch = ctx.obj['es'] - dsl = kql.to_dsl(query) if query else MATCH_ALL - dsl['bool'].setdefault('filter', []).append({'bool': {'should': [{'match_phrase': {'host.id': host_id}}]}}) + client: Elasticsearch = ctx.obj["es"] + dsl = kql.to_dsl(query) if query else MATCH_ALL # type: ignore[reportUnknownMemberType] + dsl["bool"].setdefault("filter", []).append({"bool": {"should": [{"match_phrase": {"host.id": host_id}}]}}) # type: ignore[reportUnknownMemberType] try: collector = CollectEventsWithDSL(client) start = time.time() - click.pause('Press any key once detonation is complete ...') - start_time = f'now-{round(time.time() - start) + 5}s' - events = collector.run(dsl, index or '*', start_time) + click.pause("Press any key once detonation is complete ...") + start_time = f"now-{round(time.time() - start) + 5}s" + events = collector.run(dsl, index or "*", start_time) # type: ignore[reportUnknownArgument] events.save(rta_name=rta_name, host_id=host_id) if rta_name and rule_id: - events.evaluate_against_rule(rule_id) + _ = events.evaluate_against_rule(rule_id) if view_events and events.events: events.echo_events(pager=True) return events except AssertionError as e: - error_msg = 'No events collected! Verify events are streaming and that the agent-hostname is correct' - client_error(error_msg, e, ctx=ctx) + error_msg = "No events collected! Verify events are streaming and that the agent-hostname is correct" + raise_client_error(error_msg, e, ctx=ctx) -@es_group.command('index-rules') -@click.option('--query', '-q', help='Optional KQL query to limit to specific rules') -@click.option('--from-file', '-f', type=click.File('r'), help='Load a previously saved uploadable bulk file') -@click.option('--save_files', '-s', is_flag=True, help='Optionally save the bulk request to a file') +@es_group.command("index-rules") +@click.option("--query", "-q", help="Optional KQL query to limit to specific rules") +@click.option("--from-file", "-f", type=click.File("r"), help="Load a previously saved uploadable bulk file") +@click.option("--save_files", "-s", is_flag=True, help="Optionally save the bulk request to a file") @click.pass_context -def index_repo(ctx: click.Context, query, from_file, save_files): +def index_repo(ctx: click.Context, query: str, from_file: IO[Any] | None, save_files: bool): """Index rules based on KQL search results to an elasticsearch instance.""" from .main import generate_rules_index - es_client: Elasticsearch = ctx.obj['es'] + es_client: Elasticsearch = ctx.obj["es"] if from_file: bulk_upload_docs = from_file.read() @@ -414,10 +513,10 @@ def index_repo(ctx: click.Context, query, from_file, save_files): # light validation only try: index_body = [json.loads(line) for line in bulk_upload_docs.splitlines()] - click.echo(f'{len([r for r in index_body if "rule" in r])} rules included') + click.echo(f"{len([r for r in index_body if 'rule' in r])} rules included") except json.JSONDecodeError: - client_error(f'Improperly formatted bulk request file: {from_file.name}') + raise_client_error(f"Improperly formatted bulk request file: {from_file.name}") else: - bulk_upload_docs, importable_rules_docs = ctx.invoke(generate_rules_index, query=query, save_files=save_files) + bulk_upload_docs, _ = ctx.invoke(generate_rules_index, query=query, save_files=save_files) - es_client.bulk(bulk_upload_docs) + _ = es_client.bulk(operations=bulk_upload_docs) diff --git a/detection_rules/exception.py b/detection_rules/exception.py index 1b89d6ab8dc..7dd8e6da521 100644 --- a/detection_rules/exception.py +++ b/detection_rules/exception.py @@ -3,13 +3,14 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Rule exceptions data.""" + from collections import defaultdict from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import List, Optional, Union, Tuple, get_args +from typing import get_args, Any -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from marshmallow import EXCLUDE, ValidationError, validates_schema from .mixins import MarshmallowDataclassMixin @@ -24,21 +25,23 @@ @dataclass(frozen=True) class ExceptionMeta(MarshmallowDataclassMixin): """Data stored in an exception's [metadata] section of TOML.""" + creation_date: definitions.Date list_name: str - rule_ids: List[definitions.UUIDString] - rule_names: List[str] + rule_ids: list[definitions.UUIDString] + rule_names: list[str] updated_date: definitions.Date # Optional fields - deprecation_date: Optional[definitions.Date] - comments: Optional[str] - maturity: Optional[definitions.Maturity] + deprecation_date: definitions.Date | None + comments: str | None + maturity: definitions.Maturity | None @dataclass(frozen=True) class BaseExceptionItemEntry(MarshmallowDataclassMixin): """Shared object between nested and non-nested exception items.""" + field: str type: definitions.ExceptionEntryType @@ -46,91 +49,99 @@ class BaseExceptionItemEntry(MarshmallowDataclassMixin): @dataclass(frozen=True) class NestedExceptionItemEntry(BaseExceptionItemEntry, MarshmallowDataclassMixin): """Nested exception item entry.""" - entries: List['ExceptionItemEntry'] + + entries: list["ExceptionItemEntry"] @validates_schema - def validate_nested_entry(self, data: dict, **kwargs): + def validate_nested_entry(self, data: dict[str, Any], **kwargs: Any): """More specific validation.""" - if data.get('list') is not None: - raise ValidationError('Nested entries cannot define a list') + if data.get("list") is not None: + raise ValidationError("Nested entries cannot define a list") @dataclass(frozen=True) class ExceptionItemEntry(BaseExceptionItemEntry, MarshmallowDataclassMixin): """Exception item entry.""" + @dataclass(frozen=True) class ListObject: """List object for exception item entry.""" + id: str type: definitions.EsDataTypes - list: Optional[ListObject] + list_vals: ListObject | None operator: definitions.ExceptionEntryOperator - value: Optional[Union[str, List[str]]] + value: str | None | list[str] @validates_schema - def validate_entry(self, data: dict, **kwargs): + def validate_entry(self, data: dict[str, Any], **kwargs: Any): """Validate the entry based on its type.""" - value = data.get('value', '') - if data['type'] in ('exists', 'list') and value is not None: - raise ValidationError(f'Entry of type {data["type"]} cannot have a value') - elif data['type'] in ('match', 'wildcard') and not isinstance(value, str): - raise ValidationError(f'Entry of type {data["type"]} must have a string value') - elif data['type'] == 'match_any' and not isinstance(value, list): - raise ValidationError(f'Entry of type {data["type"]} must have a list of strings as a value') + value = data.get("value", "") + if data["type"] in ("exists", "list") and value is not None: + raise ValidationError(f"Entry of type {data['type']} cannot have a value") + elif data["type"] in ("match", "wildcard") and not isinstance(value, str): + raise ValidationError(f"Entry of type {data['type']} must have a string value") + elif data["type"] == "match_any" and not isinstance(value, list): + raise ValidationError(f"Entry of type {data['type']} must have a list of strings as a value") @dataclass(frozen=True) class ExceptionItem(MarshmallowDataclassMixin): """Base exception item.""" + @dataclass(frozen=True) class Comment: """Comment object for exception item.""" + comment: str - comments: List[Optional[Comment]] + comments: list[Comment | None] description: str - entries: List[Union[ExceptionItemEntry, NestedExceptionItemEntry]] + entries: list[ExceptionItemEntry | NestedExceptionItemEntry] list_id: str - item_id: Optional[str] # api sets field when not provided - meta: Optional[dict] + item_id: str | None # api sets field when not provided + meta: dict[str, Any] | None name: str - namespace_type: Optional[definitions.ExceptionNamespaceType] # defaults to "single" if not provided - tags: Optional[List[str]] + namespace_type: definitions.ExceptionNamespaceType | None # defaults to "single" if not provided + tags: list[str] | None type: definitions.ExceptionItemType @dataclass(frozen=True) class EndpointException(ExceptionItem, MarshmallowDataclassMixin): """Endpoint exception item.""" - _tags: List[definitions.ExceptionItemEndpointTags] + + _tags: list[definitions.ExceptionItemEndpointTags] @validates_schema - def validate_endpoint(self, data: dict, **kwargs): + def validate_endpoint(self, data: dict[str, Any], **kwargs: Any): """Validate the endpoint exception.""" - for entry in data['entries']: - if entry['operator'] == "excluded": + for entry in data["entries"]: + if entry["operator"] == "excluded": raise ValidationError("Endpoint exceptions cannot have an `excluded` operator") @dataclass(frozen=True) class DetectionException(ExceptionItem, MarshmallowDataclassMixin): """Detection exception item.""" - expire_time: Optional[str] # fields.DateTime] # maybe this is isoformat? + + expire_time: str | None # fields.DateTime] # maybe this is isoformat? @dataclass(frozen=True) class ExceptionContainer(MarshmallowDataclassMixin): """Exception container.""" + description: str - list_id: Optional[str] - meta: Optional[dict] + list_id: str | None + meta: dict[str, Any] | None name: str - namespace_type: Optional[definitions.ExceptionNamespaceType] - tags: Optional[List[str]] + namespace_type: definitions.ExceptionNamespaceType | None + tags: list[str] | None type: definitions.ExceptionContainerType - def to_rule_entry(self) -> dict: + def to_rule_entry(self) -> dict[str, Any]: """Returns a dict of the format required in rule.exception_list.""" # requires KSO id to be consider valid structure return dict(namespace_type=self.namespace_type, type=self.type, list_id=self.list_id) @@ -139,8 +150,9 @@ def to_rule_entry(self) -> dict: @dataclass(frozen=True) class Data(MarshmallowDataclassMixin): """Data stored in an exception's [exception] section of TOML.""" + container: ExceptionContainer - items: Optional[List[DetectionException]] # Union[DetectionException, EndpointException]] + items: list[DetectionException] | None @dataclass(frozen=True) @@ -148,13 +160,15 @@ class TOMLExceptionContents(MarshmallowDataclassMixin): """Data stored in an exception file.""" metadata: ExceptionMeta - exceptions: List[Data] + exceptions: list[Data] @classmethod - def from_exceptions_dict(cls, exceptions_dict: dict, rule_list: list[dict]) -> "TOMLExceptionContents": + def from_exceptions_dict( + cls, exceptions_dict: dict[str, Any], rule_list: list[dict[str, Any]] + ) -> "TOMLExceptionContents": """Create a TOMLExceptionContents from a kibana rule resource.""" - rule_ids = [] - rule_names = [] + rule_ids: list[str] = [] + rule_names: list[str] = [] for rule in rule_list: rule_ids.append(rule["id"]) @@ -177,9 +191,9 @@ def from_exceptions_dict(cls, exceptions_dict: dict, rule_list: list[dict]) -> " return cls.from_dict({"metadata": metadata, "exceptions": [exceptions_dict]}, unknown=EXCLUDE) - def to_api_format(self) -> List[dict]: + def to_api_format(self) -> list[dict[str, Any]]: """Convert the TOML Exception to the API format.""" - converted = [] + converted: list[dict[str, Any]] = [] for exception in self.exceptions: converted.append(exception.container.to_dict()) @@ -193,8 +207,9 @@ def to_api_format(self) -> List[dict]: @dataclass(frozen=True) class TOMLException: """TOML exception object.""" + contents: TOMLExceptionContents - path: Optional[Path] = None + path: Path | None = None @property def name(self): @@ -213,19 +228,20 @@ def save_toml(self): contents_dict = self.contents.to_dict() # Sort the dictionary so that 'metadata' is at the top sorted_dict = dict(sorted(contents_dict.items(), key=lambda item: item[0] != "metadata")) - pytoml.dump(sorted_dict, f) + pytoml.dump(sorted_dict, f) # type: ignore[reportUnknownMemberType] -def parse_exceptions_results_from_api(results: List[dict]) -> tuple[dict, dict, List[str], List[dict]]: +def parse_exceptions_results_from_api( + results: list[dict[str, Any]], +) -> tuple[dict[str, Any], dict[str, Any], list[str], list[dict[str, Any]]]: """Parse exceptions results from the API into containers and items.""" - exceptions_containers = {} - exceptions_items = defaultdict(list) - errors = [] - unparsed_results = [] + exceptions_containers: dict[str, Any] = {} + exceptions_items: dict[str, list[Any]] = defaultdict(list) + unparsed_results: list[dict[str, Any]] = [] for result in results: - result_type = result.get("type") - list_id = result.get("list_id") + result_type = result["type"] + list_id = result["list_id"] if result_type in get_args(definitions.ExceptionContainerType): exceptions_containers[list_id] = result @@ -234,24 +250,29 @@ def parse_exceptions_results_from_api(results: List[dict]) -> tuple[dict, dict, else: unparsed_results.append(result) - return exceptions_containers, exceptions_items, errors, unparsed_results + return exceptions_containers, exceptions_items, [], unparsed_results -def build_exception_objects(exceptions_containers: List[dict], exceptions_items: List[dict], - exception_list_rule_table: dict, exceptions_directory: Path, save_toml: bool = False, - skip_errors: bool = False, verbose=False, - ) -> Tuple[List[TOMLException], List[str], List[str]]: +def build_exception_objects( + exceptions_containers: dict[str, Any], + exceptions_items: dict[str, Any], + exception_list_rule_table: dict[str, Any], + exceptions_directory: Path | None, + save_toml: bool = False, + skip_errors: bool = False, + verbose: bool = False, +) -> tuple[list[TOMLException], list[str], list[str]]: """Build TOMLException objects from a list of exception dictionaries.""" - output = [] - errors = [] - toml_exceptions = [] + output: list[str] = [] + errors: list[str] = [] + toml_exceptions: list[TOMLException] = [] for container in exceptions_containers.values(): try: - list_id = container.get("list_id") - items = exceptions_items.get(list_id) + list_id = container["list_id"] + items = exceptions_items[list_id] contents = TOMLExceptionContents.from_exceptions_dict( {"container": container, "items": items}, - exception_list_rule_table.get(list_id), + exception_list_rule_table[list_id], ) filename = f"{list_id}_exceptions.toml" if RULES_CONFIG.exception_dir is None and not exceptions_directory: @@ -259,9 +280,7 @@ def build_exception_objects(exceptions_containers: List[dict], exceptions_items: "No Exceptions directory is specified. Please specify either in the config or CLI." ) exceptions_path = ( - Path(exceptions_directory) / filename - if exceptions_directory - else RULES_CONFIG.exception_dir / filename + Path(exceptions_directory) / filename if exceptions_directory else RULES_CONFIG.exception_dir / filename ) if verbose: output.append(f"[+] Building exception(s) for {exceptions_path}") diff --git a/detection_rules/generic_loader.py b/detection_rules/generic_loader.py index 41f91bee82b..84fd8e30954 100644 --- a/detection_rules/generic_loader.py +++ b/detection_rules/generic_loader.py @@ -4,10 +4,11 @@ # 2.0. """Load generic toml formatted files for exceptions and actions.""" + from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import Callable, Iterable, Any -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from .action import TOMLAction, TOMLActionContents from .action_connector import TOMLActionConnector, TOMLActionConnectorContents @@ -19,11 +20,11 @@ RULES_CONFIG = parse_rules_config() -GenericCollectionTypes = Union[TOMLAction, TOMLActionConnector, TOMLException] -GenericCollectionContentTypes = Union[TOMLActionContents, TOMLActionConnectorContents, TOMLExceptionContents] +GenericCollectionTypes = TOMLAction | TOMLActionConnector | TOMLException +GenericCollectionContentTypes = TOMLActionContents | TOMLActionConnectorContents | TOMLExceptionContents -def metadata_filter(**metadata) -> Callable[[GenericCollectionTypes], bool]: +def metadata_filter(**metadata: Any) -> Callable[[GenericCollectionTypes], bool]: """Get a filter callback based off item metadata""" flt = dict_filter(metadata) @@ -37,21 +38,21 @@ def callback(item: GenericCollectionTypes) -> bool: class GenericCollection: """Generic collection for action and exception objects.""" - items: list + items: list[GenericCollectionTypes] __default = None - def __init__(self, items: Optional[List[GenericCollectionTypes]] = None): - self.id_map: Dict[definitions.UUIDString, GenericCollectionTypes] = {} - self.file_map: Dict[Path, GenericCollectionTypes] = {} - self.name_map: Dict[definitions.RuleName, GenericCollectionTypes] = {} - self.items: List[GenericCollectionTypes] = [] - self.errors: Dict[Path, Exception] = {} + def __init__(self, items: list[GenericCollectionTypes] | None = None): + self.id_map: dict[definitions.UUIDString, GenericCollectionTypes] = {} + self.file_map: dict[Path, GenericCollectionTypes] = {} + self.name_map: dict[definitions.RuleName, GenericCollectionTypes] = {} + self.items: list[GenericCollectionTypes] = [] + self.errors: dict[Path, Exception] = {} self.frozen = False - self._toml_load_cache: Dict[Path, dict] = {} + self._toml_load_cache: dict[Path, dict[str, Any]] = {} - for items in (items or []): - self.add_item(items) + for item in items or []: + self.add_item(item) def __len__(self) -> int: """Get the total amount of exceptions in the collection.""" @@ -63,23 +64,23 @@ def __iter__(self) -> Iterable[GenericCollectionTypes]: def __contains__(self, item: GenericCollectionTypes) -> bool: """Check if an item is in the map by comparing IDs.""" - return item.id in self.id_map + return item.id in self.id_map # type: ignore[reportAttributeAccessIssue] - def filter(self, cb: Callable[[TOMLException], bool]) -> 'GenericCollection': + def filter(self, cb: Callable[[TOMLException], bool]) -> "GenericCollection": """Retrieve a filtered collection of items.""" filtered_collection = GenericCollection() - for item in filter(cb, self.items): + for item in filter(cb, self.items): # type: ignore[reportCallIssue] filtered_collection.add_item(item) return filtered_collection @staticmethod - def deserialize_toml_string(contents: Union[bytes, str]) -> dict: + def deserialize_toml_string(contents: bytes | str) -> dict[str, Any]: """Deserialize a TOML string into a dictionary.""" - return pytoml.loads(contents) + return pytoml.loads(contents) # type: ignore[reportUnknownVariableType] - def _load_toml_file(self, path: Path) -> dict: + def _load_toml_file(self, path: Path) -> dict[str, Any]: """Load a TOML file into a dictionary.""" if path in self._toml_load_cache: return self._toml_load_cache[path] @@ -92,9 +93,9 @@ def _load_toml_file(self, path: Path) -> dict: self._toml_load_cache[path] = toml_dict return toml_dict - def _get_paths(self, directory: Path, recursive=True) -> List[Path]: + def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]: """Get all TOML files in a directory.""" - return sorted(directory.rglob('*.toml') if recursive else directory.glob('*.toml')) + return sorted(directory.rglob("*.toml") if recursive else directory.glob("*.toml")) def _assert_new(self, item: GenericCollectionTypes) -> None: """Assert that the item is new and can be added to the collection.""" @@ -102,8 +103,7 @@ def _assert_new(self, item: GenericCollectionTypes) -> None: name_map = self.name_map assert not self.frozen, f"Unable to add item {item.name} to a frozen collection" - assert item.name not in name_map, \ - f"Rule Name {item.name} collides with {name_map[item.name].name}" + assert item.name not in name_map, f"Rule Name {item.name} collides with {name_map[item.name].name}" if item.path is not None: item_path = item.path.resolve() @@ -116,15 +116,19 @@ def add_item(self, item: GenericCollectionTypes) -> None: self.name_map[item.name] = item self.items.append(item) - def load_dict(self, obj: dict, path: Optional[Path] = None) -> GenericCollectionTypes: + def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> GenericCollectionTypes: """Load a dictionary into the collection.""" - if 'exceptions' in obj: + if "exceptions" in obj: contents = TOMLExceptionContents.from_dict(obj) item = TOMLException(path=path, contents=contents) - elif 'actions' in obj: + elif "actions" in obj: contents = TOMLActionContents.from_dict(obj) + if not path: + raise ValueError("No path value provided") item = TOMLAction(path=path, contents=contents) - elif 'action_connectors' in obj: + elif "action_connectors" in obj: + if not path: + raise ValueError("No path value provided") contents = TOMLActionConnectorContents.from_dict(obj) item = TOMLActionConnector(path=path, contents=contents) else: @@ -152,14 +156,17 @@ def load_file(self, path: Path) -> GenericCollectionTypes: print(f"Error loading item in {path}") raise - def load_files(self, paths: Iterable[Path]) -> None: + def load_files(self, paths: Iterable[Path]): """Load multiple files into the collection.""" for path in paths: - self.load_file(path) + _ = self.load_file(path) def load_directory( - self, directory: Path, recursive=True, toml_filter: Optional[Callable[[dict], bool]] = None - ) -> None: + self, + directory: Path, + recursive: bool = True, + toml_filter: Callable[[dict[str, Any]], bool] | None = None, + ): """Load all TOML files in a directory.""" paths = self._get_paths(directory, recursive=recursive) if toml_filter is not None: @@ -168,8 +175,11 @@ def load_directory( self.load_files(paths) def load_directories( - self, directories: Iterable[Path], recursive=True, toml_filter: Optional[Callable[[dict], bool]] = None - ) -> None: + self, + directories: Iterable[Path], + recursive: bool = True, + toml_filter: Callable[[dict[str, Any]], bool] | None = None, + ): """Load all TOML files in multiple directories.""" for path in directories: self.load_directory(path, recursive=recursive, toml_filter=toml_filter) @@ -179,7 +189,7 @@ def freeze(self) -> None: self.frozen = True @classmethod - def default(cls) -> 'GenericCollection': + def default(cls) -> "GenericCollection": """Return the default item collection, which retrieves from default config location.""" if cls.__default is None: collection = GenericCollection() diff --git a/detection_rules/ghwrap.py b/detection_rules/ghwrap.py index 133bd7132c2..cabe6fe96b7 100644 --- a/detection_rules/ghwrap.py +++ b/detection_rules/ghwrap.py @@ -12,9 +12,9 @@ import shutil import time from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path -from typing import Dict, Optional, Tuple +from typing import Any from zipfile import ZipFile import click @@ -23,23 +23,15 @@ from .schemas import definitions -# this is primarily for type hinting - all use of the github client should come from GithubClient class -try: - from github import Github - from github.Repository import Repository - from github.GitRelease import GitRelease - from github.GitReleaseAsset import GitReleaseAsset -except ImportError: - # for type hinting - Github = None # noqa: N806 - Repository = None # noqa: N806 - GitRelease = None # noqa: N806 - GitReleaseAsset = None # noqa: N806 - - -def get_gh_release(repo: Repository, release_name: Optional[str] = None, tag_name: Optional[str] = None) -> GitRelease: +from github import Github +from github.Repository import Repository +from github.GitRelease import GitRelease +from github.GitReleaseAsset import GitReleaseAsset + + +def get_gh_release(repo: Repository, release_name: str | None = None, tag_name: str | None = None) -> GitRelease | None: """Get a list of GitHub releases by repo.""" - assert release_name or tag_name, 'Must specify a release_name or tag_name' + assert release_name or tag_name, "Must specify a release_name or tag_name" releases = repo.get_releases() for release in releases: @@ -49,13 +41,13 @@ def get_gh_release(repo: Repository, release_name: Optional[str] = None, tag_nam return release -def load_zipped_gh_assets_with_metadata(url: str) -> Tuple[str, dict]: +def load_zipped_gh_assets_with_metadata(url: str) -> tuple[str, dict[str, Any]]: """Download and unzip a GitHub assets.""" response = requests.get(url) zipped_asset = ZipFile(io.BytesIO(response.content)) zipped_sha256 = hashlib.sha256(response.content).hexdigest() - assets = {} + assets: dict[str, Any] = {} for zipped in zipped_asset.filelist: if zipped.is_dir(): continue @@ -64,27 +56,27 @@ def load_zipped_gh_assets_with_metadata(url: str) -> Tuple[str, dict]: sha256 = hashlib.sha256(contents).hexdigest() assets[zipped.filename] = { - 'contents': contents, - 'metadata': { - 'compress_size': zipped.compress_size, + "contents": contents, + "metadata": { + "compress_size": zipped.compress_size, # zipfile provides only a 6 tuple datetime; -1 means DST is unknown; 0's set tm_wday and tm_yday - 'created_at': time.strftime('%Y-%m-%dT%H:%M:%SZ', zipped.date_time + (0, 0, -1)), - 'sha256': sha256, - 'size': zipped.file_size, - } + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", zipped.date_time + (0, 0, -1)), + "sha256": sha256, + "size": zipped.file_size, + }, } return zipped_sha256, assets -def load_json_gh_asset(url: str) -> dict: +def load_json_gh_asset(url: str) -> dict[str, Any]: """Load and return the contents of a json asset file.""" response = requests.get(url) response.raise_for_status() return response.json() -def download_gh_asset(url: str, path: str, overwrite=False): +def download_gh_asset(url: str, path: str, overwrite: bool = False): """Download and unzip a GitHub asset.""" zipped = requests.get(url) z = ZipFile(io.BytesIO(zipped.content)) @@ -94,27 +86,26 @@ def download_gh_asset(url: str, path: str, overwrite=False): shutil.rmtree(path, ignore_errors=True) z.extractall(path) - click.echo(f'files saved to {path}') + click.echo(f"files saved to {path}") z.close() -def update_gist(token: str, - file_map: Dict[Path, str], - description: str, - gist_id: str, - public=False, - pre_purge=False) -> Response: +def update_gist( + token: str, + file_map: dict[Path, str], + description: str, + gist_id: str, + public: bool = False, + pre_purge: bool = False, +) -> Response: """Update existing gist.""" - url = f'https://api.github.com/gists/{gist_id}' - headers = { - 'accept': 'application/vnd.github.v3+json', - 'Authorization': f'token {token}' - } - body = { - 'description': description, - 'files': {}, # {path.name: {'content': contents} for path, contents in file_map.items()}, - 'public': public + url = f"https://api.github.com/gists/{gist_id}" + headers = {"accept": "application/vnd.github.v3+json", "Authorization": f"token {token}"} + body: dict[str, Any] = { + "description": description, + "files": {}, # {path.name: {'content': contents} for path, contents in file_map.items()}, + "public": public, } if pre_purge: @@ -122,12 +113,12 @@ def update_gist(token: str, response = requests.get(url) response.raise_for_status() data = response.json() - files = list(data['files']) - body['files'] = {file: {} for file in files if file not in file_map} + files = list(data["files"]) + body["files"] = {file: {} for file in files if file not in file_map} response = requests.patch(url, headers=headers, json=body) response.raise_for_status() - body['files'] = {path.name: {'content': contents} for path, contents in file_map.items()} + body["files"] = {path.name: {"content": contents} for path, contents in file_map.items()} response = requests.patch(url, headers=headers, json=body) response.raise_for_status() return response @@ -136,10 +127,10 @@ def update_gist(token: str, class GithubClient: """GitHub client wrapper.""" - def __init__(self, token: Optional[str] = None): + def __init__(self, token: str | None = None): """Get an unauthenticated client, verified authenticated client, or a default client.""" self.assert_github() - self.client: Github = Github(token) + self.client = Github(token) self.unauthenticated_client = Github() self.__token = token self.__authenticated_client = None @@ -147,23 +138,22 @@ def __init__(self, token: Optional[str] = None): @classmethod def assert_github(cls): if not Github: - raise ModuleNotFoundError('Missing PyGithub - try running `pip3 install .[dev]`') + raise ModuleNotFoundError("Missing PyGithub - try running `pip3 install .[dev]`") @property def authenticated_client(self) -> Github: if not self.__token: - raise ValueError('Token not defined! Re-instantiate with a token or use add_token method') + raise ValueError("Token not defined! Re-instantiate with a token or use add_token method") if not self.__authenticated_client: self.__authenticated_client = Github(self.__token) return self.__authenticated_client - def add_token(self, token): + def add_token(self, token: str): self.__token = token @dataclass class AssetManifestEntry: - compress_size: int created_at: datetime name: str @@ -173,18 +163,16 @@ class AssetManifestEntry: @dataclass class AssetManifestMetadata: - relative_url: str - entries: Dict[str, AssetManifestEntry] + entries: dict[str, AssetManifestEntry] zipped_sha256: definitions.Sha256 - created_at: datetime = field(default_factory=datetime.utcnow) - description: Optional[str] = None # populated by GitHub release asset label + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + description: str | None = None # populated by GitHub release asset label @dataclass class ReleaseManifest: - - assets: Dict[str, AssetManifestMetadata] + assets: dict[str, AssetManifestMetadata] assets_url: str author: str # parsed from GitHub release metadata as: author[login] created_at: str @@ -194,15 +182,20 @@ class ReleaseManifest: published_at: str url: str zipball_url: str - tag_name: str = None - description: str = None # parsed from GitHub release metadata as: body + tag_name: str | None = None + description: str | None = None # parsed from GitHub release metadata as: body class ManifestManager: """Manifest handler for GitHub releases.""" - def __init__(self, repo: str = 'elastic/detection-rules', release_name: Optional[str] = None, - tag_name: Optional[str] = None, token: Optional[str] = None): + def __init__( + self, + repo: str = "elastic/detection-rules", + release_name: str | None = None, + tag_name: str | None = None, + token: str | None = None, + ): self.repo_name = repo self.release_name = release_name self.tag_name = tag_name @@ -210,34 +203,37 @@ def __init__(self, repo: str = 'elastic/detection-rules', release_name: Optional self.has_token = token is not None self.repo: Repository = self.gh_client.client.get_repo(repo) - self.release: GitRelease = get_gh_release(self.repo, release_name, tag_name) + release = get_gh_release(self.repo, release_name, tag_name) + if not release: + raise ValueError("No release info found") + self.release = release if not self.release: - raise ValueError(f'No release found for {tag_name or release_name}') + raise ValueError(f"No release found for {tag_name or release_name}") if not self.release_name: self.release_name = self.release.title - self.manifest_name = f'manifest-{self.release_name}.json' - self.assets: dict = self._get_enriched_assets_from_release() + self.manifest_name = f"manifest-{self.release_name}.json" + self.assets = self._get_enriched_assets_from_release() self.release_manifest = self._create() self.__release_manifest_dict = dataclasses.asdict(self.release_manifest) self.manifest_size = len(json.dumps(self.__release_manifest_dict)) @property def release_manifest_fl(self) -> io.BytesIO: - return io.BytesIO(json.dumps(self.__release_manifest_dict, sort_keys=True).encode('utf-8')) + return io.BytesIO(json.dumps(self.__release_manifest_dict, sort_keys=True).encode("utf-8")) def _create(self) -> ReleaseManifest: """Create the manifest from GitHub asset metadata and file contents.""" assets = {} for asset_name, asset_data in self.assets.items(): - entries = {} - data = asset_data['data'] - metadata = asset_data['metadata'] + entries: dict[str, AssetManifestEntry] = {} + data = asset_data["data"] + metadata = asset_data["metadata"] for file_name, file_data in data.items(): - file_metadata = file_data['metadata'] + file_metadata = file_data["metadata"] name = Path(file_name).name file_metadata.update(name=name) @@ -245,9 +241,13 @@ def _create(self) -> ReleaseManifest: entry = AssetManifestEntry(**file_metadata) entries[name] = entry - assets[asset_name] = AssetManifestMetadata(metadata['browser_download_url'], entries, - metadata['zipped_sha256'], metadata['created_at'], - metadata['label']) + assets[asset_name] = AssetManifestMetadata( + metadata["browser_download_url"], + entries, + metadata["zipped_sha256"], + metadata["created_at"], + metadata["label"], + ) release_metadata = self._parse_release_metadata() release_metadata.update(assets=assets) @@ -255,49 +255,58 @@ def _create(self) -> ReleaseManifest: return release_manifest - def _parse_release_metadata(self) -> dict: + def _parse_release_metadata(self) -> dict[str, Any]: """Parse relevant info from GitHub metadata for release manifest.""" - ignore = ['assets'] - manual_set_keys = ['author', 'description'] + ignore = ["assets"] + manual_set_keys = ["author", "description"] keys = [f.name for f in dataclasses.fields(ReleaseManifest) if f.name not in ignore + manual_set_keys] parsed = {k: self.release.raw_data[k] for k in keys} - parsed.update(description=self.release.raw_data['body'], author=self.release.raw_data['author']['login']) + parsed.update(description=self.release.raw_data["body"], author=self.release.raw_data["author"]["login"]) return parsed def save(self) -> GitReleaseAsset: """Save manifest files.""" if not self.has_token: - raise ValueError('You must provide a token to save a manifest to a GitHub release') + raise ValueError("You must provide a token to save a manifest to a GitHub release") - asset = self.release.upload_asset_from_memory(self.release_manifest_fl, - self.manifest_size, - self.manifest_name) - click.echo(f'Manifest saved as {self.manifest_name} to {self.release.html_url}') + asset = self.release.upload_asset_from_memory(self.release_manifest_fl, self.manifest_size, self.manifest_name) + click.echo(f"Manifest saved as {self.manifest_name} to {self.release.html_url}") return asset @classmethod - def load(cls, name: str, repo: str = 'elastic/detection-rules', token: Optional[str] = None) -> Optional[dict]: + def load( + cls, + name: str, + repo_name: str = "elastic/detection-rules", + token: str | None = None, + ) -> dict[str, Any] | None: """Load a manifest.""" gh_client = GithubClient(token) - repo = gh_client.client.get_repo(repo) + repo = gh_client.client.get_repo(repo_name) release = get_gh_release(repo, tag_name=name) + if not release: + raise ValueError("No release info found") + for asset in release.get_assets(): - if asset.name == f'manifest-{name}.json': + if asset.name == f"manifest-{name}.json": return load_json_gh_asset(asset.browser_download_url) @classmethod - def load_all(cls, repo: str = 'elastic/detection-rules', token: Optional[str] = None - ) -> Tuple[Dict[str, dict], list]: + def load_all( + cls, + repo_name: str = "elastic/detection-rules", + token: str | None = None, + ) -> tuple[dict[str, dict[str, Any]], list[str]]: """Load a consolidated manifest.""" gh_client = GithubClient(token) - repo = gh_client.client.get_repo(repo) + repo = gh_client.client.get_repo(repo_name) - consolidated = {} - missing = set() + consolidated: dict[str, dict[str, Any]] = {} + missing: set[str] = set() for release in repo.get_releases(): name = release.tag_name - asset = next((a for a in release.get_assets() if a.name == f'manifest-{name}.json'), None) + asset = next((a for a in release.get_assets() if a.name == f"manifest-{name}.json"), None) if not asset: missing.add(name) else: @@ -306,28 +315,29 @@ def load_all(cls, repo: str = 'elastic/detection-rules', token: Optional[str] = return consolidated, list(missing) @classmethod - def get_existing_asset_hashes(cls, repo: str = 'elastic/detection-rules', token: Optional[str] = None) -> dict: + def get_existing_asset_hashes( + cls, + repo: str = "elastic/detection-rules", + token: str | None = None, + ) -> dict[str, Any]: """Load all assets with their hashes, by release.""" - flat = {} - consolidated, _ = cls.load_all(repo=repo, token=token) + flat: dict[str, Any] = {} + consolidated, _ = cls.load_all(repo_name=repo, token=token) for release, data in consolidated.items(): - for asset in data['assets'].values(): + for asset in data["assets"].values(): flat_release = flat[release] = {} - for asset_name, asset_data in asset['entries'].items(): - flat_release[asset_name] = asset_data['sha256'] + for asset_name, asset_data in asset["entries"].items(): + flat_release[asset_name] = asset_data["sha256"] return flat - def _get_enriched_assets_from_release(self) -> dict: + def _get_enriched_assets_from_release(self) -> dict[str, Any]: """Get assets and metadata from a GitHub release.""" - assets = {} + assets: dict[str, Any] = {} for asset in [a.raw_data for a in self.release.get_assets()]: - zipped_sha256, data = load_zipped_gh_assets_with_metadata(asset['browser_download_url']) + zipped_sha256, data = load_zipped_gh_assets_with_metadata(asset["browser_download_url"]) asset.update(zipped_sha256=zipped_sha256) - assets[asset['name']] = { - 'metadata': asset, - 'data': data - } + assets[asset["name"]] = {"metadata": asset, "data": data} return assets diff --git a/detection_rules/integrations.py b/detection_rules/integrations.py index 14ce262b8d1..81663b9604e 100644 --- a/detection_rules/integrations.py +++ b/detection_rules/integrations.py @@ -4,20 +4,21 @@ # 2.0. """Functions to support and interact with Kibana integrations.""" -import glob + +import fnmatch import gzip import json import re from collections import defaultdict, OrderedDict from pathlib import Path -from typing import Generator, List, Tuple, Union, Optional +from typing import Iterator, Any import requests from semver import Version import yaml from marshmallow import EXCLUDE, Schema, fields, post_load -import kql +import kql # type: ignore[reportMissingTypeStubs] from . import ecs from .config import load_current_package_version @@ -25,22 +26,23 @@ from .utils import cached, get_etc_path, read_gzip, unzip from .schemas import definitions -MANIFEST_FILE_PATH = get_etc_path('integration-manifests.json.gz') +MANIFEST_FILE_PATH = get_etc_path(["integration-manifests.json.gz"]) DEFAULT_MAX_RULE_VERSIONS = 1 -SCHEMA_FILE_PATH = get_etc_path('integration-schemas.json.gz') -_notified_integrations = set() +SCHEMA_FILE_PATH = get_etc_path(["integration-schemas.json.gz"]) + +_notified_integrations: set[str] = set() @cached -def load_integrations_manifests() -> dict: +def load_integrations_manifests() -> dict[str, Any]: """Load the consolidated integrations manifest.""" - return json.loads(read_gzip(get_etc_path('integration-manifests.json.gz'))) + return json.loads(read_gzip(get_etc_path(["integration-manifests.json.gz"]))) @cached -def load_integrations_schemas() -> dict: +def load_integrations_schemas() -> dict[str, Any]: """Load the consolidated integrations schemas.""" - return json.loads(read_gzip(get_etc_path('integration-schemas.json.gz'))) + return json.loads(read_gzip(get_etc_path(["integration-schemas.json.gz"]))) class IntegrationManifestSchema(Schema): @@ -54,35 +56,41 @@ class IntegrationManifestSchema(Schema): owner = fields.Dict(required=False) @post_load - def transform_policy_template(self, data, **kwargs): + def transform_policy_template(self, data: dict[str, Any], **_: Any): if "policy_templates" in data: data["policy_templates"] = [policy["name"] for policy in data["policy_templates"]] return data -def build_integrations_manifest(overwrite: bool, rule_integrations: list = [], - integration: str = None, prerelease: bool = False) -> None: +def build_integrations_manifest( + overwrite: bool, + rule_integrations: list[str] = [], + integration: str | None = None, + prerelease: bool = False, +) -> None: """Builds a new local copy of manifest.yaml from integrations Github.""" - def write_manifests(integrations: dict) -> None: - manifest_file = gzip.open(MANIFEST_FILE_PATH, "w+") + def write_manifests(integrations: dict[str, Any]) -> None: manifest_file_bytes = json.dumps(integrations).encode("utf-8") - manifest_file.write(manifest_file_bytes) - manifest_file.close() + with gzip.open(MANIFEST_FILE_PATH, "wb") as f: + _ = f.write(manifest_file_bytes) if overwrite: if MANIFEST_FILE_PATH.exists(): MANIFEST_FILE_PATH.unlink() - final_integration_manifests = {integration: {} for integration in rule_integrations} \ - or {integration: {}} + final_integration_manifests: dict[str, dict[str, Any]] = dict() + if rule_integrations: + final_integration_manifests = {integration: dict() for integration in rule_integrations} + elif integration: + final_integration_manifests = {integration: dict()} + rule_integrations = [integration] - rule_integrations = rule_integrations or [integration] for integration in rule_integrations: integration_manifests = get_integration_manifests(integration, prerelease=prerelease) for manifest in integration_manifests: - validated_manifest = IntegrationManifestSchema(unknown=EXCLUDE).load(manifest) - package_version = validated_manifest.pop("version") + validated_manifest = IntegrationManifestSchema(unknown=EXCLUDE).load(manifest) # type: ignore[reportUnknownVariableType] + package_version = validated_manifest.pop("version") # type: ignore final_integration_manifests[integration][package_version] = validated_manifest if overwrite and rule_integrations: @@ -98,7 +106,7 @@ def write_manifests(integrations: dict) -> None: print(f"final integrations manifests dumped: {MANIFEST_FILE_PATH}") -def build_integrations_schemas(overwrite: bool, integration: str = None) -> None: +def build_integrations_schemas(overwrite: bool, integration: str | None = None) -> None: """Builds a new local copy of integration-schemas.json.gz from EPR integrations.""" saved_integration_schemas = {} @@ -125,7 +133,7 @@ def build_integrations_schemas(overwrite: bool, integration: str = None) -> None # Loop through the packages and versions for package, versions in integration_manifests.items(): print(f"processing {package}") - final_integration_schemas.setdefault(package, {}) + final_integration_schemas.setdefault(package, {}) # type: ignore[reportUnknownMemberType] for version, manifest in versions.items(): if package in saved_integration_schemas and version in saved_integration_schemas[package]: continue @@ -136,74 +144,86 @@ def build_integrations_schemas(overwrite: bool, integration: str = None) -> None response.raise_for_status() # Update the final integration schemas - final_integration_schemas[package].update({version: {}}) + final_integration_schemas[package].update({version: {}}) # type: ignore[reportUnknownMemberType] # Open the zip file with unzip(response.content) as zip_ref: for file in zip_ref.namelist(): file_data_bytes = zip_ref.read(file) # Check if the file is a match - if glob.fnmatch.fnmatch(file, '*/fields/*.yml'): + if fnmatch.fnmatch(file, "*/fields/*.yml"): integration_name = Path(file).parent.parent.name - final_integration_schemas[package][version].setdefault(integration_name, {}) + final_integration_schemas[package][version].setdefault(integration_name, {}) # type: ignore[reportUnknownMemberType] schema_fields = yaml.safe_load(file_data_bytes) # Parse the schema and add to the integration_manifests data = flatten_ecs_schema(schema_fields) - flat_data = {field['name']: field['type'] for field in data} + flat_data = {field["name"]: field["type"] for field in data} - final_integration_schemas[package][version][integration_name].update(flat_data) + final_integration_schemas[package][version][integration_name].update(flat_data) # type: ignore[reportUnknownMemberType] # add machine learning jobs to the schema if package in list(map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)): - if glob.fnmatch.fnmatch(file, '*/ml_module/*ml.json'): + if fnmatch.fnmatch(file, "*/ml_module/*ml.json"): ml_module = json.loads(file_data_bytes) - job_ids = [job['id'] for job in ml_module['attributes']['jobs']] - final_integration_schemas[package][version]['jobs'] = job_ids + job_ids = [job["id"] for job in ml_module["attributes"]["jobs"]] + final_integration_schemas[package][version]["jobs"] = job_ids del file_data_bytes # Write the final integration schemas to disk with gzip.open(SCHEMA_FILE_PATH, "w") as schema_file: schema_file_bytes = json.dumps(final_integration_schemas).encode("utf-8") - schema_file.write(schema_file_bytes) + _ = schema_file.write(schema_file_bytes) print(f"final integrations manifests dumped: {SCHEMA_FILE_PATH}") -def find_least_compatible_version(package: str, integration: str, - current_stack_version: str, packages_manifest: dict) -> str: +def find_least_compatible_version( + package: str, + integration: str, + current_stack_version: str, + packages_manifest: dict[str, Any], +) -> str: """Finds least compatible version for specified integration based on stack version supplied.""" - integration_manifests = {k: v for k, v in sorted(packages_manifest[package].items(), - key=lambda x: Version.parse(x[0]))} - current_stack_version = Version.parse(current_stack_version, optional_minor_and_patch=True) + integration_manifests = { + k: v for k, v in sorted(packages_manifest[package].items(), key=lambda x: Version.parse(x[0])) + } + stack_version = Version.parse(current_stack_version, optional_minor_and_patch=True) # filter integration_manifests to only the latest major entries - major_versions = sorted(list(set([Version.parse(manifest_version).major - for manifest_version in integration_manifests])), reverse=True) + major_versions = sorted( + list(set([Version.parse(manifest_version).major for manifest_version in integration_manifests])), reverse=True + ) for max_major in major_versions: - major_integration_manifests = \ - {k: v for k, v in integration_manifests.items() if Version.parse(k).major == max_major} + major_integration_manifests = { + k: v for k, v in integration_manifests.items() if Version.parse(k).major == max_major + } # iterates through ascending integration manifests # returns latest major version that is least compatible - for version, manifest in OrderedDict(sorted(major_integration_manifests.items(), - key=lambda x: Version.parse(x[0]))).items(): - compatible_versions = re.sub(r"\>|\<|\=|\^|\~", "", - manifest["conditions"]["kibana"]["version"]).split(" || ") + for version, manifest in OrderedDict( + sorted(major_integration_manifests.items(), key=lambda x: Version.parse(x[0])) + ).items(): + compatible_versions = re.sub(r"\>|\<|\=|\^|\~", "", manifest["conditions"]["kibana"]["version"]).split( + " || " + ) for kibana_ver in compatible_versions: kibana_ver = Version.parse(kibana_ver) # check versions have the same major - if kibana_ver.major == current_stack_version.major: - if kibana_ver <= current_stack_version: + if kibana_ver.major == stack_version.major: + if kibana_ver <= stack_version: return f"^{version}" raise ValueError(f"no compatible version for integration {package}:{integration}") -def find_latest_compatible_version(package: str, integration: str, - rule_stack_version: Version, - packages_manifest: dict) -> Union[None, Tuple[str, str]]: +def find_latest_compatible_version( + package: str, + integration: str, + rule_stack_version: Version, + packages_manifest: dict[str, Any], +) -> tuple[str, list[str]]: """Finds least compatible version for specified integration based on stack version supplied.""" if not package: @@ -215,7 +235,7 @@ def find_latest_compatible_version(package: str, integration: str, # Converts the dict keys (version numbers) to Version objects for proper sorting (descending) integration_manifests = sorted(package_manifest.items(), key=lambda x: Version.parse(x[0]), reverse=True) - notice = "" + notice = [""] for version, manifest in integration_manifests: kibana_conditions = manifest.get("conditions", {}).get("kibana", {}) @@ -228,16 +248,17 @@ def find_latest_compatible_version(package: str, integration: str, if not compatible_versions: raise ValueError(f"Manifest for {package}:{integration} version {version} is missing compatible versions") - highest_compatible_version = Version.parse(max(compatible_versions, - key=lambda x: Version.parse(x))) + highest_compatible_version = Version.parse(max(compatible_versions, key=lambda x: Version.parse(x))) if highest_compatible_version > rule_stack_version: # generate notice message that a later integration version is available integration = f" {integration.strip()}" if integration else "" - notice = (f"There is a new integration {package}{integration} version {version} available!", - f"Update the rule min_stack version from {rule_stack_version} to " - f"{highest_compatible_version} if using new features in this latest version.") + notice = [ + f"There is a new integration {package}{integration} version {version} available!", + f"Update the rule min_stack version from {rule_stack_version} to " + f"{highest_compatible_version} if using new features in this latest version.", + ] if highest_compatible_version.major == rule_stack_version.major: return version, notice @@ -251,18 +272,25 @@ def find_latest_compatible_version(package: str, integration: str, raise ValueError(f"no compatible version for integration {package}:{integration}") -def get_integration_manifests(integration: str, prerelease: Optional[bool] = False, - kibana_version: Optional[str] = "") -> list: +def get_integration_manifests( + integration: str, + prerelease: bool | None = False, + kibana_version: str | None = "", +) -> list[Any]: """Iterates over specified integrations from package-storage and combines manifests per version.""" epr_search_url = "https://epr.elastic.co/search" if not prerelease: - prerelease = "false" + prerelease_str = "false" else: - prerelease = "true" + prerelease_str = "true" # link for search parameters - https://github.com/elastic/package-registry - epr_search_parameters = {"package": f"{integration}", "prerelease": prerelease, - "all": "true", "include_policy_templates": "true"} + epr_search_parameters = { + "package": f"{integration}", + "prerelease": prerelease_str, + "all": "true", + "include_policy_templates": "true", + } if kibana_version: epr_search_parameters["kibana.version"] = kibana_version epr_search_response = requests.get(epr_search_url, params=epr_search_parameters, timeout=10) @@ -273,8 +301,10 @@ def get_integration_manifests(integration: str, prerelease: Optional[bool] = Fal raise ValueError(f"EPR search for {integration} integration package returned empty list") sorted_manifests = sorted(manifests, key=lambda p: Version.parse(p["version"]), reverse=True) - print(f"loaded {integration} manifests from the following package versions: " - f"{[manifest['version'] for manifest in sorted_manifests]}") + print( + f"loaded {integration} manifests from the following package versions: " + f"{[manifest['version'] for manifest in sorted_manifests]}" + ) return manifests @@ -283,36 +313,39 @@ def find_latest_integration_version(integration: str, maturity: str, stack_versi prerelease = False if maturity == "ga" else True existing_pkgs = get_integration_manifests(integration, prerelease, str(stack_version)) if maturity == "ga": - existing_pkgs = [pkg for pkg in existing_pkgs if not - Version.parse(pkg["version"]).prerelease] + existing_pkgs = [pkg for pkg in existing_pkgs if not Version.parse(pkg["version"]).prerelease] if maturity == "beta": - existing_pkgs = [pkg for pkg in existing_pkgs if - Version.parse(pkg["version"]).prerelease] + existing_pkgs = [pkg for pkg in existing_pkgs if Version.parse(pkg["version"]).prerelease] return max([Version.parse(pkg["version"]) for pkg in existing_pkgs]) -def get_integration_schema_data(data, meta, package_integrations: dict) -> Generator[dict, None, None]: +# Using `Any` here because `integrations` and `rule` modules are tightly coupled +def get_integration_schema_data( + data: Any, # type: ignore[reportRedeclaration] + meta: Any, # type: ignore[reportRedeclaration] + package_integrations: list[dict[str, Any]], +) -> Iterator[dict[str, Any]]: """Iterates over specified integrations from package-storage and combines schemas per version.""" # lazy import to avoid circular import - from .rule import ( # pylint: disable=import-outside-toplevel - ESQLRuleData, QueryRuleData, RuleMeta) + from .rule import ( + QueryRuleData, + RuleMeta, + ) - data: QueryRuleData = data + data: QueryRuleData = data # type: ignore[reportAssignmentType] meta: RuleMeta = meta packages_manifest = load_integrations_manifests() integrations_schemas = load_integrations_schemas() # validate the query against related integration fields - if (isinstance(data, QueryRuleData) or isinstance(data, ESQLRuleData)) \ - and data.language != 'lucene' and meta.maturity == "production": - + if data.language != "lucene" and meta.maturity == "production": for stack_version, mapping in meta.get_validation_stack_versions().items(): - ecs_version = mapping['ecs'] - endgame_version = mapping['endgame'] + ecs_version = mapping["ecs"] + endgame_version = mapping["endgame"] - ecs_schema = ecs.flatten_multi_fields(ecs.get_schema(ecs_version, name='ecs_flat')) + ecs_schema = ecs.flatten_multi_fields(ecs.get_schema(ecs_version, name="ecs_flat")) for pk_int in package_integrations: package = pk_int["package"] @@ -323,20 +356,42 @@ def get_integration_schema_data(data, meta, package_integrations: dict) -> Gener min_stack = Version.parse(min_stack, optional_minor_and_patch=True) # Extract the integration schema fields - integration_schema, package_version = get_integration_schema_fields(integrations_schemas, package, - integration, min_stack, - packages_manifest, ecs_schema, - data) - - data = {"schema": integration_schema, "package": package, "integration": integration, - "stack_version": stack_version, "ecs_version": ecs_version, - "package_version": package_version, "endgame_version": endgame_version} - yield data - + integration_schema, package_version = get_integration_schema_fields( + integrations_schemas, + package, + integration, + min_stack, + packages_manifest, + ecs_schema, + data, + ) + + yield { + "schema": integration_schema, + "package": package, + "integration": integration, + "stack_version": stack_version, + "ecs_version": ecs_version, + "package_version": package_version, + "endgame_version": endgame_version, + } + + +def get_integration_schema_fields( + integrations_schemas: dict[str, Any], + package: str, + integration: str, + min_stack: Version, + packages_manifest: dict[str, Any], + ecs_schema: dict[str, Any], + data: Any, # type: ignore[reportRedeclaration] +) -> tuple[dict[str, Any], str]: + # lazy import to avoid circular import + from .rule import ( + QueryRuleData, + ) -def get_integration_schema_fields(integrations_schemas: dict, package: str, integration: str, - min_stack: Version, packages_manifest: dict, - ecs_schema: dict, data: dict) -> dict: + data: QueryRuleData = data # type: ignore[reportAssignmentType] """Extracts the integration fields to schema based on package integrations.""" package_version, notice = find_latest_compatible_version(package, integration, min_stack, packages_manifest) notify_user_if_update_available(data, notice, integration) @@ -348,25 +403,43 @@ def get_integration_schema_fields(integrations_schemas: dict, package: str, inte return integration_schema, package_version -def notify_user_if_update_available(data: dict, notice: list, integration: str) -> None: +def notify_user_if_update_available( + data: Any, # type: ignore[reportRedeclaration] + notice: list[str], + integration: str, +) -> None: """Notifies the user if an update is available, only once per integration.""" + # lazy import to avoid circular import + from .rule import ( + QueryRuleData, + ) + + data: QueryRuleData = data # type: ignore[reportAssignmentType] + global _notified_integrations if notice and data.get("notify", False) and integration not in _notified_integrations: - # flag to only warn once per integration for available upgrades _notified_integrations.add(integration) print(f"\n{data.get('name')}") - print('\n'.join(notice)) + print("\n".join(notice)) -def collect_schema_fields(integrations_schemas: dict, package: str, package_version: str, - integration: Optional[str] = None) -> dict: +def collect_schema_fields( + integrations_schemas: dict[str, Any], + package: str, + package_version: str, + integration: str | None = None, +) -> dict[str, Any]: """Collects the schema fields for a given integration.""" if integration is None: - return {field: value for dataset in integrations_schemas[package][package_version] if dataset != "jobs" - for field, value in integrations_schemas[package][package_version][dataset].items()} + return { + field: value + for dataset in integrations_schemas[package][package_version] + if dataset != "jobs" + for field, value in integrations_schemas[package][package_version][dataset].items() + } if integration not in integrations_schemas[package][package_version]: raise ValueError(f"Integration {integration} not found in package {package} version {package_version}") @@ -374,21 +447,20 @@ def collect_schema_fields(integrations_schemas: dict, package: str, package_vers return integrations_schemas[package][package_version][integration] -def parse_datasets(datasets: list, package_manifest: dict) -> List[Optional[dict]]: +def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> list[dict[str, Any]]: """Parses datasets into packaged integrations from rule data.""" - packaged_integrations = [] + packaged_integrations: list[dict[str, Any]] = [] for value in sorted(datasets): - # cleanup extra quotes pulled from ast field value = value.strip('"') - integration = 'Unknown' - if '.' in value: - package, integration = value.split('.', 1) + integration = "Unknown" + if "." in value: + package, integration = value.split(".", 1) # Handle cases where endpoint event datasource needs to be parsed uniquely (e.g endpoint.events.network) # as endpoint.network if package == "endpoint" and "events" in integration: - integration = integration.split('.')[1] + integration = integration.split(".")[1] else: package = value @@ -403,7 +475,7 @@ class SecurityDetectionEngine: def __init__(self): self.epr_url = "https://epr.elastic.co/package/security_detection_engine/" - def load_integration_assets(self, package_version: Version) -> dict: + def load_integration_assets(self, package_version: Version) -> dict[str, Any]: """Loads integration assets into memory.""" epr_package_url = f"{self.epr_url}{str(package_version)}/" @@ -414,15 +486,19 @@ def load_integration_assets(self, package_version: Version) -> dict: zip_response = requests.get(zip_url) with unzip(zip_response.content) as zip_package: asset_file_names = [asset for asset in zip_package.namelist() if "json" in asset] - assets = {x.split("/")[-1].replace(".json", ""): json.loads(zip_package.read(x).decode('utf-8')) - for x in asset_file_names} + assets = { + x.split("/")[-1].replace(".json", ""): json.loads(zip_package.read(x).decode("utf-8")) + for x in asset_file_names + } return assets - def keep_latest_versions(self, assets: dict, num_versions: int = DEFAULT_MAX_RULE_VERSIONS) -> dict: + def keep_latest_versions( + self, assets: dict[str, Any], num_versions: int = DEFAULT_MAX_RULE_VERSIONS + ) -> dict[str, Any]: """Keeps only the latest N versions of each rule to limit historical rule versions in our release package.""" # Dictionary to hold the sorted list of versions for each base rule ID - rule_versions = defaultdict(list) + rule_versions: dict[str, list[tuple[int, str]]] = defaultdict(list) # Separate rule ID and version, and group by base rule ID for key in assets: @@ -431,7 +507,7 @@ def keep_latest_versions(self, assets: dict, num_versions: int = DEFAULT_MAX_RUL rule_versions[base_id].append((version, key)) # Dictionary to hold the final assets with only the specified number of latest versions - filtered_assets = {} + filtered_assets: dict[str, Any] = {} # Keep only the last/latest num_versions versions for each rule # Sort versions and take the last num_versions diff --git a/detection_rules/kbwrap.py b/detection_rules/kbwrap.py index e52dd507ce7..7c67faa48cd 100644 --- a/detection_rules/kbwrap.py +++ b/detection_rules/kbwrap.py @@ -4,25 +4,29 @@ # 2.0. """Kibana cli commands.""" + import re import sys from pathlib import Path -from typing import Iterable, List, Optional +from typing import Any import click -import kql -from kibana import Signal, RuleResource +import kql # type: ignore[reportMissingTypeStubs] +from kibana import Signal, RuleResource # type: ignore[reportMissingTypeStubs] from .config import parse_rules_config from .cli_utils import multi_collection -from .action_connector import (TOMLActionConnectorContents, - parse_action_connector_results_from_api, build_action_connector_objects) -from .exception import (TOMLExceptionContents, - build_exception_objects, parse_exceptions_results_from_api) +from .action_connector import ( + TOMLActionConnector, + TOMLActionConnectorContents, + parse_action_connector_results_from_api, + build_action_connector_objects, +) +from .exception import TOMLExceptionContents, TOMLException, build_exception_objects, parse_exceptions_results_from_api from .generic_loader import GenericCollection from .main import root -from .misc import add_params, client_error, kibana_options, get_kibana_client, nested_set +from .misc import add_params, raise_client_error, kibana_options, get_kibana_client, nested_set from .rule import downgrade_contents_from_rule, TOMLRuleContents, TOMLRule from .rule_loader import RuleCollection, update_metadata_from_file from .utils import format_command_options, rulename_to_filename @@ -30,30 +34,30 @@ RULES_CONFIG = parse_rules_config() -@root.group('kibana') +@root.group("kibana") @add_params(*kibana_options) @click.pass_context -def kibana_group(ctx: click.Context, **kibana_kwargs): +def kibana_group(ctx: click.Context, **kibana_kwargs: Any): """Commands for integrating with Kibana.""" - ctx.ensure_object(dict) + _ = ctx.ensure_object(dict) # type: ignore[reportUnknownVariableType] # only initialize an kibana client if the subcommand is invoked without help (hacky) if sys.argv[-1] in ctx.help_option_names: - click.echo('Kibana client:') + click.echo("Kibana client:") click.echo(format_command_options(ctx)) else: - ctx.obj['kibana'] = get_kibana_client(**kibana_kwargs) + ctx.obj["kibana"] = get_kibana_client(**kibana_kwargs) @kibana_group.command("upload-rule") @multi_collection -@click.option('--replace-id', '-r', is_flag=True, help='Replace rule IDs with new IDs before export') +@click.option("--replace-id", "-r", is_flag=True, help="Replace rule IDs with new IDs before export") @click.pass_context -def upload_rule(ctx, rules: RuleCollection, replace_id): +def upload_rule(ctx: click.Context, rules: RuleCollection, replace_id: bool): """[Deprecated] Upload a list of rule .toml files to Kibana.""" - kibana = ctx.obj['kibana'] - api_payloads = [] + kibana = ctx.obj["kibana"] + api_payloads: list[RuleResource] = [] click.secho( "WARNING: This command is deprecated as of Elastic Stack version 9.0. Please use `kibana import-rules`.", @@ -64,62 +68,69 @@ def upload_rule(ctx, rules: RuleCollection, replace_id): try: payload = downgrade_contents_from_rule(rule, kibana.version, replace_id=replace_id) except ValueError as e: - client_error(f'{e} in version:{kibana.version}, for rule: {rule.name}', e, ctx=ctx) + raise_client_error(f"{e} in version:{kibana.version}, for rule: {rule.name}", e, ctx=ctx) rule = RuleResource(payload) api_payloads.append(rule) with kibana: - results = RuleResource.bulk_create_legacy(api_payloads) + results: list[RuleResource] = RuleResource.bulk_create_legacy(api_payloads) # type: ignore[reportUnknownMemberType] - success = [] - errors = [] + success: list[str] = [] + errors: list[str] = [] for result in results: - if 'error' in result: - errors.append(f'{result["rule_id"]} - {result["error"]["message"]}') + if "error" in result: + errors.append(f"{result['rule_id']} - {result['error']['message']}") else: - success.append(result['rule_id']) + success.append(result["rule_id"]) # type: ignore[reportUnknownArgumentType] if success: - click.echo('Successful uploads:\n - ' + '\n - '.join(success)) + click.echo("Successful uploads:\n - " + "\n - ".join(success)) if errors: - click.echo('Failed uploads:\n - ' + '\n - '.join(errors)) + click.echo("Failed uploads:\n - " + "\n - ".join(errors)) return results -@kibana_group.command('import-rules') +@kibana_group.command("import-rules") @multi_collection -@click.option('--overwrite', '-o', is_flag=True, help='Overwrite existing rules') -@click.option('--overwrite-exceptions', '-e', is_flag=True, help='Overwrite exceptions in existing rules') -@click.option('--overwrite-action-connectors', '-ac', is_flag=True, - help='Overwrite action connectors in existing rules') +@click.option("--overwrite", "-o", is_flag=True, help="Overwrite existing rules") +@click.option("--overwrite-exceptions", "-e", is_flag=True, help="Overwrite exceptions in existing rules") +@click.option( + "--overwrite-action-connectors", "-ac", is_flag=True, help="Overwrite action connectors in existing rules" +) @click.pass_context -def kibana_import_rules(ctx: click.Context, rules: RuleCollection, overwrite: Optional[bool] = False, - overwrite_exceptions: Optional[bool] = False, - overwrite_action_connectors: Optional[bool] = False) -> (dict, List[RuleResource]): +def kibana_import_rules( + ctx: click.Context, + rules: RuleCollection, + overwrite: bool = False, + overwrite_exceptions: bool = False, + overwrite_action_connectors: bool = False, +) -> tuple[dict[str, Any], list[RuleResource]]: """Import custom rules into Kibana.""" - def _handle_response_errors(response: dict): + + def _handle_response_errors(response: dict[str, Any]): """Handle errors from the import response.""" + def _parse_list_id(s: str): """Parse the list ID from the error message.""" match = re.search(r'list_id: "(.*?)"', s) return match.group(1) if match else None # Re-try to address known Kibana issue: https://github.com/elastic/kibana/issues/143864 - workaround_errors = [] - workaround_error_types = set() + workaround_errors: list[str] = [] + workaround_error_types: set[str] = set() flattened_exceptions = [e for sublist in exception_dicts for e in sublist] all_exception_list_ids = {exception["list_id"] for exception in flattened_exceptions} - click.echo(f'{len(response["errors"])} rule(s) failed to import!') + click.echo(f"{len(response['errors'])} rule(s) failed to import!") action_connector_validation_error = "Error validating create data" action_connector_type_error = "expected value of type [string] but got [undefined]" - for error in response['errors']: + for error in response["errors"]: error_message = error["error"]["message"] - click.echo(f' - {error["rule_id"]}: ({error["error"]["status_code"]}) {error_message}') + click.echo(f" - {error['rule_id']}: ({error['error']['status_code']}) {error_message}") if "references a non existent exception list" in error_message: list_id = _parse_list_id(error_message) @@ -147,15 +158,17 @@ def _parse_list_id(s: str): ) click.echo() - def _process_imported_items(imported_items_list, item_type_description, item_key): + def _process_imported_items( + imported_items_list: list[list[dict[str, Any]]], item_type_description: str, item_key: str + ): """Displays appropriately formatted success message that all items imported successfully.""" all_ids = {item[item_key] for sublist in imported_items_list for item in sublist} if all_ids: - click.echo(f'{len(all_ids)} {item_type_description} successfully imported') - ids_str = '\n - '.join(all_ids) - click.echo(f' - {ids_str}') + click.echo(f"{len(all_ids)} {item_type_description} successfully imported") + ids_str = "\n - ".join(all_ids) + click.echo(f" - {ids_str}") - kibana = ctx.obj['kibana'] + kibana = ctx.obj["kibana"] rule_dicts = [r.contents.to_api_format() for r in rules] with kibana: cl = GenericCollection.default() @@ -165,26 +178,26 @@ def _process_imported_items(imported_items_list, item_type_description, item_key action_connectors_dicts = [ d.contents.to_api_format() for d in cl.items if isinstance(d.contents, TOMLActionConnectorContents) ] - response, successful_rule_ids, results = RuleResource.import_rules( + response, successful_rule_ids, results = RuleResource.import_rules( # type: ignore[reportUnknownMemberType] rule_dicts, exception_dicts, action_connectors_dicts, overwrite=overwrite, overwrite_exceptions=overwrite_exceptions, - overwrite_action_connectors=overwrite_action_connectors + overwrite_action_connectors=overwrite_action_connectors, ) if successful_rule_ids: - click.echo(f'{len(successful_rule_ids)} rule(s) successfully imported') - rule_str = '\n - '.join(successful_rule_ids) - click.echo(f' - {rule_str}') - if response['errors']: - _handle_response_errors(response) + click.echo(f"{len(successful_rule_ids)} rule(s) successfully imported") # type: ignore[reportUnknownArgumentType] + rule_str = "\n - ".join(successful_rule_ids) # type: ignore[reportUnknownArgumentType] + click.echo(f" - {rule_str}") + if response["errors"]: + _handle_response_errors(response) # type: ignore[reportUnknownArgumentType] else: - _process_imported_items(exception_dicts, 'exception list(s)', 'list_id') - _process_imported_items(action_connectors_dicts, 'action connector(s)', 'id') + _process_imported_items(exception_dicts, "exception list(s)", "list_id") + _process_imported_items(action_connectors_dicts, "action connector(s)", "id") - return response, results + return response, results # type: ignore[reportUnknownVariableType] @kibana_group.command("export-rules") @@ -195,15 +208,23 @@ def _process_imported_items(imported_items_list, item_type_description, item_key @click.option("--exceptions-directory", "-ed", required=False, type=Path, help="Directory to export exceptions to") @click.option("--default-author", "-da", type=str, required=False, help="Default author for rules missing one") @click.option("--rule-id", "-r", multiple=True, help="Optional Rule IDs to restrict export to") -@click.option("--rule-name", "-rn", required=False, help="Optional Rule name to restrict export to " - "(KQL, case-insensitive, supports wildcards)") +@click.option( + "--rule-name", + "-rn", + required=False, + help="Optional Rule name to restrict export to (KQL, case-insensitive, supports wildcards)", +) @click.option("--export-action-connectors", "-ac", is_flag=True, help="Include action connectors in export") @click.option("--export-exceptions", "-e", is_flag=True, help="Include exceptions in export") @click.option("--skip-errors", "-s", is_flag=True, help="Skip errors when exporting rules") @click.option("--strip-version", "-sv", is_flag=True, help="Strip the version fields from all rules") -@click.option("--no-tactic-filename", "-nt", is_flag=True, - help="Exclude tactic prefix in exported filenames for rules. " - "Use same flag for import-rules to prevent warnings and disable its unit test.") +@click.option( + "--no-tactic-filename", + "-nt", + is_flag=True, + help="Exclude tactic prefix in exported filenames for rules. " + "Use same flag for import-rules to prevent warnings and disable its unit test.", +) @click.option("--local-creation-date", "-lc", is_flag=True, help="Preserve the local creation date of the rule") @click.option("--local-updated-date", "-lu", is_flag=True, help="Preserve the local updated date of the rule") @click.option("--custom-rules-only", "-cro", is_flag=True, help="Only export custom rules") @@ -214,18 +235,28 @@ def _process_imported_items(imported_items_list, item_type_description, item_key required=False, help=( "Apply a query filter to exporting rules e.g. " - "\"alert.attributes.tags: \\\"test\\\"\" to filter for rules that have the tag \"test\"" - ) + '"alert.attributes.tags: \\"test\\"" to filter for rules that have the tag "test"' + ), ) @click.pass_context -def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_directory: Optional[Path], - exceptions_directory: Optional[Path], default_author: str, - rule_id: Optional[Iterable[str]] = None, rule_name: Optional[str] = None, - export_action_connectors: bool = False, - export_exceptions: bool = False, skip_errors: bool = False, strip_version: bool = False, - no_tactic_filename: bool = False, local_creation_date: bool = False, - local_updated_date: bool = False, custom_rules_only: bool = False, - export_query: Optional[str] = None) -> List[TOMLRule]: +def kibana_export_rules( + ctx: click.Context, + directory: Path, + action_connectors_directory: Path | None, + exceptions_directory: Path | None, + default_author: str, + rule_id: list[str] | None = None, + rule_name: str | None = None, + export_action_connectors: bool = False, + export_exceptions: bool = False, + skip_errors: bool = False, + strip_version: bool = False, + no_tactic_filename: bool = False, + local_creation_date: bool = False, + local_updated_date: bool = False, + custom_rules_only: bool = False, + export_query: str | None = None, +) -> list[TOMLRule]: """Export custom rules from Kibana.""" kibana = ctx.obj["kibana"] kibana_include_details = export_exceptions or export_action_connectors @@ -237,22 +268,20 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d with kibana: # Look up rule IDs by name if --rule-name was provided if rule_name: - found = RuleResource.find(filter=f"alert.attributes.name:{rule_name}") - rule_id = [r["rule_id"] for r in found] + found = RuleResource.find(filter=f"alert.attributes.name:{rule_name}") # type: ignore + rule_id = [r["rule_id"] for r in found] # type: ignore[reportUnknownVariableType] query = ( - export_query if not custom_rules_only + export_query + if not custom_rules_only else ( - f"alert.attributes.params.ruleSource.type: \"internal\"" - f"{f' and ({export_query})' if export_query else ''}" + f'alert.attributes.params.ruleSource.type: "internal"{f" and ({export_query})" if export_query else ""}' ) ) - results = ( - RuleResource.bulk_export(rule_ids=list(rule_id), query=query) + results = ( # type: ignore[reportUnknownVariableType] + RuleResource.bulk_export(rule_ids=list(rule_id), query=query) # type: ignore if query - else RuleResource.export_rules( - list(rule_id), exclude_export_details=not kibana_include_details - ) + else RuleResource.export_rules(list(rule_id), exclude_export_details=not kibana_include_details) # type: ignore[reportArgumentType] ) # Handle Exceptions Directory Location if results and exceptions_directory: @@ -274,48 +303,48 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d click.echo("No rules found to export") return [] - rules_results = results + rules_results = results # type: ignore[reportUnknownVariableType] action_connector_results = [] exception_results = [] if kibana_include_details: # Assign counts to variables - rules_count = results[-1]["exported_rules_count"] - exception_list_count = results[-1]["exported_exception_list_count"] - exception_list_item_count = results[-1]["exported_exception_list_item_count"] - action_connector_count = results[-1]["exported_action_connector_count"] + rules_count = results[-1]["exported_rules_count"] # type: ignore[reportUnknownVariableType] + exception_list_count = results[-1]["exported_exception_list_count"] # type: ignore[reportUnknownVariableType] + exception_list_item_count = results[-1]["exported_exception_list_item_count"] # type: ignore[reportUnknownVariableType] + action_connector_count = results[-1]["exported_action_connector_count"] # type: ignore[reportUnknownVariableType] # Parse rules results and exception results from API return - rules_results = results[:rules_count] - exception_results = results[rules_count:rules_count + exception_list_count + exception_list_item_count] - rules_and_exceptions_count = rules_count + exception_list_count + exception_list_item_count - action_connector_results = results[ - rules_and_exceptions_count: rules_and_exceptions_count + action_connector_count + rules_results = results[:rules_count] # type: ignore[reportUnknownVariableType] + exception_results = results[rules_count : rules_count + exception_list_count + exception_list_item_count] # type: ignore[reportUnknownVariableType] + rules_and_exceptions_count = rules_count + exception_list_count + exception_list_item_count # type: ignore[reportUnknownVariableType] + action_connector_results = results[ # type: ignore[reportUnknownVariableType] + rules_and_exceptions_count : rules_and_exceptions_count + action_connector_count ] - errors = [] - exported = [] - exception_list_rule_table = {} - action_connector_rule_table = {} - for rule_resource in rules_results: + errors: list[str] = [] + exported: list[TOMLRule] = [] + exception_list_rule_table: dict[str, list[dict[str, Any]]] = {} + action_connector_rule_table: dict[str, list[dict[str, Any]]] = {} + for rule_resource in rules_results: # type: ignore[reportUnknownVariableType] try: if strip_version: - rule_resource.pop("revision", None) - rule_resource.pop("version", None) - rule_resource["author"] = rule_resource.get("author") or default_author or [rule_resource.get("created_by")] + rule_resource.pop("revision", None) # type: ignore[reportUnknownMemberType] + rule_resource.pop("version", None) # type: ignore[reportUnknownMemberType] + rule_resource["author"] = rule_resource.get("author") or default_author or [rule_resource.get("created_by")] # type: ignore[reportUnknownMemberType] if isinstance(rule_resource["author"], str): rule_resource["author"] = [rule_resource["author"]] # Inherit maturity and optionally local dates from the rule if it already exists - params = { + params: dict[str, Any] = { "rule": rule_resource, "maturity": "development", } - threat = rule_resource.get("threat") - first_tactic = threat[0].get("tactic").get("name") if threat else "" + threat = rule_resource.get("threat") # type: ignore[reportUnknownMemberType] + first_tactic = threat[0].get("tactic").get("name") if threat else "" # type: ignore[reportUnknownMemberType] # Check if flag or config is set to not include tactic in the filename no_tactic_filename = no_tactic_filename or RULES_CONFIG.no_tactic_filename # Check if the flag is set to not include tactic in the filename - tactic_name = first_tactic if not no_tactic_filename else None - rule_name = rulename_to_filename(rule_resource.get("name"), tactic_name=tactic_name) + tactic_name = first_tactic if not no_tactic_filename else None # type: ignore[reportUnknownMemberType] + rule_name = rulename_to_filename(rule_resource.get("name"), tactic_name=tactic_name) # type: ignore[reportUnknownMemberType] save_path = directory / f"{rule_name}" params.update( @@ -323,12 +352,12 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d save_path, {"creation_date": local_creation_date, "updated_date": local_updated_date} ) ) - contents = TOMLRuleContents.from_rule_resource(**params) + contents = TOMLRuleContents.from_rule_resource(**params) # type: ignore[reportArgumentType] rule = TOMLRule(contents=contents, path=save_path) except Exception as e: if skip_errors: - print(f'- skipping {rule_resource.get("name")} - {type(e).__name__}') - errors.append(f'- {rule_resource.get("name")} - {e}') + print(f"- skipping {rule_resource.get('name')} - {type(e).__name__}") # type: ignore[reportUnknownMemberType] + errors.append(f"- {rule_resource.get('name')} - {e}") # type: ignore[reportUnknownMemberType] continue raise if rule.contents.data.exceptions_list: @@ -354,7 +383,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d exceptions_containers = {} exceptions_items = {} - exceptions_containers, exceptions_items, parse_errors, _ = parse_exceptions_results_from_api(exception_results) + exceptions_containers, exceptions_items, parse_errors, _ = parse_exceptions_results_from_api(exception_results) # type: ignore[reportArgumentType] errors.extend(parse_errors) # Build TOMLException Objects @@ -374,7 +403,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d # Parse action connector results from API return action_connectors = [] if export_action_connectors: - action_connector_results, _ = parse_action_connector_results_from_api(action_connector_results) + action_connector_results, _ = parse_action_connector_results_from_api(action_connector_results) # type: ignore[reportArgumentType] # Build TOMLActionConnector Objects action_connectors, ac_output, ac_errors = build_action_connector_objects( @@ -389,7 +418,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d click.echo(line) errors.extend(ac_errors) - saved = [] + saved: list[TOMLRule] = [] for rule in exported: try: rule.save_toml() @@ -402,20 +431,20 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d saved.append(rule) - saved_exceptions = [] + saved_exceptions: list[TOMLException] = [] for exception in exceptions: try: exception.save_toml() except Exception as e: if skip_errors: - print(f"- skipping {exception.rule_name} - {type(e).__name__}") - errors.append(f"- {exception.rule_name} - {e}") + print(f"- skipping {exception.rule_name} - {type(e).__name__}") # type: ignore[reportUnknownMemberType] + errors.append(f"- {exception.rule_name} - {e}") # type: ignore[reportUnknownMemberType] continue raise saved_exceptions.append(exception) - saved_action_connectors = [] + saved_action_connectors: list[TOMLActionConnector] = [] for action in action_connectors: try: action.save_toml() @@ -428,7 +457,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d saved_action_connectors.append(action) - click.echo(f"{len(results)} results exported") + click.echo(f"{len(results)} results exported") # type: ignore[reportUnknownArgumentType] click.echo(f"{len(exported)} rules converted") click.echo(f"{len(exceptions)} exceptions exported") click.echo(f"{len(action_connectors)} action connectors exported") @@ -437,54 +466,55 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d click.echo(f"{len(saved_action_connectors)} action connectors saved to {action_connectors_directory}") if errors: err_file = directory / "_errors.txt" - err_file.write_text("\n".join(errors)) + _ = err_file.write_text("\n".join(errors)) click.echo(f"{len(errors)} errors saved to {err_file}") return exported -@kibana_group.command('search-alerts') -@click.argument('query', required=False) -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--columns', '-c', multiple=True, help='Columns to display in table') -@click.option('--extend', '-e', is_flag=True, help='If columns are specified, extend the original columns') -@click.option('--max-count', '-m', default=100, help='The max number of alerts to return') +@kibana_group.command("search-alerts") +@click.argument("query", required=False) +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option("--columns", "-c", multiple=True, help="Columns to display in table") +@click.option("--extend", "-e", is_flag=True, help="If columns are specified, extend the original columns") +@click.option("--max-count", "-m", default=100, help="The max number of alerts to return") @click.pass_context -def search_alerts(ctx, query, date_range, columns, extend, max_count): +def search_alerts( + ctx: click.Context, query: str, date_range: tuple[str, str], columns: list[str], extend: bool, max_count: int +): """Search detection engine alerts with KQL.""" - from eql.table import Table + from eql.table import Table # type: ignore[reportMissingTypeStubs] from .eswrap import MATCH_ALL, add_range_to_dsl - kibana = ctx.obj['kibana'] + kibana = ctx.obj["kibana"] start_time, end_time = date_range - kql_query = kql.to_dsl(query) if query else MATCH_ALL - add_range_to_dsl(kql_query['bool'].setdefault('filter', []), start_time, end_time) + kql_query = kql.to_dsl(query) if query else MATCH_ALL # type: ignore[reportUnknownMemberType] + add_range_to_dsl(kql_query["bool"].setdefault("filter", []), start_time, end_time) # type: ignore[reportUnknownArgumentType] with kibana: - alerts = [a['_source'] for a in Signal.search({'query': kql_query}, size=max_count)['hits']['hits']] + alerts = [a["_source"] for a in Signal.search({"query": kql_query}, size=max_count)["hits"]["hits"]] # type: ignore # check for events with nested signal fields if alerts: - table_columns = ['host.hostname'] + table_columns = ["host.hostname"] - if 'signal' in alerts[0]: - table_columns += ['signal.rule.name', 'signal.status', 'signal.original_time'] - elif 'kibana.alert.rule.name' in alerts[0]: - table_columns += ['kibana.alert.rule.name', 'kibana.alert.status', 'kibana.alert.original_time'] + if "signal" in alerts[0]: + table_columns += ["signal.rule.name", "signal.status", "signal.original_time"] + elif "kibana.alert.rule.name" in alerts[0]: + table_columns += ["kibana.alert.rule.name", "kibana.alert.status", "kibana.alert.original_time"] else: - table_columns += ['rule.name', '@timestamp'] + table_columns += ["rule.name", "@timestamp"] if columns: columns = list(columns) table_columns = table_columns + columns if extend else columns # Table requires the data to be nested, but depending on the version, some data uses dotted keys, so # they must be nested explicitly - for alert in alerts: + for alert in alerts: # type: ignore[reportUnknownVariableType] for key in table_columns: if key in alert: - nested_set(alert, key, alert[key]) + nested_set(alert, key, alert[key]) # type: ignore[reportUnknownArgumentType] - click.echo(Table.from_list(table_columns, alerts)) + click.echo(Table.from_list(table_columns, alerts)) # type: ignore[reportUnknownMemberType] else: - click.echo('No alerts detected') - return alerts + click.echo("No alerts detected") diff --git a/detection_rules/main.py b/detection_rules/main.py index 1e1dcfcac06..74ccc352bd6 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -4,83 +4,92 @@ # 2.0. """CLI commands for detection_rules.""" + import dataclasses import glob import json import os import time -from datetime import datetime +from datetime import datetime, timezone -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from marshmallow_dataclass import class_schema from pathlib import Path from semver import Version -from typing import Dict, Iterable, List, Optional, get_args +from typing import Iterable, get_args, Any, Literal from uuid import uuid4 import click -from .action_connector import (TOMLActionConnectorContents, - build_action_connector_objects, parse_action_connector_results_from_api) +from .action_connector import ( + TOMLActionConnectorContents, + build_action_connector_objects, + parse_action_connector_results_from_api, +) from .attack import build_threat_map_entry from .cli_utils import rule_prompt, multi_collection from .config import load_current_package_version, parse_rules_config from .generic_loader import GenericCollection -from .exception import (TOMLExceptionContents, - build_exception_objects, parse_exceptions_results_from_api) -from .misc import ( - add_client, client_error, nested_set, parse_user_config -) +from .exception import TOMLExceptionContents, build_exception_objects, parse_exceptions_results_from_api +from .misc import add_client, raise_client_error, nested_set, parse_user_config from .rule import TOMLRule, TOMLRuleContents, QueryRuleData from .rule_formatter import toml_write from .rule_loader import RuleCollection, update_metadata_from_file from .schemas import all_versions, definitions, get_incompatible_fields, get_schema_file -from .utils import Ndjson, get_path, get_etc_path, clear_caches, load_dump, load_rule_contents, rulename_to_filename +from .utils import Ndjson, get_path, get_etc_path, clear_caches, load_rule_contents, rulename_to_filename +from .utils import load_dump # type: ignore[reportUnknownVariableType] RULES_CONFIG = parse_rules_config() RULES_DIRS = RULES_CONFIG.rule_dirs @click.group( - 'detection-rules', + "detection-rules", context_settings={ - 'help_option_names': ['-h', '--help'], - 'max_content_width': int(os.getenv('DR_CLI_MAX_WIDTH', 240)), + "help_option_names": ["-h", "--help"], + "max_content_width": int(os.getenv("DR_CLI_MAX_WIDTH", 240)), }, ) -@click.option('--debug/--no-debug', '-D/-N', is_flag=True, default=None, - help='Print full exception stacktrace on errors') +@click.option( + "--debug/--no-debug", + "-D/-N", + is_flag=True, + default=None, + help="Print full exception stacktrace on errors", +) @click.pass_context -def root(ctx, debug): +def root(ctx: click.Context, debug: bool): """Commands for detection-rules repository.""" - debug = debug if debug is not None else parse_user_config().get('debug') - ctx.obj = {'debug': debug, 'rules_config': RULES_CONFIG} + debug = debug if debug else parse_user_config().get("debug") + ctx.obj = {"debug": debug, "rules_config": RULES_CONFIG} if debug: - click.secho('DEBUG MODE ENABLED', fg='yellow') + click.secho("DEBUG MODE ENABLED", fg="yellow") -@root.command('create-rule') -@click.argument('path', type=Path) -@click.option('--config', '-c', type=click.Path(exists=True, dir_okay=False, path_type=Path), - help='Rule or config file') -@click.option('--required-only', is_flag=True, help='Only prompt for required fields') -@click.option('--rule-type', '-t', type=click.Choice(sorted(TOMLRuleContents.all_rule_types())), - help='Type of rule to create') -def create_rule(path, config, required_only, rule_type): +@root.command("create-rule") +@click.argument("path", type=Path) +@click.option( + "--config", "-c", type=click.Path(exists=True, dir_okay=False, path_type=Path), help="Rule or config file" +) +@click.option("--required-only", is_flag=True, help="Only prompt for required fields") +@click.option( + "--rule-type", "-t", type=click.Choice(sorted(TOMLRuleContents.all_rule_types())), help="Type of rule to create" +) +def create_rule(path: Path, config: Path, required_only: bool, rule_type: str): """Create a detection rule.""" - contents = load_rule_contents(config, single_only=True)[0] if config else {} + contents: dict[str, Any] = load_rule_contents(config, single_only=True)[0] if config else {} return rule_prompt(path, rule_type=rule_type, required_only=required_only, save=True, **contents) -@root.command('generate-rules-index') -@click.option('--query', '-q', help='Optional KQL query to limit to specific rules') -@click.option('--overwrite', is_flag=True, help='Overwrite files in an existing folder') +@root.command("generate-rules-index") +@click.option("--query", "-q", help="Optional KQL query to limit to specific rules") +@click.option("--overwrite", is_flag=True, help="Overwrite files in an existing folder") @click.pass_context -def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): +def generate_rules_index(ctx: click.Context, query: str, overwrite: bool, save_files: bool = True): """Generate enriched indexes of rules, based on a KQL search, for indexing/importing into elasticsearch/kibana.""" from .packaging import Package if query: - rule_paths = [r['file'] for r in ctx.invoke(search_rules, query=query, verbose=False)] + rule_paths = [r["file"] for r in ctx.invoke(search_rules, query=query, verbose=False)] rules = RuleCollection() rules.load_files(Path(p) for p in rule_paths) else: @@ -92,37 +101,40 @@ def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): bulk_upload_docs, importable_rules_docs = package.create_bulk_index_body() if save_files: - path = get_path('enriched-rule-indexes', package_hash) + path = get_path(["enriched-rule-indexes", package_hash]) path.mkdir(parents=True, exist_ok=overwrite) - bulk_upload_docs.dump(path.joinpath('enriched-rules-index-uploadable.ndjson'), sort_keys=True) - importable_rules_docs.dump(path.joinpath('enriched-rules-index-importable.ndjson'), sort_keys=True) + bulk_upload_docs.dump(path.joinpath("enriched-rules-index-uploadable.ndjson"), sort_keys=True) + importable_rules_docs.dump(path.joinpath("enriched-rules-index-importable.ndjson"), sort_keys=True) - click.echo(f'files saved to: {path}') + click.echo(f"files saved to: {path}") - click.echo(f'{rule_count} rules included') + click.echo(f"{rule_count} rules included") return bulk_upload_docs, importable_rules_docs @root.command("import-rules-to-repo") -@click.argument("input-file", type=click.Path(dir_okay=False, exists=True), nargs=-1, required=False) +@click.argument("input-file", type=click.Path(dir_okay=False, exists=True, path_type=Path), nargs=-1, required=False) @click.option("--action-connector-import", "-ac", is_flag=True, help="Include action connectors in export") @click.option("--exceptions-import", "-e", is_flag=True, help="Include exceptions in export") @click.option("--required-only", is_flag=True, help="Only prompt for required fields") @click.option("--directory", "-d", type=click.Path(file_okay=False, exists=True), help="Load files from a directory") @click.option( - "--save-directory", "-s", type=click.Path(file_okay=False, exists=True), help="Save imported rules to a directory" + "--save-directory", + "-s", + type=click.Path(file_okay=False, exists=True, path_type=Path), + help="Save imported rules to a directory", ) @click.option( "--exceptions-directory", "-se", - type=click.Path(file_okay=False, exists=True), + type=click.Path(file_okay=False, exists=True, path_type=Path), help="Save imported exceptions to a directory", ) @click.option( "--action-connectors-directory", "-sa", - type=click.Path(file_okay=False, exists=True), + type=click.Path(file_okay=False, exists=True, path_type=Path), help="Save imported actions to a directory", ) @click.option("--skip-errors", "-ske", is_flag=True, help="Skip rule import errors") @@ -130,17 +142,28 @@ def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): @click.option("--strip-none-values", "-snv", is_flag=True, help="Strip None values from the rule") @click.option("--local-creation-date", "-lc", is_flag=True, help="Preserve the local creation date of the rule") @click.option("--local-updated-date", "-lu", is_flag=True, help="Preserve the local updated date of the rule") -def import_rules_into_repo(input_file: click.Path, required_only: bool, action_connector_import: bool, - exceptions_import: bool, directory: click.Path, save_directory: click.Path, - action_connectors_directory: click.Path, exceptions_directory: click.Path, - skip_errors: bool, default_author: str, strip_none_values: bool, local_creation_date: bool, - local_updated_date: bool): +def import_rules_into_repo( + input_file: Path | None, + required_only: bool, + action_connector_import: bool, + exceptions_import: bool, + directory: Path | None, + save_directory: Path, + action_connectors_directory: Path | None, + exceptions_directory: Path | None, + skip_errors: bool, + default_author: str, + strip_none_values: bool, + local_creation_date: bool, + local_updated_date: bool, +): """Import rules from json, toml, or yaml files containing Kibana exported rule(s).""" - errors = [] - rule_files = glob.glob(os.path.join(directory, "**", "*.*"), recursive=True) if directory else [] - rule_files = sorted(set(rule_files + list(input_file))) + errors: list[str] = [] + rule_files = glob.glob(os.path.join(str(directory), "**", "*.*"), recursive=True) if directory else [] + if input_file: + rule_files = sorted(set(rule_files + [input_file])) - file_contents = [] + file_contents: list[Any] = [] for rule_file in rule_files: file_contents.extend(load_rule_contents(Path(rule_file))) @@ -156,19 +179,19 @@ def import_rules_into_repo(input_file: click.Path, required_only: bool, action_c file_contents = unparsed_results - exception_list_rule_table = {} - action_connector_rule_table = {} + exception_list_rule_table: dict[str, Any] = {} + action_connector_rule_table: dict[str, Any] = {} rule_count = 0 for contents in file_contents: # Don't load exceptions as rules if contents.get("type") not in get_args(definitions.RuleType): - click.echo(f"Skipping - {contents.get("type")} is not a supported rule type") + click.echo(f"Skipping - {contents.get('type')} is not a supported rule type") continue base_path = contents.get("name") or contents.get("rule", {}).get("name") base_path = rulename_to_filename(base_path) if base_path else base_path if base_path is None: raise ValueError(f"Invalid rule file, please ensure the rule has a name field: {contents}") - rule_path = os.path.join(save_directory if save_directory is not None else RULES_DIRS[0], base_path) + rule_path = Path(os.path.join(str(save_directory) if save_directory else RULES_DIRS[0], base_path)) # handle both rule json formats loaded from kibana and toml data_view_id = contents.get("data_view_id") or contents.get("rule", {}).get("data_view_id") @@ -255,15 +278,20 @@ def import_rules_into_repo(input_file: click.Path, required_only: bool, action_c click.echo(f"{exceptions_count} exceptions exported") click.echo(f"{len(action_connectors)} actions connectors exported") if errors: - err_file = save_directory if save_directory is not None else RULES_DIRS[0] / "_errors.txt" - err_file.write_text("\n".join(errors)) + dir = save_directory if save_directory else RULES_DIRS[0] + err_file = dir / "_errors.txt" + _ = err_file.write_text("\n".join(errors)) click.echo(f"{len(errors)} errors saved to {err_file}") -@root.command('build-limited-rules') -@click.option('--stack-version', type=click.Choice(all_versions()), required=True, - help='Version to downgrade to be compatible with the older instance of Kibana') -@click.option('--output-file', '-o', type=click.Path(dir_okay=False, exists=False), required=True) +@root.command("build-limited-rules") +@click.option( + "--stack-version", + type=click.Choice(all_versions()), + required=True, + help="Version to downgrade to be compatible with the older instance of Kibana", +) +@click.option("--output-file", "-o", type=click.Path(dir_okay=False, exists=False), required=True) def build_limited_rules(stack_version: str, output_file: str): """ Import rules from json, toml, or Kibana exported rule file(s), @@ -272,9 +300,10 @@ def build_limited_rules(stack_version: str, output_file: str): # Schema generation and incompatible fields detection query_rule_data = class_schema(QueryRuleData)() - fields = getattr(query_rule_data, 'fields', {}) - incompatible_fields = get_incompatible_fields(list(fields.values()), - Version.parse(stack_version, optional_minor_and_patch=True)) + fields = getattr(query_rule_data, "fields", {}) + incompatible_fields = get_incompatible_fields( + list(fields.values()), Version.parse(stack_version, optional_minor_and_patch=True) + ) # Load all rules rules = RuleCollection.default() @@ -289,10 +318,11 @@ def build_limited_rules(stack_version: str, output_file: str): api_schema = get_schema_file(stack_version, "base")["properties"]["type"]["enum"] # Function to process each rule - def process_rule(rule, incompatible_fields: List[str]): + def process_rule(rule: TOMLRule, incompatible_fields: list[str]): if rule.contents.type not in api_schema: - click.secho(f'{rule.contents.name} - Skipping unsupported rule type: {rule.contents.get("type")}', - fg='yellow') + click.secho( + f"{rule.contents.name} - Skipping unsupported rule type: {rule.contents.get('type')}", fg="yellow" + ) return None # Remove unsupported fields from rule @@ -311,13 +341,18 @@ def process_rule(rule, incompatible_fields: List[str]): # Write ndjson_output to file ndjson_output.dump(output_path) - click.echo(f'Success: Rules written to {output_file}') + click.echo(f"Success: Rules written to {output_file}") -@root.command('toml-lint') -@click.option('--rule-file', '-f', multiple=True, type=click.Path(exists=True), - help='Specify one or more rule files.') -def toml_lint(rule_file): +@root.command("toml-lint") +@click.option( + "--rule-file", + "-f", + multiple=True, + type=click.Path(exists=True, path_type=Path), + help="Specify one or more rule files.", +) +def toml_lint(rule_file: list[Path]): """Cleanup files with some simple toml formatting.""" if rule_file: rules = RuleCollection() @@ -329,19 +364,25 @@ def toml_lint(rule_file): for rule in rules: rule.save_toml() - click.echo('TOML file linting complete') + click.echo("TOML file linting complete") -@root.command('mass-update') -@click.argument('query') -@click.option('--metadata', '-m', is_flag=True, help='Make an update to the rule metadata rather than contents.') -@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -@click.option('--field', type=(str, str), multiple=True, - help='Use rule-search to retrieve a subset of rules and modify values ' - '(ex: --field management.ecs_version 1.1.1).\n' - 'Note this is limited to string fields only. Nested fields should use dot notation.') +@root.command("mass-update") +@click.argument("query") +@click.option("--metadata", "-m", is_flag=True, help="Make an update to the rule metadata rather than contents.") +@click.option("--language", type=click.Choice(["eql", "kql"]), default="kql") +@click.option( + "--field", + type=(str, str), + multiple=True, + help="Use rule-search to retrieve a subset of rules and modify values " + "(ex: --field management.ecs_version 1.1.1).\n" + "Note this is limited to string fields only. Nested fields should use dot notation.", +) @click.pass_context -def mass_update(ctx, query, metadata, language, field): +def mass_update( + ctx: click.Context, query: str, metadata: bool, language: Literal["eql", "kql"], field: tuple[str, str] +): """Update multiple rules based on eql results.""" rules = RuleCollection().default() results = ctx.invoke(search_rules, query=query, language=language, verbose=False) @@ -350,27 +391,31 @@ def mass_update(ctx, query, metadata, language, field): for rule in rules: for key, value in field: - nested_set(rule.metadata if metadata else rule.contents, key, value) + nested_set(rule.metadata if metadata else rule.contents, key, value) # type: ignore[reportAttributeAccessIssue] - rule.validate(as_rule=True) - rule.save(as_rule=True) + rule.validate(as_rule=True) # type: ignore[reportAttributeAccessIssue] + rule.save(as_rule=True) # type: ignore[reportAttributeAccessIssue] - return ctx.invoke(search_rules, query=query, language=language, - columns=['rule_id', 'name'] + [k[0].split('.')[-1] for k in field]) + return ctx.invoke( + search_rules, + query=query, + language=language, + columns=["rule_id", "name"] + [k[0].split(".")[-1] for k in field], + ) -@root.command('view-rule') -@click.argument('rule-file', type=Path) -@click.option('--api-format/--rule-format', default=True, help='Print the rule in final api or rule format') +@root.command("view-rule") +@click.argument("rule-file", type=Path) +@click.option("--api-format/--rule-format", default=True, help="Print the rule in final api or rule format") @click.pass_context -def view_rule(ctx, rule_file, api_format): +def view_rule(_: click.Context, rule_file: Path, api_format: str): """View an internal rule or specified rule file.""" rule = RuleCollection().load_file(rule_file) if api_format: click.echo(json.dumps(rule.contents.to_api_format(), indent=2, sort_keys=True)) else: - click.echo(toml_write(rule.contents.to_dict())) + click.echo(toml_write(rule.contents.to_dict())) # type: ignore[reportAttributeAccessIssue] return rule @@ -378,9 +423,9 @@ def view_rule(ctx, rule_file, api_format): def _export_rules( rules: RuleCollection, outfile: Path, - downgrade_version: Optional[definitions.SemVer] = None, - verbose=True, - skip_unsupported=False, + downgrade_version: definitions.SemVer | None = None, + verbose: bool = True, + skip_unsupported: bool = False, include_metadata: bool = False, include_action_connectors: bool = False, include_exceptions: bool = False, @@ -388,29 +433,37 @@ def _export_rules( """Export rules and exceptions into a consolidated ndjson file.""" from .rule import downgrade_contents_from_rule - outfile = outfile.with_suffix('.ndjson') - unsupported = [] + outfile = outfile.with_suffix(".ndjson") + unsupported: list[str] = [] if downgrade_version: if skip_unsupported: - output_lines = [] + output_lines: list[str] = [] for rule in rules: try: - output_lines.append(json.dumps(downgrade_contents_from_rule(rule, downgrade_version, - include_metadata=include_metadata), - sort_keys=True)) + output_lines.append( + json.dumps( + downgrade_contents_from_rule(rule, downgrade_version, include_metadata=include_metadata), + sort_keys=True, + ) + ) except ValueError as e: - unsupported.append(f'{e}: {rule.id} - {rule.name}') + unsupported.append(f"{e}: {rule.id} - {rule.name}") continue else: - output_lines = [json.dumps(downgrade_contents_from_rule(r, downgrade_version, - include_metadata=include_metadata), sort_keys=True) - for r in rules] + output_lines = [ + json.dumps( + downgrade_contents_from_rule(r, downgrade_version, include_metadata=include_metadata), + sort_keys=True, + ) + for r in rules + ] else: - output_lines = [json.dumps(r.contents.to_api_format(include_metadata=include_metadata), - sort_keys=True) for r in rules] + output_lines = [ + json.dumps(r.contents.to_api_format(include_metadata=include_metadata), sort_keys=True) for r in rules + ] # Add exceptions to api format here and add to output_lines if include_exceptions or include_action_connectors: @@ -427,14 +480,14 @@ def _export_rules( actions = [a for sublist in action_connectors for a in sublist] output_lines.extend(json.dumps(a, sort_keys=True) for a in actions) - outfile.write_text('\n'.join(output_lines) + '\n') + _ = outfile.write_text("\n".join(output_lines) + "\n") if verbose: - click.echo(f'Exported {len(rules) - len(unsupported)} rules into {outfile}') + click.echo(f"Exported {len(rules) - len(unsupported)} rules into {outfile}") if skip_unsupported and unsupported: - unsupported_str = '\n- '.join(unsupported) - click.echo(f'Skipped {len(unsupported)} unsupported rules: \n- {unsupported_str}') + unsupported_str = "\n- ".join(unsupported) + click.echo(f"Skipped {len(unsupported)} unsupported rules: \n- {unsupported_str}") @root.command("export-rules-from-repo") @@ -442,7 +495,7 @@ def _export_rules( @click.option( "--outfile", "-o", - default=Path(get_path("exports", f'{time.strftime("%Y%m%dT%H%M%SL")}.ndjson')), + default=Path(get_path(["exports", f"{time.strftime('%Y%m%dT%H%M%SL')}.ndjson"])), type=Path, help="Name of file for exported rules", ) @@ -456,7 +509,7 @@ def _export_rules( "--skip-unsupported", "-s", is_flag=True, - help="If `--stack-version` is passed, skip rule types which are unsupported " "(an error will be raised otherwise)", + help="If `--stack-version` is passed, skip rule types which are unsupported (an error will be raised otherwise)", ) @click.option("--include-metadata", type=bool, is_flag=True, default=False, help="Add metadata to the exported rules") @click.option( @@ -470,8 +523,16 @@ def _export_rules( @click.option( "--include-exceptions", "-e", type=bool, is_flag=True, default=False, help="Include Exceptions Lists in export" ) -def export_rules_from_repo(rules, outfile: Path, replace_id, stack_version, skip_unsupported, include_metadata: bool, - include_action_connectors: bool, include_exceptions: bool) -> RuleCollection: +def export_rules_from_repo( + rules: RuleCollection, + outfile: Path, + replace_id: bool, + stack_version: str, + skip_unsupported: bool, + include_metadata: bool, + include_action_connectors: bool, + include_exceptions: bool, +) -> RuleCollection: """Export rule(s) and exception(s) into an importable ndjson file.""" assert len(rules) > 0, "No rules found" @@ -500,79 +561,91 @@ def export_rules_from_repo(rules, outfile: Path, replace_id, stack_version, skip return rules -@root.command('validate-rule') -@click.argument('path') +@root.command("validate-rule") +@click.argument("path") @click.pass_context -def validate_rule(ctx, path): +def validate_rule(_: click.Context, path: str): """Check if a rule staged in rules dir validates against a schema.""" rule = RuleCollection().load_file(Path(path)) - click.echo('Rule validation successful') + click.echo("Rule validation successful") return rule -@root.command('validate-all') +@root.command("validate-all") def validate_all(): """Check if all rules validates against a schema.""" - RuleCollection.default() - click.echo('Rule validation successful') - - -@root.command('rule-search') -@click.argument('query', required=False) -@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') -@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -@click.option('--count', is_flag=True, help='Return a count rather than table') -def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, TOMLRule] = None, pager=False): + _ = RuleCollection.default() + click.echo("Rule validation successful") + + +@root.command("rule-search") +@click.argument("query", required=False) +@click.option("--columns", "-c", multiple=True, help="Specify columns to add the table") +@click.option("--language", type=click.Choice(["eql", "kql"]), default="kql") +@click.option("--count", is_flag=True, help="Return a count rather than table") +def search_rules( + query: str | None, + columns: list[str], + language: Literal["eql", "kql"], + count: bool, + verbose: bool = True, + rules: dict[str, TOMLRule] | None = None, + pager: bool = False, +): """Use KQL or EQL to find matching rules.""" - from kql import get_evaluator - from eql.table import Table - from eql.build import get_engine - from eql import parse_query - from eql.pipes import CountPipe + from kql import get_evaluator # type: ignore[reportMissingTypeStubs] + from eql.table import Table # type: ignore[reportMissingTypeStubs] + from eql.build import get_engine # type: ignore[reportMissingTypeStubs] + from eql import parse_query # type: ignore[reportMissingTypeStubs] + from eql.pipes import CountPipe # type: ignore[reportMissingTypeStubs] from .rule import get_unique_query_fields - flattened_rules = [] + flattened_rules: list[dict[str, Any]] = [] rules = rules or {str(rule.path): rule for rule in RuleCollection.default()} for file_name, rule in rules.items(): - flat: dict = {"file": os.path.relpath(file_name)} + flat: dict[str, Any] = {"file": os.path.relpath(file_name)} flat.update(rule.contents.to_dict()) flat.update(flat["metadata"]) flat.update(flat["rule"]) - tactic_names = [] - technique_ids = [] - subtechnique_ids = [] + tactic_names: list[str] = [] + technique_ids: list[str] = [] + subtechnique_ids: list[str] = [] - for entry in flat['rule'].get('threat', []): + for entry in flat["rule"].get("threat", []): if entry["framework"] != "MITRE ATT&CK": continue - techniques = entry.get('technique', []) - tactic_names.append(entry['tactic']['name']) - technique_ids.extend([t['id'] for t in techniques]) - subtechnique_ids.extend([st['id'] for t in techniques for st in t.get('subtechnique', [])]) + techniques = entry.get("technique", []) + tactic_names.append(entry["tactic"]["name"]) + technique_ids.extend([t["id"] for t in techniques]) + subtechnique_ids.extend([st["id"] for t in techniques for st in t.get("subtechnique", [])]) - flat.update(techniques=technique_ids, tactics=tactic_names, subtechniques=subtechnique_ids, - unique_fields=get_unique_query_fields(rule)) + flat.update( + techniques=technique_ids, + tactics=tactic_names, + subtechniques=subtechnique_ids, + unique_fields=get_unique_query_fields(rule), + ) flattened_rules.append(flat) flattened_rules.sort(key=lambda dct: dct["name"]) - filtered = [] + filtered: list[dict[str, Any]] = [] if language == "kql": - evaluator = get_evaluator(query) if query else lambda x: True - filtered = list(filter(evaluator, flattened_rules)) + evaluator = get_evaluator(query) if query else lambda _: True # type: ignore[reportUnknownLambdaType] + filtered = list(filter(evaluator, flattened_rules)) # type: ignore[reportCallIssue] elif language == "eql": - parsed = parse_query(query, implied_any=True, implied_base=True) - evaluator = get_engine(parsed) - filtered = [result.events[0].data for result in evaluator(flattened_rules)] + parsed = parse_query(query, implied_any=True, implied_base=True) # type: ignore[reportUnknownVariableType] + evaluator = get_engine(parsed) # type: ignore[reportUnknownVariableType] + filtered = [result.events[0].data for result in evaluator(flattened_rules)] # type: ignore[reportUnknownVariableType] - if not columns and any(isinstance(pipe, CountPipe) for pipe in parsed.pipes): + if not columns and any(isinstance(pipe, CountPipe) for pipe in parsed.pipes): # type: ignore[reportAttributeAccessIssue] columns = ["key", "count", "percent"] if count: - click.echo(f'{len(filtered)} rules') + click.echo(f"{len(filtered)} rules") return filtered if columns: @@ -580,7 +653,7 @@ def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, else: columns = ["rule_id", "file", "name"] - table = Table.from_list(columns, filtered) + table: Table = Table.from_list(columns, filtered) # type: ignore[reportUnknownMemberType] if verbose: click.echo_via_pager(table) if pager else click.echo(table) @@ -588,70 +661,71 @@ def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, return filtered -@root.command('build-threat-map-entry') -@click.argument('tactic') -@click.argument('technique-ids', nargs=-1) +@root.command("build-threat-map-entry") +@click.argument("tactic") +@click.argument("technique-ids", nargs=-1) def build_threat_map(tactic: str, technique_ids: Iterable[str]): """Build a threat map entry.""" entry = build_threat_map_entry(tactic, *technique_ids) - rendered = pytoml.dumps({'rule': {'threat': [entry]}}) + rendered = pytoml.dumps({"rule": {"threat": [entry]}}) # type: ignore[reportUnknownMemberType] # strip out [rule] - cleaned = '\n'.join(rendered.splitlines()[2:]) + cleaned = "\n".join(rendered.splitlines()[2:]) print(cleaned) return entry @root.command("test") @click.pass_context -def test_rules(ctx): +def test_rules(ctx: click.Context): """Run unit tests over all of the rules.""" import pytest - rules_config = ctx.obj['rules_config'] + rules_config = ctx.obj["rules_config"] test_config = rules_config.test_config tests, skipped = test_config.get_test_names(formatted=True) if skipped: - click.echo(f'Tests skipped per config ({len(skipped)}):') - click.echo('\n'.join(skipped)) + click.echo(f"Tests skipped per config ({len(skipped)}):") + click.echo("\n".join(skipped)) clear_caches() if tests: - ctx.exit(pytest.main(['-v'] + tests)) + ctx.exit(pytest.main(["-v"] + tests)) else: - click.echo('No tests found to execute!') + click.echo("No tests found to execute!") -@root.group('typosquat') +@root.group("typosquat") def typosquat_group(): """Commands for generating typosquat detections.""" -@typosquat_group.command('create-dnstwist-index') -@click.argument('input-file', type=click.Path(exists=True, dir_okay=False), required=True) +@typosquat_group.command("create-dnstwist-index") +@click.argument("input-file", type=click.Path(exists=True, dir_okay=False), required=True) @click.pass_context -@add_client('elasticsearch', add_func_arg=False) +@add_client(["elasticsearch"], add_func_arg=False) def create_dnstwist_index(ctx: click.Context, input_file: click.Path): """Create a dnstwist index in Elasticsearch to work with a threat match rule.""" from elasticsearch import Elasticsearch - es_client: Elasticsearch = ctx.obj['es'] + es_client: Elasticsearch = ctx.obj["es"] - click.echo(f'Attempting to load dnstwist data from {input_file}') - dnstwist_data: dict = load_dump(str(input_file)) - click.echo(f'{len(dnstwist_data)} records loaded') + click.echo(f"Attempting to load dnstwist data from {input_file}") + dnstwist_data: list[dict[str, Any]] = load_dump(str(input_file)) # type: ignore[reportAssignmentType] + click.echo(f"{len(dnstwist_data)} records loaded") - original_domain = next(r['domain-name'] for r in dnstwist_data if r.get('fuzzer', '') == 'original*') - click.echo(f'Original domain name identified: {original_domain}') + original_domain = next(r["domain-name"] for r in dnstwist_data if r.get("fuzzer", "") == "original*") # type: ignore[reportAttributeAccessIssue] + click.echo(f"Original domain name identified: {original_domain}") - domain = original_domain.split('.')[0] - domain_index = f'dnstwist-{domain}' + domain = original_domain.split(".")[0] + domain_index = f"dnstwist-{domain}" # If index already exists, prompt user to confirm if they want to overwrite if es_client.indices.exists(index=domain_index): if click.confirm( - f"dnstwist index: {domain_index} already exists for {original_domain}. Do you want to overwrite?", - abort=True): - es_client.indices.delete(index=domain_index) + f"dnstwist index: {domain_index} already exists for {original_domain}. Do you want to overwrite?", + abort=True, + ): + _ = es_client.indices.delete(index=domain_index) fields = [ "dns-a", @@ -661,52 +735,52 @@ def create_dnstwist_index(ctx: click.Context, input_file: click.Path): "banner-http", "fuzzer", "original-domain", - "dns.question.registered_domain" + "dns.question.registered_domain", ] timestamp_field = "@timestamp" mappings = {"mappings": {"properties": {f: {"type": "keyword"} for f in fields}}} mappings["mappings"]["properties"][timestamp_field] = {"type": "date"} - es_client.indices.create(index=domain_index, body=mappings) + _ = es_client.indices.create(index=domain_index, body=mappings) # handle dns.question.registered_domain separately - fields.pop() - es_updates = [] - now = datetime.utcnow() + _ = fields.pop() + es_updates: list[dict[str, Any]] = [] + now = datetime.now(timezone.utc) for item in dnstwist_data: - if item['fuzzer'] == 'original*': + if item["fuzzer"] == "original*": continue record = item.copy() - record.setdefault('dns', {}).setdefault('question', {}).setdefault('registered_domain', item.get('domain-name')) + record.setdefault("dns", {}).setdefault("question", {}).setdefault("registered_domain", item.get("domain-name")) for field in fields: - record.setdefault(field, None) + _ = record.setdefault(field, None) - record['@timestamp'] = now + record["@timestamp"] = now - es_updates.extend([{'create': {'_index': domain_index}}, record]) + es_updates.extend([{"create": {"_index": domain_index}}, record]) - click.echo(f'Indexing data for domain {original_domain}') + click.echo(f"Indexing data for domain {original_domain}") results = es_client.bulk(body=es_updates) - if results['errors']: - error = {r['create']['result'] for r in results['items'] if r['create']['status'] != 201} - client_error(f'Errors occurred during indexing:\n{error}') + if results["errors"]: + error = {r["create"]["result"] for r in results["items"] if r["create"]["status"] != 201} + raise_client_error(f"Errors occurred during indexing:\n{error}") - click.echo(f'{len(results["items"])} watchlist domains added to index') - click.echo('Run `prep-rule` and import to Kibana to create alerts on this index') + click.echo(f"{len(results['items'])} watchlist domains added to index") + click.echo("Run `prep-rule` and import to Kibana to create alerts on this index") -@typosquat_group.command('prep-rule') -@click.argument('author') +@typosquat_group.command("prep-rule") +@click.argument("author") def prep_rule(author: str): """Prep the detection threat match rule for dnstwist data with a rule_id and author.""" - rule_template_file = get_etc_path('rule_template_typosquatting_domain.json') + rule_template_file = get_etc_path(["rule_template_typosquatting_domain.json"]) template_rule = json.loads(rule_template_file.read_text()) template_rule.update(author=[author], rule_id=str(uuid4())) - updated_rule = get_path('rule_typosquatting_domain.ndjson') - updated_rule.write_text(json.dumps(template_rule, sort_keys=True)) - click.echo(f'Rule saved to: {updated_rule}. Import this to Kibana to create alerts on all dnstwist-* indexes') - click.echo('Note: you only need to import and enable this rule one time for all dnstwist-* indexes') + updated_rule = get_path(["rule_typosquatting_domain.ndjson"]) + _ = updated_rule.write_text(json.dumps(template_rule, sort_keys=True)) + click.echo(f"Rule saved to: {updated_rule}. Import this to Kibana to create alerts on all dnstwist-* indexes") + click.echo("Note: you only need to import and enable this rule one time for all dnstwist-* indexes") diff --git a/detection_rules/misc.py b/detection_rules/misc.py index dcf1fc51ad0..c2356cc2b13 100644 --- a/detection_rules/misc.py +++ b/detection_rules/misc.py @@ -4,6 +4,7 @@ # 2.0. """Misc support.""" + import os import re import time @@ -11,16 +12,14 @@ import uuid from pathlib import Path from functools import wraps -from typing import NoReturn, Optional +from typing import NoReturn, IO, Any, Callable import click import requests -from kibana import Kibana - -from .utils import add_params, cached, get_path, load_etc_dump +from kibana import Kibana # type: ignore[reportMissingTypeStubs] -_CONFIG = {} +from .utils import add_params, cached, load_etc_dump LICENSE_HEADER = """ Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one @@ -35,7 +34,7 @@ /* {} */ -""".strip().format("\n".join(' * ' + line for line in LICENSE_LINES)) +""".strip().format("\n".join(" * " + line for line in LICENSE_LINES)) ROOT_DIR = Path(__file__).parent.parent @@ -44,57 +43,60 @@ class ClientError(click.ClickException): """Custom CLI error to format output or full debug stacktrace.""" - def __init__(self, message, original_error=None): + def __init__(self, message: str, original_error: Exception | None = None): super(ClientError, self).__init__(message) self.original_error = original_error - self.original_error_type = type(original_error).__name__ if original_error else '' + self.original_error_type = type(original_error).__name__ if original_error else "" - def show(self, file=None, err=True): + def show(self, file: IO[Any] | None = None, err: bool = True): """Print the error to the console.""" # err_msg = f' {self.original_error_type}' if self.original_error else '' - msg = f'{click.style(f"CLI Error ({self.original_error_type})", fg="red", bold=True)}: {self.format_message()}' + msg = f"{click.style(f'CLI Error ({self.original_error_type})', fg='red', bold=True)}: {self.format_message()}" click.echo(msg, err=err, file=file) -def client_error(message, exc: Exception = None, debug=None, ctx: click.Context = None, file=None, - err=None) -> NoReturn: - config_debug = True if ctx and ctx.ensure_object(dict) and ctx.obj.get('debug') is True else False +def raise_client_error( + message: str, + exc: Exception | None = None, + debug: bool | None = False, + ctx: click.Context | None = None, + file: IO[Any] | None = None, + err: bool = False, +) -> NoReturn: + config_debug = True if ctx and ctx.ensure_object(dict) and ctx.obj.get("debug") is True else False debug = debug if debug is not None else config_debug if debug: - click.echo(click.style('DEBUG: ', fg='yellow') + message, err=err, file=file) + click.echo(click.style("DEBUG: ", fg="yellow") + message, err=err, file=file) raise else: raise ClientError(message, original_error=exc) -def nested_get(_dict, dot_key, default=None): +def nested_get(_dict: dict[str, Any] | None, dot_key: str | None, default: Any | None = None) -> Any: """Get a nested field from a nested dict with dot notation.""" if _dict is None or dot_key is None: return default - elif '.' in dot_key and isinstance(_dict, dict): - dot_key = dot_key.split('.') - this_key = dot_key.pop(0) - return nested_get(_dict.get(this_key, default), '.'.join(dot_key), default) + elif "." in dot_key: + dot_key_parts = dot_key.split(".") + this_key = dot_key_parts.pop(0) + return nested_get(_dict.get(this_key, default), ".".join(dot_key_parts), default) else: return _dict.get(dot_key, default) -def nested_set(_dict, dot_key, value): +def nested_set(_dict: dict[str, Any], dot_key: str, value: Any): """Set a nested field from a key in dot notation.""" - keys = dot_key.split('.') + keys = dot_key.split(".") for key in keys[:-1]: _dict = _dict.setdefault(key, {}) - if isinstance(_dict, dict): - _dict[keys[-1]] = value - else: - raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key)) + _dict[keys[-1]] = value -def nest_from_dot(dots, value): +def nest_from_dot(dots: str, value: Any) -> Any: """Nest a dotted field and set the innermost value.""" - fields = dots.split('.') + fields = dots.split(".") if not fields: return {} @@ -107,65 +109,70 @@ def nest_from_dot(dots, value): return nested -def schema_prompt(name, value=None, is_required=False, **options): +def schema_prompt(name: str, value: Any | None = None, is_required: bool = False, **options: Any) -> Any: """Interactively prompt based on schema requirements.""" - name = str(name) - field_type = options.get('type') - pattern = options.get('pattern') - enum = options.get('enum', []) - minimum = options.get('minimum') - maximum = options.get('maximum') - min_item = options.get('min_items', 0) - max_items = options.get('max_items', 9999) - - default = options.get('default') - if default is not None and str(default).lower() in ('true', 'false'): + field_type = options.get("type") + pattern: str | None = options.get("pattern") + enum = options.get("enum", []) + minimum = int(options["minimum"]) if "minimum" in options else None + maximum = int(options["maximum"]) if "maximum" in options else None + min_item = int(options.get("min_items", 0)) + max_items = int(options.get("max_items", 9999)) + + default = options.get("default") + if default is not None and str(default).lower() in ("true", "false"): default = str(default).lower() - if 'date' in name: - default = time.strftime('%Y/%m/%d') + if "date" in name: + default = time.strftime("%Y/%m/%d") - if name == 'rule_id': + if name == "rule_id": default = str(uuid.uuid4()) if len(enum) == 1 and is_required and field_type != "array": return enum[0] - def _check_type(_val): - if field_type in ('number', 'integer') and not str(_val).isdigit(): - print('Number expected but got: {}'.format(_val)) - return False - if pattern and (not re.match(pattern, _val) or len(re.match(pattern, _val).group(0)) != len(_val)): - print('{} did not match pattern: {}!'.format(_val, pattern)) + def _check_type(_val: Any): + if field_type in ("number", "integer") and not str(_val).isdigit(): + print("Number expected but got: {}".format(_val)) return False + if pattern: + match = re.match(pattern, _val) + if not match or len(match.group(0)) != len(_val): + print("{} did not match pattern: {}!".format(_val, pattern)) + return False if enum and _val not in enum: - print('{} not in valid options: {}'.format(_val, ', '.join(enum))) + print("{} not in valid options: {}".format(_val, ", ".join(enum))) return False if minimum and (type(_val) is int and int(_val) < minimum): - print('{} is less than the minimum: {}'.format(str(_val), str(minimum))) + print("{} is less than the minimum: {}".format(str(_val), str(minimum))) return False if maximum and (type(_val) is int and int(_val) > maximum): - print('{} is greater than the maximum: {}'.format(str(_val), str(maximum))) + print("{} is greater than the maximum: {}".format(str(_val), str(maximum))) return False - if field_type == 'boolean' and _val.lower() not in ('true', 'false'): - print('Boolean expected but got: {}'.format(str(_val))) + if type(_val) is str and field_type == "boolean" and _val.lower() not in ("true", "false"): + print("Boolean expected but got: {}".format(str(_val))) return False return True - def _convert_type(_val): - if field_type == 'boolean' and not type(_val) is bool: - _val = True if _val.lower() == 'true' else False - return int(_val) if field_type in ('number', 'integer') else _val - - prompt = '{name}{default}{required}{multi}'.format( - name=name, - default=' [{}] ("n/a" to leave blank) '.format(default) if default else '', - required=' (required) ' if is_required else '', - multi=' (multi, comma separated) ' if field_type == 'array' else '').strip() + ': ' + def _convert_type(_val: Any): + if field_type == "boolean" and type(_val) is not bool: + _val = True if _val.lower() == "true" else False + return int(_val) if field_type in ("number", "integer") else _val + + prompt = ( + "{name}{default}{required}{multi}".format( + name=name, + default=' [{}] ("n/a" to leave blank) '.format(default) if default else "", + required=" (required) " if is_required else "", + multi=" (multi, comma separated) " if field_type == "array" else "", + ).strip() + + ": " + ) while True: result = value or input(prompt) or default - if result == 'n/a': + if result == "n/a": result = None if not result: @@ -175,8 +182,8 @@ def _convert_type(_val): else: return - if field_type == 'array': - result_list = result.split(',') + if field_type == "array": + result_list = result.split(",") if not (min_item < len(result_list) < max_items): if is_required: @@ -205,51 +212,85 @@ def _convert_type(_val): return -def get_kibana_rules_map(repo='elastic/kibana', branch='master'): +def get_kibana_rules_map(repo: str = "elastic/kibana", branch: str = "master") -> dict[str, Any]: """Get list of available rules from the Kibana repo and return a list of URLs.""" # ensure branch exists - r = requests.get(f'https://api.github.com/repos/{repo}/branches/{branch}') + r = requests.get(f"https://api.github.com/repos/{repo}/branches/{branch}") r.raise_for_status() - url = ('https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/' - 'detection_engine/rules/prepackaged_rules?ref={branch}') + url = ( + "https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/" + "detection_engine/rules/prepackaged_rules?ref={branch}" + ) - gh_rules = requests.get(url.format(legacy='', app='security_solution', branch=branch, repo=repo)).json() + r = requests.get(url.format(legacy="", app="security_solution", branch=branch, repo=repo)) + r.raise_for_status() + + gh_rules = r.json() # pre-7.9 app was siem - if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found': - gh_rules = requests.get(url.format(legacy='', app='siem', branch=branch, repo=repo)).json() + if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType] + gh_rules = requests.get(url.format(legacy="", app="siem", branch=branch, repo=repo)).json() # pre-7.8 the siem was under the legacy directory - if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found': - gh_rules = requests.get(url.format(legacy='legacy/', app='siem', branch=branch, repo=repo)).json() + if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType] + gh_rules = requests.get(url.format(legacy="legacy/", app="siem", branch=branch, repo=repo)).json() + + if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType] + raise ValueError(f"rules directory does not exist for {repo} branch: {branch}") + + if not isinstance(gh_rules, list): + raise ValueError("Expected to receive a list") + + results: dict[str, Any] = {} - if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found': - raise ValueError(f'rules directory does not exist for {repo} branch: {branch}') + for r in gh_rules: # type: ignore[reportUnknownMemberType] + if "name" not in r: + raise ValueError("Name value is expected") - return {os.path.splitext(r['name'])[0]: r['download_url'] for r in gh_rules if r['name'].endswith('.json')} + name = r["name"] # type: ignore[reportUnknownMemberType] + if not isinstance(name, str): + raise ValueError("String value is expected for name") -def get_kibana_rules(*rule_paths, repo='elastic/kibana', branch='master', verbose=True, threads=50): + if name.endswith(".json"): + name_parts = os.path.splitext(name) + key = name_parts[0] + val = r["download_url"] # type: ignore[reportUnknownMemberType] + results[key] = val + + return results + + +def get_kibana_rules( + repo: str = "elastic/kibana", + branch: str = "master", + verbose: bool = True, + threads: int = 50, + rule_paths: list[str] = [], +) -> dict[str, Any]: """Retrieve prepackaged rules from kibana repo.""" from multiprocessing.pool import ThreadPool - kibana_rules = {} + kibana_rules: dict[str, Any] = {} if verbose: - thread_use = f' using {threads} threads' if threads > 1 else '' - click.echo(f'Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...') + thread_use = f" using {threads} threads" if threads > 1 else "" + click.echo(f"Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...") rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in rule_paths] - rules_mapping = [(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths] \ - if rule_paths else get_kibana_rules_map(repo=repo, branch=branch).items() + rules_mapping = ( + [(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths] + if rule_paths + else get_kibana_rules_map(repo=repo, branch=branch).items() + ) - def download_worker(rule_info): + def download_worker(rule_info: tuple[str, str]) -> None: n, u = rule_info kibana_rules[n] = requests.get(u).json() pool = ThreadPool(processes=threads) - pool.map(download_worker, rules_mapping) + _ = pool.map(download_worker, rules_mapping) pool.close() pool.join() @@ -259,81 +300,92 @@ def download_worker(rule_info): @cached def load_current_package_version() -> str: """Load the current package version from config file.""" - return load_etc_dump('packages.yaml')['package']['name'] + data = load_etc_dump(["packages.yaml"]) + return data["package"]["name"] -def get_default_config() -> Optional[Path]: - return next(get_path().glob('.detection-rules-cfg.*'), None) +def get_default_config() -> Path | None: + return next(ROOT_DIR.glob(".detection-rules-cfg.*"), None) @cached def parse_user_config(): """Parse a default config file.""" - import eql + import eql # type: ignore[reportMissingTypeStubs] config_file = get_default_config() config = {} if config_file and config_file.exists(): - config = eql.utils.load_dump(str(config_file)) - - click.secho(f'Loaded config file: {config_file}', fg='yellow') + config = eql.utils.load_dump(str(config_file)) # type: ignore[reportUnknownMemberType] + click.secho(f"Loaded config file: {config_file}", fg="yellow") return config -def discover_tests(start_dir: str = 'tests', pattern: str = 'test*.py', top_level_dir: Optional[str] = None): +def discover_tests(start_dir: str = "tests", pattern: str = "test*.py", top_level_dir: str | None = None) -> list[str]: """Discover all unit tests in a directory.""" - def list_tests(s, tests=None): - if tests is None: - tests = [] + + tests: list[str] = [] + + def list_tests(s: unittest.TestSuite): for test in s: if isinstance(test, unittest.TestSuite): - list_tests(test, tests) + list_tests(test) else: tests.append(test.id()) - return tests loader = unittest.defaultTestLoader suite = loader.discover(start_dir, pattern=pattern, top_level_dir=top_level_dir or str(ROOT_DIR)) - return list_tests(suite) + list_tests(suite) + return tests -def getdefault(name): +def getdefault(name: str): """Callback function for `default` to get an environment variable.""" envvar = f"DR_{name.upper()}" config = parse_user_config() return lambda: os.environ.get(envvar, config.get(name)) -def get_elasticsearch_client(cloud_id: str = None, elasticsearch_url: str = None, es_user: str = None, - es_password: str = None, ctx: click.Context = None, api_key: str = None, **kwargs): +def get_elasticsearch_client( + cloud_id: str | None = None, + elasticsearch_url: str | None = None, + es_user: str | None = None, + es_password: str | None = None, + ctx: click.Context | None = None, + api_key: str | None = None, + **kwargs: Any, +): """Get an authenticated elasticsearch client.""" from elasticsearch import AuthenticationException, Elasticsearch if not (cloud_id or elasticsearch_url): - client_error("Missing required --cloud-id or --elasticsearch-url") + raise_client_error("Missing required --cloud-id or --elasticsearch-url") # don't prompt for these until there's a cloud id or elasticsearch URL - basic_auth: (str, str) | None = None + basic_auth: tuple[str, str] | None = None if not api_key: es_user = es_user or click.prompt("es_user") es_password = es_password or click.prompt("es_password", hide_input=True) + if not es_user or not es_password: + raise ValueError("Both username and password must be provided") basic_auth = (es_user, es_password) hosts = [elasticsearch_url] if elasticsearch_url else None - timeout = kwargs.pop('timeout', 60) - kwargs['verify_certs'] = not kwargs.pop('ignore_ssl_errors', False) + timeout = kwargs.pop("timeout", 60) + kwargs["verify_certs"] = not kwargs.pop("ignore_ssl_errors", False) try: - client = Elasticsearch(hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key, - **kwargs) + client = Elasticsearch( + hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key, **kwargs + ) # force login to test auth - client.info() + _ = client.info() return client except AuthenticationException as e: - error_msg = f'Failed authentication for {elasticsearch_url or cloud_id}' - client_error(error_msg, e, ctx=ctx, err=True) + error_msg = f"Failed authentication for {elasticsearch_url or cloud_id}" + raise_client_error(error_msg, e, ctx=ctx, err=True) def get_kibana_client( @@ -343,71 +395,71 @@ def get_kibana_client( kibana_url: str | None = None, space: str | None = None, ignore_ssl_errors: bool = False, - **kwargs + **kwargs: Any, ): """Get an authenticated Kibana client.""" if not (cloud_id or kibana_url): - client_error("Missing required --cloud-id or --kibana-url") + raise_client_error("Missing required --cloud-id or --kibana-url") verify = not ignore_ssl_errors return Kibana(cloud_id=cloud_id, kibana_url=kibana_url, space=space, verify=verify, api_key=api_key, **kwargs) client_options = { - 'kibana': { - 'kibana_url': click.Option(['--kibana-url'], default=getdefault('kibana_url')), - 'cloud_id': click.Option(['--cloud-id'], default=getdefault('cloud_id'), help="ID of the cloud instance."), - 'api_key': click.Option(['--api-key'], default=getdefault('api_key')), - 'space': click.Option(['--space'], default=None, help='Kibana space'), - 'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors')) + "kibana": { + "kibana_url": click.Option(["--kibana-url"], default=getdefault("kibana_url")), + "cloud_id": click.Option(["--cloud-id"], default=getdefault("cloud_id"), help="ID of the cloud instance."), + "api_key": click.Option(["--api-key"], default=getdefault("api_key")), + "space": click.Option(["--space"], default=None, help="Kibana space"), + "ignore_ssl_errors": click.Option(["--ignore-ssl-errors"], default=getdefault("ignore_ssl_errors")), + }, + "elasticsearch": { + "cloud_id": click.Option(["--cloud-id"], default=getdefault("cloud_id")), + "api_key": click.Option(["--api-key"], default=getdefault("api_key")), + "elasticsearch_url": click.Option(["--elasticsearch-url"], default=getdefault("elasticsearch_url")), + "es_user": click.Option(["--es-user", "-eu"], default=getdefault("es_user")), + "es_password": click.Option(["--es-password", "-ep"], default=getdefault("es_password")), + "timeout": click.Option(["--timeout", "-et"], default=60, help="Timeout for elasticsearch client"), + "ignore_ssl_errors": click.Option(["--ignore-ssl-errors"], default=getdefault("ignore_ssl_errors")), }, - 'elasticsearch': { - 'cloud_id': click.Option(['--cloud-id'], default=getdefault("cloud_id")), - 'api_key': click.Option(['--api-key'], default=getdefault('api_key')), - 'elasticsearch_url': click.Option(['--elasticsearch-url'], default=getdefault("elasticsearch_url")), - 'es_user': click.Option(['--es-user', '-eu'], default=getdefault("es_user")), - 'es_password': click.Option(['--es-password', '-ep'], default=getdefault("es_password")), - 'timeout': click.Option(['--timeout', '-et'], default=60, help='Timeout for elasticsearch client'), - 'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors')) - } } -kibana_options = list(client_options['kibana'].values()) -elasticsearch_options = list(client_options['elasticsearch'].values()) +kibana_options = list(client_options["kibana"].values()) +elasticsearch_options = list(client_options["elasticsearch"].values()) -def add_client(*client_type, add_to_ctx=True, add_func_arg=True): +def add_client(client_types: list[str], add_to_ctx: bool = True, add_func_arg: bool = True): """Wrapper to add authed client.""" from elasticsearch import Elasticsearch from elasticsearch.exceptions import AuthenticationException - from kibana import Kibana + from kibana import Kibana # type: ignore[reportMissingTypeStubs] - def _wrapper(func): - client_ops_dict = {} - client_ops_keys = {} - for c_type in client_type: - ops = client_options.get(c_type) + def _wrapper(func: Callable[..., Any]): + client_ops_dict: dict[str, click.Option] = {} + client_ops_keys: dict[str, list[str]] = {} + for c_type in client_types: + ops = client_options[c_type] client_ops_dict.update(ops) client_ops_keys[c_type] = list(ops) if not client_ops_dict: - raise ValueError(f'Unknown client: {client_type} in {func.__name__}') + client_types_str = ", ".join(client_types) + raise ValueError(f"Unknown client: {client_types_str} in {func.__name__}") client_ops = list(client_ops_dict.values()) @wraps(func) @add_params(*client_ops) - def _wrapped(*args, **kwargs): - ctx: click.Context = next((a for a in args if isinstance(a, click.Context)), None) - es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get('elasticsearch', [])} + def _wrapped(*args: Any, **kwargs: Any): + ctx: click.Context | None = next((a for a in args if isinstance(a, click.Context)), None) + es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get("elasticsearch", [])} # shared args like cloud_id - kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get('kibana', [])} + kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get("kibana", [])} - if 'elasticsearch' in client_type: + if "elasticsearch" in client_types: # for nested ctx invocation, no need to re-auth if an existing client is already passed - elasticsearch_client: Elasticsearch = kwargs.get('elasticsearch_client') + elasticsearch_client: Elasticsearch | None = kwargs.get("elasticsearch_client") try: - if elasticsearch_client and isinstance(elasticsearch_client, Elasticsearch) and \ - elasticsearch_client.info(): + if elasticsearch_client and elasticsearch_client.info(): pass else: elasticsearch_client = get_elasticsearch_client(**es_client_args) @@ -415,15 +467,14 @@ def _wrapped(*args, **kwargs): elasticsearch_client = get_elasticsearch_client(**es_client_args) if add_func_arg: - kwargs['elasticsearch_client'] = elasticsearch_client + kwargs["elasticsearch_client"] = elasticsearch_client if ctx and add_to_ctx: - ctx.obj['es'] = elasticsearch_client + ctx.obj["es"] = elasticsearch_client - if 'kibana' in client_type: + if "kibana" in client_types: # for nested ctx invocation, no need to re-auth if an existing client is already passed - kibana_client: Kibana = kwargs.get('kibana_client') - if kibana_client and isinstance(kibana_client, Kibana): - + kibana_client: Kibana | None = kwargs.get("kibana_client") + if kibana_client: try: with kibana_client: if kibana_client.version: @@ -435,9 +486,9 @@ def _wrapped(*args, **kwargs): kibana_client = get_kibana_client(**kibana_client_args) if add_func_arg: - kwargs['kibana_client'] = kibana_client + kwargs["kibana_client"] = kibana_client if ctx and add_to_ctx: - ctx.obj['kibana'] = kibana_client + ctx.obj["kibana"] = kibana_client return func(*args, **kwargs) diff --git a/detection_rules/mixins.py b/detection_rules/mixins.py index b22677d2920..cf377f796ff 100644 --- a/detection_rules/mixins.py +++ b/detection_rules/mixins.py @@ -7,13 +7,13 @@ import dataclasses from pathlib import Path -from typing import Any, Optional, TypeVar, Type, Literal +from typing import Any, Literal import json import marshmallow_dataclass import marshmallow_dataclass.union_field -import marshmallow_jsonschema -import marshmallow_union +import marshmallow_jsonschema # type: ignore[reportMissingTypeStubs] +import marshmallow_union # type: ignore[reportMissingTypeStubs] import marshmallow from marshmallow import Schema, ValidationError, validates_schema, fields as marshmallow_fields @@ -23,26 +23,24 @@ from semver import Version from .utils import cached, dict_hash -T = TypeVar('T') -ClassT = TypeVar('ClassT') # bound=dataclass? -UNKNOWN_VALUES = Literal['raise', 'exclude', 'include'] +UNKNOWN_VALUES = Literal["raise", "exclude", "include"] -def _strip_none_from_dict(obj: T) -> T: +def _strip_none_from_dict(obj: Any) -> Any: """Strip none values from a dict recursively.""" if isinstance(obj, dict): - return {key: _strip_none_from_dict(value) for key, value in obj.items() if value is not None} - if isinstance(obj, list): - return [_strip_none_from_dict(o) for o in obj] - if isinstance(obj, tuple): - return tuple(_strip_none_from_dict(list(obj))) + return {key: _strip_none_from_dict(value) for key, value in obj.items() if value is not None} # type: ignore[reportUnknownVariableType] + elif isinstance(obj, list): + return [_strip_none_from_dict(o) for o in obj] # type: ignore[reportUnknownVariableType] + elif isinstance(obj, tuple): + return tuple(_strip_none_from_dict(list(obj))) # type: ignore[reportUnknownVariableType] return obj -def patch_jsonschema(obj: dict) -> dict: +def patch_jsonschema(obj: Any) -> dict[str, Any]: """Patch marshmallow-jsonschema output to look more like JSL.""" - def dive(child: dict) -> dict: + def dive(child: dict[str, Any]) -> dict[str, Any]: if "$ref" in child: name = child["$ref"].split("/")[-1] definition = obj["definitions"][name] @@ -58,10 +56,12 @@ def dive(child: dict) -> dict: child["anyOf"] = [dive(c) for c in child["anyOf"]] elif isinstance(child["type"], list): - if 'null' in child["type"]: - child["type"] = [t for t in child["type"] if t != 'null'] + type_vals: list[str] = child["type"] # type: ignore[reportUnknownVariableType] - if len(child["type"]) == 1: + if "null" in type_vals: + child["type"] = [t for t in type_vals if t != "null"] + + if len(type_vals) == 1: child["type"] = child["type"][0] if "items" in child: @@ -86,29 +86,9 @@ def dive(child: dict) -> dict: class BaseSchema(Schema): """Base schema for marshmallow dataclasses with unknown.""" - class Meta: - """Meta class for marshmallow schema.""" - - -def exclude_class_schema( - clazz, base_schema: type[Schema] = BaseSchema, unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, **kwargs -) -> type[Schema]: - """Get a marshmallow schema for a dataclass with unknown=EXCLUDE.""" - base_schema.Meta.unknown = unknown - return marshmallow_dataclass.class_schema(clazz, base_schema=base_schema, **kwargs) - -def recursive_class_schema( - clazz, base_schema: type[Schema] = BaseSchema, unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, **kwargs -) -> type[Schema]: - """Recursively apply the unknown parameter for nested schemas.""" - schema = exclude_class_schema(clazz, base_schema=base_schema, unknown=unknown, **kwargs) - for field in dataclasses.fields(clazz): - if dataclasses.is_dataclass(field.type): - nested_cls = field.type - nested_schema = recursive_class_schema(nested_cls, base_schema=base_schema, **kwargs) - setattr(schema, field.name, nested_schema) - return schema + class Meta: # type: ignore[reportIncompatibleVariableOverride] + """Meta class for marshmallow schema.""" class MarshmallowDataclassMixin: @@ -116,14 +96,14 @@ class MarshmallowDataclassMixin: @classmethod @cached - def __schema(cls: ClassT, unknown: Optional[UNKNOWN_VALUES] = None) -> Schema: + def __schema(cls, unknown: UNKNOWN_VALUES | None = None) -> Schema: """Get the marshmallow schema for the data class""" if unknown: return recursive_class_schema(cls, unknown=unknown)() else: return marshmallow_dataclass.class_schema(cls)() - def get(self, key: str, default: Optional[Any] = None): + def get(self, key: str, default: Any | None = None): """Get a key from the query data without raising attribute errors.""" return getattr(self, key, default) @@ -131,20 +111,20 @@ def get(self, key: str, default: Optional[Any] = None): @cached def jsonschema(cls): """Get the jsonschema representation for this class.""" - jsonschema = PatchedJSONSchema().dump(cls.__schema()) + jsonschema = PatchedJSONSchema().dump(cls.__schema()) # type: ignore[reportUnknownMemberType] jsonschema = patch_jsonschema(jsonschema) return jsonschema @classmethod - def from_dict(cls: Type[ClassT], obj: dict, unknown: Optional[UNKNOWN_VALUES] = None) -> ClassT: + def from_dict(cls, obj: dict[str, Any], unknown: UNKNOWN_VALUES | None = None) -> Any: """Deserialize and validate a dataclass from a dict using marshmallow.""" schema = cls.__schema(unknown=unknown) return schema.load(obj) - def to_dict(self, strip_none_values=True) -> dict: + def to_dict(self, strip_none_values: bool = True) -> dict[str, Any]: """Serialize a dataclass to a dictionary using marshmallow.""" schema = self.__schema() - serialized: dict = schema.dump(self) + serialized = schema.dump(self) if strip_none_values: serialized = _strip_none_from_dict(serialized) @@ -152,44 +132,78 @@ def to_dict(self, strip_none_values=True) -> dict: return serialized +def exclude_class_schema( + cls: type, + base_schema: type[Schema] = BaseSchema, + unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, + **kwargs: dict[str, Any], +) -> type[Schema]: + """Get a marshmallow schema for a dataclass with unknown=EXCLUDE.""" + base_schema.Meta.unknown = unknown # type: ignore[reportAttributeAccessIssue] + return marshmallow_dataclass.class_schema(cls, base_schema=base_schema, **kwargs) + + +def recursive_class_schema( + cls: type, + base_schema: type[Schema] = BaseSchema, + unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, + **kwargs: dict[str, Any], +) -> type[Schema]: + """Recursively apply the unknown parameter for nested schemas.""" + schema = exclude_class_schema(cls, base_schema=base_schema, unknown=unknown, **kwargs) + for field in dataclasses.fields(cls): + if dataclasses.is_dataclass(field.type): + nested_cls = field.type + nested_schema = recursive_class_schema( + nested_cls, # type: ignore[reportArgumentType] + base_schema=base_schema, + unknown=unknown, + **kwargs, + ) + setattr(schema, field.name, nested_schema) + return schema + + class LockDataclassMixin: """Mixin class for version and deprecated rules lock files.""" @classmethod @cached - def __schema(cls: ClassT) -> Schema: + def __schema(cls) -> Schema: """Get the marshmallow schema for the data class""" return marshmallow_dataclass.class_schema(cls)() - def get(self, key: str, default: Optional[Any] = None): + def get(self, key: str, default: Any = None): """Get a key from the query data without raising attribute errors.""" return getattr(self, key, default) @classmethod - def from_dict(cls: Type[ClassT], obj: dict) -> ClassT: + def from_dict(cls, obj: dict[str, Any]) -> Any: """Deserialize and validate a dataclass from a dict using marshmallow.""" schema = cls.__schema() try: loaded = schema.load(obj) except ValidationError as e: - err_msg = json.dumps(e.messages, indent=2) - raise ValidationError(f'Validation error loading: {cls.__name__}\n{err_msg}') from None + err_msg = json.dumps(e.normalized_messages(), indent=2) + raise ValidationError(f"Validation error loading: {cls.__name__}\n{err_msg}") return loaded - def to_dict(self, strip_none_values=True) -> dict: + def to_dict(self, strip_none_values: bool = True) -> dict[str, Any]: """Serialize a dataclass to a dictionary using marshmallow.""" schema = self.__schema() - serialized: dict = schema.dump(self) + serialized: dict[str, Any] = schema.dump(self) if strip_none_values: serialized = _strip_none_from_dict(serialized) - return serialized['data'] + return serialized["data"] @classmethod - def load_from_file(cls: Type[ClassT], lock_file: Optional[Path] = None) -> ClassT: + def load_from_file(cls, lock_file: Path | None = None) -> Any: """Load and validate a version lock file.""" - path: Path = getattr(cls, 'file_path', lock_file) + path = getattr(cls, "file_path", lock_file) + if not path: + raise ValueError("No file path found") contents = json.loads(path.read_text()) loaded = cls.from_dict(dict(data=contents)) return loaded @@ -199,22 +213,23 @@ def sha256(self) -> definitions.Sha256: contents = self.to_dict() return dict_hash(contents) - def save_to_file(self, lock_file: Optional[Path] = None): + def save_to_file(self, lock_file: Path | None = None): """Save and validate a version lock file.""" - path: Path = lock_file or getattr(self, 'file_path', None) - assert path, 'No path passed or set' + path = lock_file or getattr(self, "file_path", None) + if not path: + raise ValueError("No file path found") contents = self.to_dict() - path.write_text(json.dumps(contents, indent=2, sort_keys=True)) + _ = path.write_text(json.dumps(contents, indent=2, sort_keys=True)) class StackCompatMixin: """Mixin to restrict schema compatibility to defined stack versions.""" @validates_schema - def validate_field_compatibility(self, data: dict, **kwargs): + def validate_field_compatibility(self, data: dict[str, Any], **_: dict[str, Any]): """Verify stack-specific fields are properly applied to schema.""" package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - schema_fields = getattr(self, 'fields', {}) + schema_fields = getattr(self, "fields", {}) incompatible = get_incompatible_fields(list(schema_fields.values()), package_version) if not incompatible: return @@ -223,14 +238,15 @@ def validate_field_compatibility(self, data: dict, **kwargs): for field, bounds in incompatible.items(): min_compat, max_compat = bounds if data.get(field) is not None: - raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, ' - f'min compatibility: {min_compat}, max compatibility: {max_compat}') + raise ValidationError( + f'Invalid field: "{field}" for stack version: {package_version}, ' + f"min compatibility: {min_compat}, max compatibility: {max_compat}" + ) class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema): - # Patch marshmallow-jsonschema to support marshmallow-dataclass[union] - def _get_schema_for_field(self, obj, field): + def _get_schema_for_field(self, obj: Any, field: Any) -> Any: """Patch marshmallow_jsonschema.base.JSONSchema to support marshmallow-dataclass[union].""" if isinstance(field, marshmallow_fields.Raw) and field.allow_none and not field.validate: # raw fields shouldn't be type string but type any. bug in marshmallow_dataclass:__init__.py: @@ -241,11 +257,17 @@ def _get_schema_for_field(self, obj, field): if isinstance(field, marshmallow_dataclass.union_field.Union): # convert to marshmallow_union.Union - field = marshmallow_union.Union([subfield for _, subfield in field.union_fields], - metadata=field.metadata, - required=field.required, name=field.name, - parent=field.parent, root=field.root, error_messages=field.error_messages, - default_error_messages=field.default_error_messages, default=field.default, - allow_none=field.allow_none) - - return super()._get_schema_for_field(obj, field) + field = marshmallow_union.Union( + [subfield for _, subfield in field.union_fields], + metadata=field.metadata, # type: ignore[reportUnknownMemberType] + required=field.required, + name=field.name, # type: ignore[reportUnknownMemberType] + parent=field.parent, # type: ignore[reportUnknownMemberType] + root=field.root, # type: ignore[reportUnknownMemberType] + error_messages=field.error_messages, + default_error_messages=field.default_error_messages, + default=field.default, # type: ignore[reportUnknownMemberType] + allow_none=field.allow_none, + ) + + return super()._get_schema_for_field(obj, field) # type: ignore[reportUnknownMemberType] diff --git a/detection_rules/ml.py b/detection_rules/ml.py index 13573ae51d4..9a02249d42d 100644 --- a/detection_rules/ml.py +++ b/detection_rules/ml.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from functools import cached_property, lru_cache from pathlib import Path -from typing import Dict, List, Literal, Optional +from typing import Any, Literal import click import elasticsearch @@ -24,14 +24,20 @@ from .utils import get_path, unzip_to_dict -ML_PATH = get_path('machine-learning') +ML_PATH = get_path(["machine-learning"]) -def info_from_tag(tag: str) -> (Literal['ml'], definitions.MachineLearningType, str, int): +def info_from_tag(tag: str) -> tuple[Literal["ml"], str, str, int]: try: - ml, release_type, release_date, release_number = tag.split('-') + ml, release_type, release_date, release_number = tag.split("-") except ValueError as exc: - raise ValueError(f'{tag} is not of valid release format: ml-type-date-number. {exc}') + raise ValueError(f"{tag} is not of valid release format: ml-type-date-number. {exc}") + + if ml != "ml": + raise ValueError(f"Invalid type from the tag: {ml}") + + if release_type not in definitions.MACHINE_LEARNING_PACKAGES: + raise ValueError(f"Unexpected release type encountered: {release_type}") return ml, release_type, release_date, int(release_number) @@ -45,15 +51,15 @@ class MachineLearningClient: """Class for experimental machine learning release clients.""" es_client: Elasticsearch - bundle: dict + bundle: dict[str, Any] @cached_property def model_id(self) -> str: - return next(data['model_id'] for name, data in self.bundle.items() if Path(name).stem.lower().endswith('model')) + return next(data["model_id"] for name, data in self.bundle.items() if Path(name).stem.lower().endswith("model")) @cached_property def bundle_type(self) -> str: - return self.model_id.split('_')[0].lower() + return self.model_id.split("_")[0].lower() @cached_property def ml_client(self) -> MlClient: @@ -66,36 +72,41 @@ def ingest_client(self) -> IngestClient: @cached_property def license(self) -> str: license_client = LicenseClient(self.es_client) - return license_client.get()['license']['type'].lower() + return license_client.get()["license"]["type"].lower() @staticmethod @lru_cache - def ml_manifests() -> Dict[str, ReleaseManifest]: + def ml_manifests() -> dict[str, ReleaseManifest]: return get_ml_model_manifests_by_model_id() def verify_license(self): - valid_license = self.license in ('platinum', 'enterprise') + valid_license = self.license in ("platinum", "enterprise") if not valid_license: - err_msg = 'Your subscription level does not support Machine Learning. See ' \ - 'https://www.elastic.co/subscriptions for more information.' + err_msg = ( + "Your subscription level does not support Machine Learning. See " + "https://www.elastic.co/subscriptions for more information." + ) raise InvalidLicenseError(err_msg) @classmethod - def from_release(cls, es_client: Elasticsearch, release_tag: str, - repo: str = 'elastic/detection-rules') -> 'MachineLearningClient': + def from_release( + cls, es_client: Elasticsearch, release_tag: str, repo: str = "elastic/detection-rules" + ) -> "MachineLearningClient": """Load from a GitHub release.""" - full_type = '-'.join(info_from_tag(release_tag)[:2]) - release_url = f'https://api.github.com/repos/{repo}/releases/tags/{release_tag}' + + ml, release_type, _, _ = info_from_tag(release_tag) + + full_type = "-".join([ml, release_type]) + release_url = f"https://api.github.com/repos/{repo}/releases/tags/{release_tag}" release = requests.get(release_url) release.raise_for_status() # check that the release only has a single zip file - assets = [a for a in release.json()['assets'] if - a['name'].startswith(full_type) and a['name'].endswith('.zip')] - assert len(assets) == 1, f'Malformed release: expected 1 {full_type} zip file, found: {len(assets)}!' + assets = [a for a in release.json()["assets"] if a["name"].startswith(full_type) and a["name"].endswith(".zip")] + assert len(assets) == 1, f"Malformed release: expected 1 {full_type} zip file, found: {len(assets)}!" - zipped_url = assets[0]['browser_download_url'] + zipped_url = assets[0]["browser_download_url"] zipped_raw = requests.get(zipped_url) zipped_bundle = zipfile.ZipFile(io.BytesIO(zipped_raw.content)) bundle = unzip_to_dict(zipped_bundle) @@ -103,23 +114,23 @@ def from_release(cls, es_client: Elasticsearch, release_tag: str, return cls(es_client=es_client, bundle=bundle) @classmethod - def from_directory(cls, es_client: Elasticsearch, directory: Path) -> 'MachineLearningClient': + def from_directory(cls, es_client: Elasticsearch, directory: Path) -> "MachineLearningClient": """Load from an unzipped local directory.""" bundle = json.loads(directory.read_text()) return cls(es_client=es_client, bundle=bundle) - def remove(self) -> dict: + def remove(self) -> dict[str, Any]: """Remove machine learning files from a stack.""" results = dict(script={}, pipeline={}, model={}) for pipeline in list(self.get_related_pipelines()): - results['pipeline'][pipeline] = self.ingest_client.delete_pipeline(pipeline) + results["pipeline"][pipeline] = self.ingest_client.delete_pipeline(id=pipeline) for script in list(self.get_related_scripts()): - results['script'][script] = self.es_client.delete_script(script) + results["script"][script] = self.es_client.delete_script(id=script) - results['model'][self.model_id] = self.ml_client.delete_trained_model(self.model_id) + results["model"][self.model_id] = self.ml_client.delete_trained_model(model_id=self.model_id) return results - def setup(self) -> dict: + def setup(self) -> dict[str, Any]: """Setup machine learning bundle on a stack.""" self.verify_license() results = dict(script={}, pipeline={}, model={}) @@ -128,89 +139,95 @@ def setup(self) -> dict: parsed_bundle = dict(model={}, script={}, pipeline={}) for filename, data in self.bundle.items(): fp = Path(filename) - file_type = fp.stem.split('_')[-1] + file_type = fp.stem.split("_")[-1] parsed_bundle[file_type][fp.stem] = data - model = list(parsed_bundle['model'].values())[0] - results['model'][model['model_id']] = self.upload_model(model['model_id'], model) + model = list(parsed_bundle["model"].values())[0] + results["model"][model["model_id"]] = self.upload_model(model["model_id"], model) - for script_name, script in parsed_bundle['script'].items(): - results['script'][script_name] = self.upload_script(script_name, script) + for script_name, script in parsed_bundle["script"].items(): + results["script"][script_name] = self.upload_script(script_name, script) - for pipeline_name, pipeline in parsed_bundle['pipeline'].items(): - results['pipeline'][pipeline_name] = self.upload_ingest_pipeline(pipeline_name, pipeline) + for pipeline_name, pipeline in parsed_bundle["pipeline"].items(): + results["pipeline"][pipeline_name] = self.upload_ingest_pipeline(pipeline_name, pipeline) return results - def get_all_scripts(self) -> Dict[str, dict]: + def get_all_scripts(self) -> dict[str, dict[str, Any]]: """Get all scripts from an elasticsearch instance.""" - return self.es_client.cluster.state()['metadata']['stored_scripts'] + return self.es_client.cluster.state()["metadata"]["stored_scripts"] - def get_related_scripts(self) -> Dict[str, dict]: + def get_related_scripts(self) -> dict[str, dict[str, Any]]: """Get all scripts which start with ml_*.""" scripts = self.get_all_scripts() - return {n: s for n, s in scripts.items() if n.lower().startswith(f'ml_{self.bundle_type}')} + return {n: s for n, s in scripts.items() if n.lower().startswith(f"ml_{self.bundle_type}")} - def get_related_pipelines(self) -> Dict[str, dict]: + def get_related_pipelines(self) -> dict[str, dict[str, Any]]: """Get all pipelines which start with ml_*.""" pipelines = self.ingest_client.get_pipeline() - return {n: s for n, s in pipelines.items() if n.lower().startswith(f'ml_{self.bundle_type}')} + return {n: s for n, s in pipelines.items() if n.lower().startswith(f"ml_{self.bundle_type}")} - def get_related_model(self) -> Optional[dict]: + def get_related_model(self) -> dict[str, Any] | None: """Get a model from an elasticsearch instance matching the model_id.""" for model in self.get_all_existing_model_files(): - if model['model_id'] == self.model_id: + if model["model_id"] == self.model_id: return model - def get_all_existing_model_files(self) -> dict: + def get_all_existing_model_files(self) -> list[dict[str, Any]]: """Get available models from a stack.""" - return self.ml_client.get_trained_models()['trained_model_configs'] + return self.ml_client.get_trained_models()["trained_model_configs"] @classmethod - def get_existing_model_ids(cls, es_client: Elasticsearch) -> List[str]: + def get_existing_model_ids(cls, es_client: Elasticsearch) -> list[str]: """Get model IDs for existing ML models.""" ml_client = MlClient(es_client) - return [m['model_id'] for m in ml_client.get_trained_models()['trained_model_configs'] - if m['model_id'] in cls.ml_manifests()] + return [ + m["model_id"] + for m in ml_client.get_trained_models()["trained_model_configs"] + if m["model_id"] in cls.ml_manifests() + ] @classmethod def check_model_exists(cls, es_client: Elasticsearch, model_id: str) -> bool: """Check if a model exists on a stack by model id.""" ml_client = MlClient(es_client) - return model_id in [m['model_id'] for m in ml_client.get_trained_models()['trained_model_configs']] + return model_id in [m["model_id"] for m in ml_client.get_trained_models()["trained_model_configs"]] - def get_related_files(self) -> dict: + def get_related_files(self) -> dict[str, Any]: """Check for the presence and status of ML bundle files on a stack.""" files = { - 'pipeline': self.get_related_pipelines(), - 'script': self.get_related_scripts(), - 'model': self.get_related_model(), - 'release': self.get_related_release() + "pipeline": self.get_related_pipelines(), + "script": self.get_related_scripts(), + "model": self.get_related_model(), + "release": self.get_related_release(), } return files def get_related_release(self) -> ReleaseManifest: """Get the GitHub release related to a model.""" - return self.ml_manifests.get(self.model_id) + return self.ml_manifests.get(self.model_id) # type: ignore[reportAttributeAccessIssue] @classmethod - def get_all_ml_files(cls, es_client: Elasticsearch) -> dict: + def get_all_ml_files(cls, es_client: Elasticsearch) -> dict[str, Any]: """Get all scripts, pipelines, and models which start with ml_*.""" pipelines = IngestClient(es_client).get_pipeline() - scripts = es_client.cluster.state()['metadata']['stored_scripts'] - models = MlClient(es_client).get_trained_models()['trained_model_configs'] + scripts = es_client.cluster.state()["metadata"]["stored_scripts"] + models = MlClient(es_client).get_trained_models()["trained_model_configs"] manifests = get_ml_model_manifests_by_model_id() files = { - 'pipeline': {n: s for n, s in pipelines.items() if n.lower().startswith('ml_')}, - 'script': {n: s for n, s in scripts.items() if n.lower().startswith('ml_')}, - 'model': {m['model_id']: {'model': m, 'release': manifests[m['model_id']]} - for m in models if m['model_id'] in manifests}, + "pipeline": {n: s for n, s in pipelines.items() if n.lower().startswith("ml_")}, + "script": {n: s for n, s in scripts.items() if n.lower().startswith("ml_")}, + "model": { + m["model_id"]: {"model": m, "release": manifests[m["model_id"]]} + for m in models + if m["model_id"] in manifests + }, } return files @classmethod - def remove_ml_scripts_pipelines(cls, es_client: Elasticsearch, ml_type: List[str]) -> dict: + def remove_ml_scripts_pipelines(cls, es_client: Elasticsearch, ml_type: list[str]) -> dict[str, Any]: """Remove all ML script and pipeline files.""" results = dict(script={}, pipeline={}) ingest_client = IngestClient(es_client) @@ -218,52 +235,52 @@ def remove_ml_scripts_pipelines(cls, es_client: Elasticsearch, ml_type: List[str files = cls.get_all_ml_files(es_client=es_client) for file_type, data in files.items(): for name in list(data): - this_type = name.split('_')[1].lower() + this_type = name.split("_")[1].lower() if this_type not in ml_type: continue - if file_type == 'script': - results[file_type][name] = es_client.delete_script(name) - elif file_type == 'pipeline': - results[file_type][name] = ingest_client.delete_pipeline(name) + if file_type == "script": + results[file_type][name] = es_client.delete_script(id=name) + elif file_type == "pipeline": + results[file_type][name] = ingest_client.delete_pipeline(id=name) return results - def upload_model(self, model_id: str, body: dict) -> dict: + def upload_model(self, model_id: str, body: dict[str, Any]): """Upload an ML model file.""" return self.ml_client.put_trained_model(model_id=model_id, body=body) - def upload_script(self, script_id: str, body: dict) -> dict: + def upload_script(self, script_id: str, body: dict[str, Any]): """Install a script file.""" return self.es_client.put_script(id=script_id, body=body) - def upload_ingest_pipeline(self, pipeline_id: str, body: dict) -> dict: + def upload_ingest_pipeline(self, pipeline_id: str, body: dict[str, Any]): """Install a pipeline file.""" return self.ingest_client.put_pipeline(id=pipeline_id, body=body) @staticmethod def _build_script_error(exc: elasticsearch.RequestError, pipeline_file: str): """Build an error for a failed script upload.""" - error = exc.info['error'] - cause = error['caused_by'] + error = exc.info["error"] + cause = error["caused_by"] error_msg = [ - f'Script error while uploading {pipeline_file}: {cause["type"]} - {cause["reason"]}', - ' '.join(f'{k}: {v}' for k, v in error['position'].items()), - '\n'.join(error['script_stack']) + f"Script error while uploading {pipeline_file}: {cause['type']} - {cause['reason']}", + " ".join(f"{k}: {v}" for k, v in error["position"].items()), + "\n".join(error["script_stack"]), ] - return click.style('\n'.join(error_msg), fg='red') + return click.style("\n".join(error_msg), fg="red") -def get_ml_model_manifests_by_model_id(repo: str = 'elastic/detection-rules') -> Dict[str, ReleaseManifest]: +def get_ml_model_manifests_by_model_id(repo_name: str = "elastic/detection-rules") -> dict[str, ReleaseManifest]: """Load all ML DGA model release manifests by model id.""" - manifests, _ = ManifestManager.load_all(repo=repo) - model_manifests = {} - - for manifest_name, manifest in manifests.items(): - for asset_name, asset in manifest['assets'].items(): - for entry_name, entry_data in asset['entries'].items(): - if entry_name.startswith('dga') and entry_name.endswith('model.json'): - model_id, _ = entry_name.rsplit('_', 1) + manifests, _ = ManifestManager.load_all(repo_name=repo_name) + model_manifests: dict[str, ReleaseManifest] = {} + + for _, manifest in manifests.items(): + for _, asset in manifest["assets"].items(): + for entry_name, _ in asset["entries"].items(): + if entry_name.startswith("dga") and entry_name.endswith("model.json"): + model_id, _ = entry_name.rsplit("_", 1) model_manifests[model_id] = ReleaseManifest(**manifest) break diff --git a/detection_rules/navigator.py b/detection_rules/navigator.py index 125ab1bbee0..9ace9be0180 100644 --- a/detection_rules/navigator.py +++ b/detection_rules/navigator.py @@ -7,9 +7,9 @@ from functools import reduce from collections import defaultdict -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional +from typing import Any from marshmallow import pre_load import json @@ -31,17 +31,15 @@ "Office 365", "PRE", "SaaS", - "Windows" + "Windows", ] -_DEFAULT_NAVIGATOR_LINKS = { - "label": "repo", - "url": "https://github.com/elastic/detection-rules" -} +_DEFAULT_NAVIGATOR_LINKS = {"label": "repo", "url": "https://github.com/elastic/detection-rules"} @dataclass class NavigatorMetadata(MarshmallowDataclassMixin): """Metadata for ATT&CK navigator objects.""" + name: str value: str @@ -49,6 +47,7 @@ class NavigatorMetadata(MarshmallowDataclassMixin): @dataclass class NavigatorLinks(MarshmallowDataclassMixin): """Metadata for ATT&CK navigator objects.""" + label: str url: str @@ -56,40 +55,42 @@ class NavigatorLinks(MarshmallowDataclassMixin): @dataclass class Techniques(MarshmallowDataclassMixin): """ATT&CK navigator techniques array class.""" + techniqueID: str tactic: str score: int - metadata: List[NavigatorMetadata] - links: List[NavigatorLinks] + metadata: list[NavigatorMetadata] + links: list[NavigatorLinks] - color: str = '' - comment: str = '' + color: str = "" + comment: str = "" enabled: bool = True showSubtechniques: bool = False @pre_load - def set_score(self, data: dict, **kwargs): - data['score'] = len(data['metadata']) + def set_score(self, data: dict[str, Any], **_: Any): + data["score"] = len(data["metadata"]) return data @dataclass class Navigator(MarshmallowDataclassMixin): """ATT&CK navigator class.""" + @dataclass class Versions: attack: str - layer: str = '4.4' - navigator: str = '4.5.5' + layer: str = "4.4" + navigator: str = "4.5.5" @dataclass class Filters: - platforms: list = field(default_factory=_DEFAULT_PLATFORMS.copy) + platforms: list[str] = field(default_factory=_DEFAULT_PLATFORMS.copy) @dataclass class Layout: - layout: str = 'side' - aggregateFunction: str = 'average' + layout: str = "side" + aggregateFunction: str = "average" showID: bool = True showName: bool = True showAggregateScores: bool = False @@ -97,116 +98,115 @@ class Layout: @dataclass class Gradient: - colors: list = field(default_factory=['#d3e0fa', '#0861fb'].copy) + colors: list[str] = field(default_factory=["#d3e0fa", "#0861fb"].copy) minValue: int = 0 maxValue: int = 10 # not all defaults set name: str versions: Versions - techniques: List[Techniques] + techniques: list[Techniques] # all defaults set - filters: Filters = fields(Filters) - layout: Layout = fields(Layout) - gradient: Gradient = fields(Gradient) + filters: Filters = field(default_factory=Filters) + layout: Layout = field(default_factory=Layout) + gradient: Gradient = field(default_factory=Gradient) - domain: str = 'enterprise-attack' - description: str = 'Elastic detection-rules coverage' + domain: str = "enterprise-attack" + description: str = "Elastic detection-rules coverage" hideDisabled: bool = False - legendItems: list = field(default_factory=list) - links: List[NavigatorLinks] = field(default_factory=[_DEFAULT_NAVIGATOR_LINKS].copy) - metadata: Optional[List[NavigatorLinks]] = field(default_factory=list) + legendItems: list[Any] = field(default_factory=list) # type: ignore[reportUnknownVariableType] + + links: list[NavigatorLinks] = field(default_factory=[_DEFAULT_NAVIGATOR_LINKS].copy) # type: ignore[reportAssignmentType] + metadata: list[NavigatorLinks] | None = field(default_factory=list) # type: ignore[reportAssignmentType] showTacticRowBackground: bool = False selectTechniquesAcrossTactics: bool = False selectSubtechniquesWithParent: bool = False sorting: int = 0 - tacticRowBackground: str = '#dddddd' + tacticRowBackground: str = "#dddddd" -def technique_dict() -> dict: - return {'metadata': [], 'links': []} +def technique_dict() -> dict[str, Any]: + return {"metadata": [], "links": []} class NavigatorBuilder: """Rule navigator mappings and management.""" - def __init__(self, detection_rules: List[TOMLRule]): + def __init__(self, detection_rules: list[TOMLRule]): self.detection_rules = detection_rules - self.layers = { - 'all': defaultdict(lambda: defaultdict(technique_dict)), - 'platforms': defaultdict(lambda: defaultdict(technique_dict)), - + self.layers: dict[str, Any] = { + "all": defaultdict(lambda: defaultdict(technique_dict)), # type: ignore[reportUnknownLambdaType] + "platforms": defaultdict(lambda: defaultdict(technique_dict)), # type: ignore[reportUnknownLambdaType] # these will build multiple layers - 'indexes': defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))), - 'tags': defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))) + "indexes": defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))), # type: ignore[reportUnknownLambdaType] + "tags": defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))), # type: ignore[reportUnknownLambdaType] } self.process_rules() @staticmethod - def meta_dict(name: str, value: any) -> dict: - meta = { - 'name': name, - 'value': value - } + def meta_dict(name: str, value: Any) -> dict[str, Any]: + meta = {"name": name, "value": value} return meta @staticmethod - def links_dict(label: str, url: any) -> dict: - links = { - 'label': label, - 'url': url - } + def links_dict(label: str, url: Any) -> dict[str, Any]: + links = {"label": label, "url": url} return links - def rule_links_dict(self, rule: TOMLRule) -> dict: + def rule_links_dict(self, rule: TOMLRule) -> dict[str, Any]: """Create a links dictionary for a rule.""" - base_url = 'https://github.com/elastic/detection-rules/blob/main/rules/' - base_path = str(rule.get_base_rule_dir()) + base_url = "https://github.com/elastic/detection-rules/blob/main/rules/" + base_path = rule.get_base_rule_dir() - if base_path is None: + if not base_path: raise ValueError("Could not find a valid base path for the rule") - url = f'{base_url}{base_path}' + base_path_str = str(base_path) + url = f"{base_url}{base_path_str}" return self.links_dict(rule.name, url) - def get_layer(self, layer_name: str, layer_key: Optional[str] = None) -> dict: + def get_layer(self, layer_name: str, layer_key: str | None = None) -> dict[str, Any]: """Safely retrieve a layer with optional sub-keys.""" return self.layers[layer_name][layer_key] if layer_key else self.layers[layer_name] def _update_all(self, rule: TOMLRule, tactic: str, technique_id: str): - value = f'{rule.contents.data.type}/{rule.contents.data.get("language")}' - self.add_rule_to_technique(rule, 'all', tactic, technique_id, value) + value = f"{rule.contents.data.type}/{rule.contents.data.get('language')}" + self.add_rule_to_technique(rule, "all", tactic, technique_id, value) def _update_platforms(self, rule: TOMLRule, tactic: str, technique_id: str): + if not rule.path: + raise ValueError("No rule path found") value = rule.path.parent.name - self.add_rule_to_technique(rule, 'platforms', tactic, technique_id, value) + self.add_rule_to_technique(rule, "platforms", tactic, technique_id, value) def _update_indexes(self, rule: TOMLRule, tactic: str, technique_id: str): - for index in rule.contents.data.get('index') or []: + for index in rule.contents.data.get("index") or []: # type: ignore[reportUnknownVariableType] value = rule.id - self.add_rule_to_technique(rule, 'indexes', tactic, technique_id, value, layer_key=index.lower()) + self.add_rule_to_technique(rule, "indexes", tactic, technique_id, value, layer_key=index.lower()) # type: ignore[reportUnknownVariableType] def _update_tags(self, rule: TOMLRule, tactic: str, technique_id: str): - for tag in rule.contents.data.get('tags', []): + for tag in rule.contents.data.get("tags") or []: # type: ignore[reportUnknownVariableType] value = rule.id expected_prefixes = set([tag.split(":")[0] + ":" for tag in definitions.EXPECTED_RULE_TAGS]) - tag = reduce(lambda s, substr: s.replace(substr, ''), expected_prefixes, tag).lstrip() - layer_key = tag.replace(' ', '-').lower() - self.add_rule_to_technique(rule, 'tags', tactic, technique_id, value, layer_key=layer_key) - - def add_rule_to_technique(self, - rule: TOMLRule, - layer_name: str, - tactic: str, - technique_id: str, - value: str, - layer_key: Optional[str] = None): + tag = reduce(lambda s, substr: s.replace(substr, ""), expected_prefixes, tag).lstrip() # type: ignore[reportUnknownMemberType] + layer_key = tag.replace(" ", "-").lower() # type: ignore[reportUnknownVariableType] + self.add_rule_to_technique(rule, "tags", tactic, technique_id, value, layer_key=layer_key) # type: ignore[reportUnknownArgumentType] + + def add_rule_to_technique( + self, + rule: TOMLRule, + layer_name: str, + tactic: str, + technique_id: str, + value: str, + layer_key: str | None = None, + ): """Add a rule to a technique metadata and links.""" layer = self.get_layer(layer_name, layer_key) - layer[tactic][technique_id]['metadata'].append(self.meta_dict(rule.name, value)) - layer[tactic][technique_id]['links'].append(self.rule_links_dict(rule)) + layer[tactic][technique_id]["metadata"].append(self.meta_dict(rule.name, value)) + layer[tactic][technique_id]["links"].append(self.rule_links_dict(rule)) def process_rule(self, rule: TOMLRule, tactic: str, technique_id: str): self._update_all(rule, tactic, technique_id) @@ -230,15 +230,15 @@ def process_rules(self): for sub in technique_entry.subtechnique: self.process_rule(rule, tactic, sub.id) - def build_navigator(self, layer_name: str, layer_key: Optional[str] = None) -> Navigator: - populated_techniques = [] + def build_navigator(self, layer_name: str, layer_key: str | None = None) -> Navigator: + populated_techniques: list[dict[str, Any]] = [] layer = self.get_layer(layer_name, layer_key) - base_name = f'{layer_name}-{layer_key}' if layer_key else layer_name - base_name = base_name.replace('*', 'WILDCARD') - name = f'Elastic-detection-rules-{base_name}' + base_name = f"{layer_name}-{layer_key}" if layer_key else layer_name + base_name = base_name.replace("*", "WILDCARD") + name = f"Elastic-detection-rules-{base_name}" for tactic, techniques in layer.items(): - tactic_normalized = '-'.join(tactic.lower().split()) + tactic_normalized = "-".join(tactic.lower().split()) for technique_id, rules_data in techniques.items(): rules_data.update(tactic=tactic_normalized, techniqueID=technique_id) techniques = Techniques.from_dict(rules_data) @@ -246,47 +246,48 @@ def build_navigator(self, layer_name: str, layer_key: Optional[str] = None) -> N populated_techniques.append(techniques.to_dict()) base_nav_obj = { - 'name': name, - 'techniques': populated_techniques, - 'versions': {'attack': CURRENT_ATTACK_VERSION} + "name": name, + "techniques": populated_techniques, + "versions": {"attack": CURRENT_ATTACK_VERSION}, } navigator = Navigator.from_dict(base_nav_obj) return navigator - def build_all(self) -> List[Navigator]: - built = [] + def build_all(self) -> list[Navigator]: + built: list[Navigator] = [] for layer_name, data in self.layers.items(): # this is a single layer - if 'defense evasion' in data: + if "defense evasion" in data: built.append(self.build_navigator(layer_name)) else: # multi layers - for layer_key, sub_data in data.items(): + for layer_key, _ in data.items(): built.append(self.build_navigator(layer_name, layer_key)) return built @staticmethod - def _save(built: Navigator, directory: Path, verbose=True) -> Path: - path = directory.joinpath(built.name).with_suffix('.json') - path.write_text(json.dumps(built.to_dict(), indent=2)) + def _save(built: Navigator, directory: Path, verbose: bool = True) -> Path: + path = directory.joinpath(built.name).with_suffix(".json") + _ = path.write_text(json.dumps(built.to_dict(), indent=2)) if verbose: - print(f'saved: {path}') + print(f"saved: {path}") return path - def save_layer(self, - layer_name: str, - directory: Path, - layer_key: Optional[str] = None, - verbose=True - ) -> (Path, dict): + def save_layer( + self, + layer_name: str, + directory: Path, + layer_key: str | None = None, + verbose: bool = True, + ) -> tuple[Path, Navigator]: built = self.build_navigator(layer_name, layer_key) return self._save(built, directory, verbose), built - def save_all(self, directory: Path, verbose=True) -> Dict[Path, Navigator]: - paths = {} + def save_all(self, directory: Path, verbose: bool = True) -> dict[Path, Navigator]: + paths: dict[Path, Navigator] = {} for built in self.build_all(): path = self._save(built, directory, verbose) diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index a3bef733f7c..de5c16821c8 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -4,8 +4,9 @@ # 2.0. """Packaging and preparation for releases.""" + import base64 -import datetime +from datetime import datetime, timezone, date import hashlib import json import os @@ -13,7 +14,7 @@ import textwrap from collections import defaultdict from pathlib import Path -from typing import Dict, Optional, Tuple +from typing import Any from semver import Version import click @@ -30,16 +31,16 @@ RULES_CONFIG = parse_rules_config() -RELEASE_DIR = get_path("releases") +RELEASE_DIR = get_path(["releases"]) PACKAGE_FILE = str(RULES_CONFIG.packages_file) -NOTICE_FILE = get_path('NOTICE.txt') -FLEET_PKG_LOGO = get_etc_path("security-logo-color-64px.svg") +NOTICE_FILE = get_path(["NOTICE.txt"]) +FLEET_PKG_LOGO = get_etc_path(["security-logo-color-64px.svg"]) # CHANGELOG_FILE = Path(get_etc_path('rules-changelog.json')) -def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[dict] = None) -> bool: +def filter_rule(rule: TOMLRule, config_filter: dict[str, Any], exclude_fields: dict[str, Any] | None = None) -> bool: """Filter a rule based off metadata and a package configuration.""" flat_rule = rule.contents.flattened_dict() @@ -51,7 +52,7 @@ def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[di rule_value = flat_rule[key] if isinstance(rule_value, list): - rule_values = {v.lower() if isinstance(v, str) else v for v in rule_value} + rule_values: set[Any] = {v.lower() if isinstance(v, str) else v for v in rule_value} # type: ignore[reportUnknownVariableType] else: rule_values = {rule_value.lower() if isinstance(rule_value, str) else rule_value} @@ -65,7 +66,7 @@ def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[di unique_fields = get_unique_query_fields(rule) for index, fields in exclude_fields.items(): - if unique_fields and (rule.contents.data.index_or_dataview == index or index == 'any'): + if unique_fields and (rule.contents.data.index_or_dataview == index or index == "any"): # type: ignore[reportAttributeAccessIssue] if set(unique_fields) & set(fields): return False @@ -78,10 +79,18 @@ def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[di class Package(object): """Packaging object for siem rules and releases.""" - def __init__(self, rules: RuleCollection, name: str, release: Optional[bool] = False, - min_version: Optional[int] = None, max_version: Optional[int] = None, - registry_data: Optional[dict] = None, verbose: Optional[bool] = True, - generate_navigator: bool = False, historical: bool = False): + def __init__( + self, + rules: RuleCollection, + name: str, + release: bool | None = False, + min_version: int | None = None, + max_version: int | None = None, + registry_data: dict[str, Any] | None = None, + generate_navigator: bool = False, + verbose: bool = True, + historical: bool = False, + ): """Initialize a package.""" self.name = name self.rules = rules @@ -92,44 +101,49 @@ def __init__(self, rules: RuleCollection, name: str, release: Optional[bool] = F self.historical = historical if min_version is not None: - self.rules = self.rules.filter(lambda r: min_version <= r.contents.saved_version) + self.rules = self.rules.filter(lambda r: min_version <= r.contents.saved_version) # type: ignore[reportOperatorIssue] if max_version is not None: - self.rules = self.rules.filter(lambda r: max_version >= r.contents.saved_version) + self.rules = self.rules.filter(lambda r: max_version >= r.contents.saved_version) # type: ignore[reportOperatorIssue] assert not RULES_CONFIG.bypass_version_lock, "Packaging can not be used when version locking is bypassed." - self.changed_ids, self.new_ids, self.removed_ids = \ - loaded_version_lock.manage_versions(self.rules, verbose=verbose, save_changes=False) + self.changed_ids, self.new_ids, self.removed_ids = loaded_version_lock.manage_versions( + self.rules, + verbose=verbose, + save_changes=False, + ) @classmethod def load_configs(cls): """Load configs from packages.yaml.""" - return RULES_CONFIG.packages['package'] + return RULES_CONFIG.packages["package"] @staticmethod - def _package_kibana_notice_file(save_dir): + def _package_kibana_notice_file(save_dir: Path): """Convert and save notice file with package.""" - with open(NOTICE_FILE, 'rt') as f: + with open(NOTICE_FILE, "rt") as f: notice_txt = f.read() - with open(os.path.join(save_dir, 'notice.ts'), 'wt') as f: - commented_notice = [f' * {line}'.rstrip() for line in notice_txt.splitlines()] - lines = ['/* eslint-disable @kbn/eslint/require-license-header */', '', '/* @notice'] - lines = lines + commented_notice + [' */', ''] - f.write('\n'.join(lines)) + with open(os.path.join(save_dir, "notice.ts"), "wt") as f: + commented_notice = [f" * {line}".rstrip() for line in notice_txt.splitlines()] + lines = ["/* eslint-disable @kbn/eslint/require-license-header */", "", "/* @notice"] + lines = lines + commented_notice + [" */", ""] + _ = f.write("\n".join(lines)) - def _package_kibana_index_file(self, save_dir): + def _package_kibana_index_file(self, save_dir: Path): """Convert and save index file with package.""" - sorted_rules = sorted(self.rules, key=lambda k: (k.contents.metadata.creation_date, os.path.basename(k.path))) + sorted_rules = sorted(self.rules, key=lambda k: (k.contents.metadata.creation_date, k.path.name)) # type: ignore[reportOptionalMemberAccess] comments = [ - '// Auto generated file from either:', - '// - scripts/regen_prepackage_rules_index.sh', - '// - detection-rules repo using CLI command build-release', - '// Do not hand edit. Run script/command to regenerate package information instead', + "// Auto generated file from either:", + "// - scripts/regen_prepackage_rules_index.sh", + "// - detection-rules repo using CLI command build-release", + "// Do not hand edit. Run script/command to regenerate package information instead", + ] + rule_imports = [ + f"import rule{i} from './{os.path.splitext(r.path.name)[0] + '.json'}';" # type: ignore[reportOptionalMemberAccess] + for i, r in enumerate(sorted_rules, 1) ] - rule_imports = [f"import rule{i} from './{os.path.splitext(os.path.basename(r.path))[0] + '.json'}';" - for i, r in enumerate(sorted_rules, 1)] - const_exports = ['export const rawRules = ['] + const_exports = ["export const rawRules = ["] const_exports.extend(f" rule{i}," for i, _ in enumerate(sorted_rules, 1)) const_exports.append("];") const_exports.append("") @@ -141,47 +155,55 @@ def _package_kibana_index_file(self, save_dir): index_ts.append("") index_ts.extend(const_exports) - with open(os.path.join(save_dir, 'index.ts'), 'wt') as f: - f.write('\n'.join(index_ts)) + with open(os.path.join(save_dir, "index.ts"), "wt") as f: + _ = f.write("\n".join(index_ts)) - def save_release_files(self, directory: str, changed_rules: list, new_rules: list, removed_rules: list): + def save_release_files( + self, + directory: Path, + changed_rules: list[definitions.UUIDString], + new_rules: list[str], + removed_rules: list[str], + ): """Release a package.""" summary, changelog = self.generate_summary_and_changelog(changed_rules, new_rules, removed_rules) - with open(os.path.join(directory, f'{self.name}-summary.txt'), 'w') as f: - f.write(summary) - with open(os.path.join(directory, f'{self.name}-changelog-entry.md'), 'w') as f: - f.write(changelog) + with open(os.path.join(directory, f"{self.name}-summary.txt"), "w") as f: + _ = f.write(summary) + with open(os.path.join(directory, f"{self.name}-changelog-entry.md"), "w") as f: + _ = f.write(changelog) if self.generate_navigator: - self.generate_attack_navigator(Path(directory)) + _ = self.generate_attack_navigator(Path(directory)) consolidated = json.loads(self.get_consolidated()) - with open(os.path.join(directory, f'{self.name}-consolidated-rules.json'), 'w') as f: + with open(os.path.join(directory, f"{self.name}-consolidated-rules.json"), "w") as f: json.dump(consolidated, f, sort_keys=True, indent=2) consolidated_rules = Ndjson(consolidated) - consolidated_rules.dump(Path(directory).joinpath(f'{self.name}-consolidated-rules.ndjson'), sort_keys=True) + consolidated_rules.dump(Path(directory).joinpath(f"{self.name}-consolidated-rules.ndjson"), sort_keys=True) - self.generate_xslx(os.path.join(directory, f'{self.name}-summary.xlsx')) + self.generate_xslx(os.path.join(directory, f"{self.name}-summary.xlsx")) bulk_upload, rules_ndjson = self.create_bulk_index_body() - bulk_upload.dump(Path(directory).joinpath(f'{self.name}-enriched-rules-index-uploadable.ndjson'), - sort_keys=True) - rules_ndjson.dump(Path(directory).joinpath(f'{self.name}-enriched-rules-index-importable.ndjson'), - sort_keys=True) - - def get_consolidated(self, as_api=True): + bulk_upload.dump( + Path(directory).joinpath(f"{self.name}-enriched-rules-index-uploadable.ndjson"), sort_keys=True + ) + rules_ndjson.dump( + Path(directory).joinpath(f"{self.name}-enriched-rules-index-importable.ndjson"), sort_keys=True + ) + + def get_consolidated(self, as_api: bool = True): """Get a consolidated package of the rules in a single file.""" - full_package = [] + full_package: list[dict[str, Any]] = [] for rule in self.rules: full_package.append(rule.contents.to_api_format() if as_api else rule.contents.to_dict()) return json.dumps(full_package, sort_keys=True) - def save(self, verbose=True): + def save(self, verbose: bool = True): """Save a package and all artifacts.""" save_dir = RELEASE_DIR / self.name - rules_dir = save_dir / 'rules' - extras_dir = save_dir / 'extras' + rules_dir = save_dir / "rules" + extras_dir = save_dir / "extras" # remove anything that existed before shutil.rmtree(save_dir, ignore_errors=True) @@ -189,7 +211,9 @@ def save(self, verbose=True): extras_dir.mkdir(parents=True, exist_ok=True) for rule in self.rules: - rule.save_json(rules_dir / Path(rule.path.name).with_suffix('.json')) + if not rule.path: + raise ValueError("Rule path is not found") + rule.save_json(rules_dir / Path(rule.path.name).with_suffix(".json")) self._package_kibana_notice_file(rules_dir) self._package_kibana_index_file(rules_dir) @@ -199,43 +223,67 @@ def save(self, verbose=True): self.save_release_files(extras_dir, self.changed_ids, self.new_ids, self.removed_ids) # zip all rules only and place in extras - shutil.make_archive(extras_dir / self.name, 'zip', root_dir=rules_dir.parent, base_dir=rules_dir.name) + _ = shutil.make_archive( + str(extras_dir / self.name), + "zip", + root_dir=rules_dir.parent, + base_dir=rules_dir.name, + ) # zip everything and place in release root - shutil.make_archive( - save_dir / f"{self.name}-all", "zip", root_dir=extras_dir.parent, base_dir=extras_dir.name + _ = shutil.make_archive( + str(save_dir / f"{self.name}-all"), + "zip", + root_dir=extras_dir.parent, + base_dir=extras_dir.name, ) if verbose: - click.echo(f'Package saved to: {save_dir}') - - def export(self, outfile, downgrade_version=None, verbose=True, skip_unsupported=False): + click.echo(f"Package saved to: {save_dir}") + + def export( + self, + outfile: Path, + downgrade_version: definitions.SemVer | None = None, + verbose: bool = True, + skip_unsupported: bool = False, + ): """Export rules into a consolidated ndjson file.""" - from .main import _export_rules + from .main import _export_rules # type: ignore[reportPrivateUsage] - _export_rules(self.rules, outfile=outfile, downgrade_version=downgrade_version, verbose=verbose, - skip_unsupported=skip_unsupported) + _export_rules( + self.rules, + outfile=outfile, + downgrade_version=downgrade_version, + verbose=verbose, + skip_unsupported=skip_unsupported, + ) - def get_package_hash(self, as_api=True, verbose=True): + def get_package_hash(self, as_api: bool = True, verbose: bool = True): """Get hash of package contents.""" - contents = base64.b64encode(self.get_consolidated(as_api=as_api).encode('utf-8')) + contents = base64.b64encode(self.get_consolidated(as_api=as_api).encode("utf-8")) sha256 = hashlib.sha256(contents).hexdigest() if verbose: - click.echo('- sha256: {}'.format(sha256)) + click.echo("- sha256: {}".format(sha256)) return sha256 @classmethod - def from_config(cls, rule_collection: Optional[RuleCollection] = None, config: Optional[dict] = None, - verbose: Optional[bool] = False, historical: Optional[bool] = True) -> 'Package': + def from_config( + cls, + rule_collection: RuleCollection | None = None, + config: dict[str, Any] | None = None, + verbose: bool = False, + historical: bool = True, + ) -> "Package": """Load a rules package given a config.""" all_rules = rule_collection or RuleCollection.default() config = config or {} - exclude_fields = config.pop('exclude_fields', {}) + exclude_fields = config.pop("exclude_fields", {}) # deprecated rules are now embedded in the RuleCollection.deprecated - this is left here for backwards compat - config.pop('log_deprecated', False) - rule_filter = config.pop('filter', {}) + config.pop("log_deprecated", False) + rule_filter = config.pop("filter", {}) rules = all_rules.filter(lambda r: filter_rule(r, rule_filter, exclude_fields)) @@ -243,31 +291,36 @@ def from_config(cls, rule_collection: Optional[RuleCollection] = None, config: O rules.deprecated = all_rules.deprecated if verbose: - click.echo(f' - {len(all_rules) - len(rules)} rules excluded from package') + click.echo(f" - {len(all_rules) - len(rules)} rules excluded from package") package = cls(rules, verbose=verbose, historical=historical, **config) return package - def generate_summary_and_changelog(self, changed_rule_ids, new_rule_ids, removed_rules): + def generate_summary_and_changelog( + self, + changed_rule_ids: list[definitions.UUIDString], + new_rule_ids: list[str], + removed_rules: list[str], + ): """Generate stats on package.""" - summary = { - 'changed': defaultdict(list), - 'added': defaultdict(list), - 'removed': defaultdict(list), - 'unchanged': defaultdict(list) + summary: dict[str, dict[str, list[str]]] = { + "changed": defaultdict(list), + "added": defaultdict(list), + "removed": defaultdict(list), + "unchanged": defaultdict(list), } - changelog = { - 'changed': defaultdict(list), - 'added': defaultdict(list), - 'removed': defaultdict(list), - 'unchanged': defaultdict(list) + changelog: dict[str, dict[str, list[str]]] = { + "changed": defaultdict(list), + "added": defaultdict(list), + "removed": defaultdict(list), + "unchanged": defaultdict(list), } # Build an index map first longest_name = 0 - indexes = set() + indexes: set[str] = set() for rule in self.rules: longest_name = max(longest_name, len(rule.name)) index_list = getattr(rule.contents.data, "index", []) @@ -276,103 +329,110 @@ def generate_summary_and_changelog(self, changed_rule_ids, new_rule_ids, removed index_map = {index: str(i) for i, index in enumerate(sorted(indexes))} - def get_summary_rule_info(r: TOMLRule): - r = r.contents - rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}' + def get_summary_rule_info(r: TOMLRule) -> str: + contents = r.contents + rule_str = f"{r.name:<{longest_name}} (v:{contents.autobumped_version} t:{contents.data.type}" if isinstance(rule.contents.data, QueryRuleData): - index = rule.contents.data.get("index") or [] - rule_str += f'-{r.data.language}' - rule_str += f'(indexes:{"".join(index_map[idx] for idx in index) or "none"}' + index: list[str] = rule.contents.data.get("index") or [] + rule_str += f"-{contents.data.language}" # type: ignore[reportAttributeAccessIssue] + rule_str += f"(indexes:{''.join(index_map[idx] for idx in index) or 'none'}" return rule_str - def get_markdown_rule_info(r: TOMLRule, sd): + def get_markdown_rule_info(r: TOMLRule, sd: str): # lookup the rule in the GitHub tag v{major.minor.patch} data = r.contents.data - rules_dir_link = f'https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/' + rules_dir_link = f"https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/" rule_type = data.language if isinstance(data, QueryRuleData) else data.type - return f'`{r.id}` **[{r.name}]({rules_dir_link + os.path.basename(str(r.path))})** (_{rule_type}_)' + return f"`{r.id}` **[{r.name}]({rules_dir_link + os.path.basename(str(r.path))})** (_{rule_type}_)" for rule in self.rules: + if not rule.path: + raise ValueError("Unknown rule path") sub_dir = os.path.basename(os.path.dirname(rule.path)) if rule.id in changed_rule_ids: - summary['changed'][sub_dir].append(get_summary_rule_info(rule)) - changelog['changed'][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) + summary["changed"][sub_dir].append(get_summary_rule_info(rule)) + changelog["changed"][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) elif rule.id in new_rule_ids: - summary['added'][sub_dir].append(get_summary_rule_info(rule)) - changelog['added'][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) + summary["added"][sub_dir].append(get_summary_rule_info(rule)) + changelog["added"][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) else: - summary['unchanged'][sub_dir].append(get_summary_rule_info(rule)) - changelog['unchanged'][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) + summary["unchanged"][sub_dir].append(get_summary_rule_info(rule)) + changelog["unchanged"][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) for rule in self.deprecated_rules: sub_dir = os.path.basename(os.path.dirname(rule.path)) + if not rule.name: + raise ValueError("Rule name is not found") + if rule.id in removed_rules: - summary['removed'][sub_dir].append(rule.name) - changelog['removed'][sub_dir].append(rule.name) + summary["removed"][sub_dir].append(rule.name) + changelog["removed"][sub_dir].append(rule.name) - def format_summary_rule_str(rule_dict): - str_fmt = '' + def format_summary_rule_str(rule_dict: dict[str, Any]): + str_fmt = "" for sd, rules in sorted(rule_dict.items(), key=lambda x: x[0]): - str_fmt += f'\n{sd} ({len(rules)})\n' - str_fmt += '\n'.join(' - ' + s for s in sorted(rules)) - return str_fmt or '\nNone' + str_fmt += f"\n{sd} ({len(rules)})\n" + str_fmt += "\n".join(" - " + s for s in sorted(rules)) + return str_fmt or "\nNone" - def format_changelog_rule_str(rule_dict): - str_fmt = '' + def format_changelog_rule_str(rule_dict: dict[str, Any]): + str_fmt = "" for sd, rules in sorted(rule_dict.items(), key=lambda x: x[0]): - str_fmt += f'\n- **{sd}** ({len(rules)})\n' - str_fmt += '\n'.join(' - ' + s for s in sorted(rules)) - return str_fmt or '\nNone' + str_fmt += f"\n- **{sd}** ({len(rules)})\n" + str_fmt += "\n".join(" - " + s for s in sorted(rules)) + return str_fmt or "\nNone" - def rule_count(rule_dict): + def rule_count(rule_dict: dict[str, Any]): count = 0 for _, rules in rule_dict.items(): count += len(rules) return count - today = str(datetime.date.today()) - summary_fmt = [f'{sf.capitalize()} ({rule_count(summary[sf])}): \n{format_summary_rule_str(summary[sf])}\n' - for sf in ('added', 'changed', 'removed', 'unchanged') if summary[sf]] - - change_fmt = [f'{sf.capitalize()} ({rule_count(changelog[sf])}): \n{format_changelog_rule_str(changelog[sf])}\n' - for sf in ('added', 'changed', 'removed') if changelog[sf]] - - summary_str = '\n'.join([ - f'Version {self.name}', - f'Generated: {today}', - f'Total Rules: {len(self.rules)}', - f'Package Hash: {self.get_package_hash(verbose=False)}', - '---', - '(v: version, t: rule_type-language)', - 'Index Map:\n{}'.format("\n".join(f" {v}: {k}" for k, v in index_map.items())), - '', - 'Rules', - *summary_fmt - ]) - - changelog_str = '\n'.join([ - f'# Version {self.name}', - f'_Released {today}_', - '', - '### Rules', - *change_fmt, - '', - '### CLI' - ]) + today = str(date.today()) + summary_fmt = [ + f"{sf.capitalize()} ({rule_count(summary[sf])}): \n{format_summary_rule_str(summary[sf])}\n" + for sf in ("added", "changed", "removed", "unchanged") + if summary[sf] + ] + + change_fmt = [ + f"{sf.capitalize()} ({rule_count(changelog[sf])}): \n{format_changelog_rule_str(changelog[sf])}\n" + for sf in ("added", "changed", "removed") + if changelog[sf] + ] + + summary_str = "\n".join( + [ + f"Version {self.name}", + f"Generated: {today}", + f"Total Rules: {len(self.rules)}", + f"Package Hash: {self.get_package_hash(verbose=False)}", + "---", + "(v: version, t: rule_type-language)", + "Index Map:\n{}".format("\n".join(f" {v}: {k}" for k, v in index_map.items())), + "", + "Rules", + *summary_fmt, + ] + ) + + changelog_str = "\n".join( + [f"# Version {self.name}", f"_Released {today}_", "", "### Rules", *change_fmt, "", "### CLI"] + ) return summary_str, changelog_str - def generate_attack_navigator(self, path: Path) -> Dict[Path, Navigator]: + def generate_attack_navigator(self, path: Path) -> dict[Path, Navigator]: """Generate ATT&CK navigator layer files.""" - save_dir = path / 'navigator_layers' + save_dir = path / "navigator_layers" save_dir.mkdir() lb = NavigatorBuilder(self.rules.rules) return lb.save_all(save_dir, verbose=False) - def generate_xslx(self, path): + def generate_xslx(self, path: str): """Generate a detailed breakdown of a package in an excel file.""" from .docs import PackageDocument @@ -380,29 +440,28 @@ def generate_xslx(self, path): doc.populate() doc.close() - def _generate_registry_package(self, save_dir): + def _generate_registry_package(self, save_dir: Path): """Generate the artifact for the oob package-storage.""" - from .schemas.registry_package import (RegistryPackageManifestV1, - RegistryPackageManifestV3) + from .schemas.registry_package import RegistryPackageManifestV1, RegistryPackageManifestV3 # 8.12.0+ we use elastic package v3 stack_version = Version.parse(self.name, optional_minor_and_patch=True) - if stack_version >= Version.parse('8.12.0'): + if stack_version >= Version.parse("8.12.0"): manifest = RegistryPackageManifestV3.from_dict(self.registry_data) else: manifest = RegistryPackageManifestV1.from_dict(self.registry_data) - package_dir = Path(save_dir) / 'fleet' / manifest.version - docs_dir = package_dir / 'docs' - rules_dir = package_dir / 'kibana' / definitions.ASSET_TYPE + package_dir = Path(save_dir) / "fleet" / manifest.version + docs_dir = package_dir / "docs" + rules_dir = package_dir / "kibana" / definitions.ASSET_TYPE docs_dir.mkdir(parents=True) rules_dir.mkdir(parents=True) - manifest_file = package_dir / 'manifest.yml' - readme_file = docs_dir / 'README.md' - notice_file = package_dir / 'NOTICE.txt' - logo_file = package_dir / 'img' / 'security-logo-color-64px.svg' + manifest_file = package_dir / "manifest.yml" + readme_file = docs_dir / "README.md" + notice_file = package_dir / "NOTICE.txt" + logo_file = package_dir / "img" / "security-logo-color-64px.svg" manifest_file.write_text(yaml.safe_dump(manifest.to_dict())) @@ -416,7 +475,7 @@ def _generate_registry_package(self, save_dir): # asset['id] and the file name needs to resemble RULEID_VERSION instead of RULEID asset_id = f"{asset['attributes']['rule_id']}_{asset['attributes']['version']}" asset["id"] = asset_id - asset_path = rules_dir / f'{asset_id}.json' + asset_path = rules_dir / f"{asset_id}.json" asset_path.write_text(json.dumps(asset, indent=4, sort_keys=True), encoding="utf-8") @@ -437,57 +496,59 @@ def _generate_registry_package(self, save_dir): # notice only needs to be appended to the README for 7.13.x # in 7.14+ there's a separate modal to display this if self.name == "7.13": - textwrap.indent(notice_contents, prefix=" ") + notice_contents = textwrap.indent(notice_contents, prefix=" ") readme_file.write_text(readme_text) notice_file.write_text(notice_contents) - def create_bulk_index_body(self) -> Tuple[Ndjson, Ndjson]: + def create_bulk_index_body(self) -> tuple[Ndjson, Ndjson]: """Create a body to bulk index into a stack.""" package_hash = self.get_package_hash(verbose=False) - now = datetime.datetime.isoformat(datetime.datetime.utcnow()) - create = {'create': {'_index': f'rules-repo-{self.name}-{package_hash}'}} + now = datetime.now(timezone.utc).isoformat() + create = {"create": {"_index": f"rules-repo-{self.name}-{package_hash}"}} # first doc is summary stats - summary_doc = { - 'group_hash': package_hash, - 'package_version': self.name, - 'rule_count': len(self.rules), - 'rule_ids': [], - 'rule_names': [], - 'rule_hashes': [], - 'source': 'repo', - 'details': {'datetime_uploaded': now} + summary_doc: dict[str, Any] = { + "group_hash": package_hash, + "package_version": self.name, + "rule_count": len(self.rules), + "rule_ids": [], + "rule_names": [], + "rule_hashes": [], + "source": "repo", + "details": {"datetime_uploaded": now}, } bulk_upload_docs = Ndjson([create, summary_doc]) importable_rules_docs = Ndjson() for rule in self.rules: - summary_doc['rule_ids'].append(rule.id) - summary_doc['rule_names'].append(rule.name) - summary_doc['rule_hashes'].append(rule.contents.get_hash()) + summary_doc["rule_ids"].append(rule.id) + summary_doc["rule_names"].append(rule.name) + summary_doc["rule_hashes"].append(rule.contents.get_hash()) if rule.id in self.new_ids: - status = 'new' + status = "new" elif rule.id in self.changed_ids: - status = 'modified' + status = "modified" else: - status = 'unmodified' + status = "unmodified" bulk_upload_docs.append(create) relative_path = str(rule.get_base_rule_dir()) - if relative_path is None: + if not relative_path: raise ValueError(f"Could not find a valid relative path for the rule: {rule.id}") - rule_doc = dict(hash=rule.contents.get_hash(), - source='repo', - datetime_uploaded=now, - status=status, - package_version=self.name, - flat_mitre=ThreatMapping.flatten(rule.contents.data.threat).to_dict(), - relative_path=relative_path) + rule_doc = dict( + hash=rule.contents.get_hash(), + source="repo", + datetime_uploaded=now, + status=status, + package_version=self.name, + flat_mitre=ThreatMapping.flatten(rule.contents.data.threat).to_dict(), + relative_path=relative_path, + ) rule_doc.update(**rule.contents.to_api_format()) bulk_upload_docs.append(rule_doc) importable_rules_docs.append(rule_doc) @@ -495,14 +556,17 @@ def create_bulk_index_body(self) -> Tuple[Ndjson, Ndjson]: return bulk_upload_docs, importable_rules_docs @staticmethod - def add_historical_rules(historical_rules: Dict[str, dict], manifest_version: str) -> list: + def add_historical_rules( + historical_rules: dict[str, dict[str, Any]], + manifest_version: str, + ) -> list[dict[str, Any]] | None: """Adds historical rules to existing build package.""" - rules_dir = CURRENT_RELEASE_PATH / 'fleet' / manifest_version / 'kibana' / 'security_rule' + rules_dir = CURRENT_RELEASE_PATH / "fleet" / manifest_version / "kibana" / "security_rule" # iterates over historical rules from previous package and writes them to disk for _, historical_rule_contents in historical_rules.items(): rule_id = historical_rule_contents["attributes"]["rule_id"] - historical_rule_version = historical_rule_contents['attributes']['version'] + historical_rule_version = historical_rule_contents["attributes"]["version"] # checks if the rule exists in the current package first current_rule_path = list(rules_dir.glob(f"{rule_id}*.json")) @@ -512,7 +576,7 @@ def add_historical_rules(historical_rules: Dict[str, dict], manifest_version: st # load the current rule from disk current_rule_path = current_rule_path[0] current_rule_json = json.load(current_rule_path.open(encoding="UTF-8")) - current_rule_version = current_rule_json['attributes']['version'] + current_rule_version = current_rule_json["attributes"]["version"] # if the historical rule version and current rules version differ, write # the historical rule to disk @@ -524,4 +588,4 @@ def add_historical_rules(historical_rules: Dict[str, dict], manifest_version: st @cached def current_stack_version() -> str: - return Package.load_configs()['name'] + return Package.load_configs()["name"] diff --git a/detection_rules/remote_validation.py b/detection_rules/remote_validation.py index db30c5e953c..49df65ef3b5 100644 --- a/detection_rules/remote_validation.py +++ b/detection_rules/remote_validation.py @@ -4,17 +4,17 @@ # 2.0. from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from functools import cached_property from multiprocessing.pool import ThreadPool -from typing import Dict, List, Optional +from typing import Any, Callable import elasticsearch from elasticsearch import Elasticsearch from marshmallow import ValidationError from requests import HTTPError -from kibana import Kibana +from kibana import Kibana # type: ignore[reportMissingTypeStubs] from .config import load_current_package_version from .misc import ClientError, getdefault, get_elasticsearch_client, get_kibana_client @@ -25,13 +25,14 @@ @dataclass class RemoteValidationResult: """Dataclass for remote validation results.""" + rule_id: definitions.UUIDString rule_name: str - contents: dict + contents: dict[str, Any] rule_version: int stack_version: str - query_results: Optional[dict] - engine_results: Optional[dict] + query_results: dict[str, Any] + engine_results: dict[str, Any] class RemoteConnector: @@ -39,17 +40,17 @@ class RemoteConnector: MAX_RETRIES = 5 - def __init__(self, parse_config: bool = False, **kwargs): - es_args = ['cloud_id', 'ignore_ssl_errors', 'elasticsearch_url', 'es_user', 'es_password', 'timeout'] - kibana_args = ['cloud_id', 'ignore_ssl_errors', 'kibana_url', 'api_key', 'space'] + def __init__(self, parse_config: bool = False, **kwargs: Any): + es_args = ["cloud_id", "ignore_ssl_errors", "elasticsearch_url", "es_user", "es_password", "timeout"] + kibana_args = ["cloud_id", "ignore_ssl_errors", "kibana_url", "api_key", "space"] if parse_config: es_kwargs = {arg: getdefault(arg)() for arg in es_args} kibana_kwargs = {arg: getdefault(arg)() for arg in kibana_args} try: - if 'max_retries' not in es_kwargs: - es_kwargs['max_retries'] = self.MAX_RETRIES + if "max_retries" not in es_kwargs: + es_kwargs["max_retries"] = self.MAX_RETRIES self.es_client = get_elasticsearch_client(**es_kwargs, **kwargs) except ClientError: self.es_client = None @@ -59,15 +60,29 @@ def __init__(self, parse_config: bool = False, **kwargs): except HTTPError: self.kibana_client = None - def auth_es(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None, - elasticsearch_url: Optional[str] = None, es_user: Optional[str] = None, - es_password: Optional[str] = None, timeout: Optional[int] = None, **kwargs) -> Elasticsearch: + def auth_es( + self, + *, + cloud_id: str | None = None, + ignore_ssl_errors: bool | None = None, + elasticsearch_url: str | None = None, + es_user: str | None = None, + es_password: str | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> Elasticsearch: """Return an authenticated Elasticsearch client.""" - if 'max_retries' not in kwargs: - kwargs['max_retries'] = self.MAX_RETRIES - self.es_client = get_elasticsearch_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors, - elasticsearch_url=elasticsearch_url, es_user=es_user, - es_password=es_password, timeout=timeout, **kwargs) + if "max_retries" not in kwargs: + kwargs["max_retries"] = self.MAX_RETRIES + self.es_client = get_elasticsearch_client( + cloud_id=cloud_id, + ignore_ssl_errors=ignore_ssl_errors, + elasticsearch_url=elasticsearch_url, + es_user=es_user, + es_password=es_password, + timeout=timeout, + **kwargs, + ) return self.es_client def auth_kibana( @@ -78,7 +93,7 @@ def auth_kibana( kibana_url: str | None = None, space: str | None = None, ignore_ssl_errors: bool = False, - **kwargs + **kwargs: Any, ) -> Kibana: """Return an authenticated Kibana client.""" self.kibana_client = get_kibana_client( @@ -87,7 +102,7 @@ def auth_kibana( kibana_url=kibana_url, api_key=api_key, space=space, - **kwargs + **kwargs, ) return self.kibana_client @@ -99,41 +114,53 @@ def __init__(self, parse_config: bool = False): super(RemoteValidator, self).__init__(parse_config=parse_config) @cached_property - def get_validate_methods(self) -> List[str]: + def get_validate_methods(self) -> list[str]: """Return all validate methods.""" - exempt = ('validate_rule', 'validate_rules') - methods = [m for m in self.__dir__() if m.startswith('validate_') and m not in exempt] + exempt = ("validate_rule", "validate_rules") + methods = [m for m in self.__dir__() if m.startswith("validate_") and m not in exempt] return methods - def get_validate_method(self, name: str) -> callable: + def get_validate_method(self, name: str) -> Callable[..., Any]: """Return validate method by name.""" - assert name in self.get_validate_methods, f'validate method {name} not found' + assert name in self.get_validate_methods, f"validate method {name} not found" return getattr(self, name) @staticmethod - def prep_for_preview(contents: TOMLRuleContents) -> dict: + def prep_for_preview(contents: TOMLRuleContents) -> dict[str, Any]: """Prepare rule for preview.""" - end_time = datetime.utcnow().isoformat() + end_time = datetime.now(timezone.utc).isoformat() dumped = contents.to_api_format().copy() dumped.update(timeframeEnd=end_time, invocationCount=1) return dumped - def engine_preview(self, contents: TOMLRuleContents) -> dict: + def engine_preview(self, contents: TOMLRuleContents) -> dict[str, Any]: """Get results from detection engine preview API.""" dumped = self.prep_for_preview(contents) - return self.kibana_client.post('/api/detection_engine/rules/preview', json=dumped) + if not self.kibana_client: + raise ValueError("No Kibana client found") + return self.kibana_client.post("/api/detection_engine/rules/preview", json=dumped) # type: ignore[reportReturnType] def validate_rule(self, contents: TOMLRuleContents) -> RemoteValidationResult: """Validate a single rule query.""" - method = self.get_validate_method(f'validate_{contents.data.type}') + method = self.get_validate_method(f"validate_{contents.data.type}") query_results = method(contents) engine_results = self.engine_preview(contents) rule_version = contents.autobumped_version stack_version = load_current_package_version() - return RemoteValidationResult(contents.data.rule_id, contents.data.name, contents.to_api_format(), - rule_version, stack_version, query_results, engine_results) + if not rule_version: + raise ValueError("No rule version found") + + return RemoteValidationResult( + contents.data.rule_id, + contents.data.name, + contents.to_api_format(), + rule_version, + stack_version, + query_results, + engine_results, + ) - def validate_rules(self, rules: List[TOMLRule], threads: int = 5) -> Dict[str, RemoteValidationResult]: + def validate_rules(self, rules: list[TOMLRule], threads: int = 5) -> dict[str, RemoteValidationResult]: """Validate a collection of rules via threads.""" responses = {} @@ -141,69 +168,83 @@ def request(c: TOMLRuleContents): try: responses[c.data.rule_id] = self.validate_rule(c) except ValidationError as e: - responses[c.data.rule_id] = e.messages + responses[c.data.rule_id] = e.messages # type: ignore[reportUnknownMemberType] pool = ThreadPool(processes=threads) - pool.map(request, [r.contents for r in rules]) + _ = pool.map(request, [r.contents for r in rules]) pool.close() pool.join() - return responses + return responses # type: ignore[reportUnknownVariableType] - def validate_esql(self, contents: TOMLRuleContents) -> dict: - query = contents.data.query + def validate_esql(self, contents: TOMLRuleContents) -> dict[str, Any]: + query = contents.data.query # type: ignore[reportAttributeAccessIssue] rule_id = contents.data.rule_id headers = {"accept": "application/json", "content-type": "application/json"} - body = {'query': f'{query} | LIMIT 0'} + body = {"query": f"{query} | LIMIT 0"} + if not self.es_client: + raise ValueError("No ES client found") try: - response = self.es_client.perform_request('POST', '/_query', headers=headers, params={'pretty': True}, - body=body) + response = self.es_client.perform_request( + "POST", + "/_query", + headers=headers, + params={"pretty": True}, + body=body, + ) except Exception as exc: if isinstance(exc, elasticsearch.BadRequestError): - raise ValidationError(f'ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}') + raise ValidationError(f"ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}") else: - raise Exception(f'ES|QL query failed for rule: {rule_id}, query: \n{query}') from exc + raise Exception(f"ES|QL query failed for rule: {rule_id}, query: \n{query}") from exc return response.body - def validate_eql(self, contents: TOMLRuleContents) -> dict: + def validate_eql(self, contents: TOMLRuleContents) -> dict[str, Any]: """Validate query for "eql" rule types.""" - query = contents.data.query + query = contents.data.query # type: ignore[reportAttributeAccessIssue] rule_id = contents.data.rule_id - index = contents.data.index - time_range = {"range": {"@timestamp": {"gt": 'now-1h/h', "lte": 'now', "format": "strict_date_optional_time"}}} - body = {'query': query} + index = contents.data.index # type: ignore[reportAttributeAccessIssue] + time_range = {"range": {"@timestamp": {"gt": "now-1h/h", "lte": "now", "format": "strict_date_optional_time"}}} + body: dict[str, Any] = {"query": query} + + if not self.es_client: + raise ValueError("No ES client found") + + if not index: + raise ValueError("Indices not found") + try: - response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range) + response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range) # type: ignore[reportUnknownArgumentType] except Exception as exc: if isinstance(exc, elasticsearch.BadRequestError): - raise ValidationError(f'EQL query failed: {exc} for rule: {rule_id}, query: \n{query}') + raise ValidationError(f"EQL query failed: {exc} for rule: {rule_id}, query: \n{query}") else: - raise Exception(f'EQL query failed for rule: {rule_id}, query: \n{query}') from exc + raise Exception(f"EQL query failed for rule: {rule_id}, query: \n{query}") from exc return response.body @staticmethod - def validate_query(self, contents: TOMLRuleContents) -> dict: + def validate_query(_, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "query" rule types.""" - return {'results': 'Unable to remote validate query rules'} + return {"results": "Unable to remote validate query rules"} @staticmethod - def validate_threshold(self, contents: TOMLRuleContents) -> dict: + def validate_threshold(_, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "threshold" rule types.""" - return {'results': 'Unable to remote validate threshold rules'} + return {"results": "Unable to remote validate threshold rules"} @staticmethod - def validate_new_terms(self, contents: TOMLRuleContents) -> dict: + def validate_new_terms(_, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "new_terms" rule types.""" - return {'results': 'Unable to remote validate new_terms rules'} + return {"results": "Unable to remote validate new_terms rules"} @staticmethod - def validate_threat_match(self, contents: TOMLRuleContents) -> dict: + def validate_threat_match(_, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "threat_match" rule types.""" - return {'results': 'Unable to remote validate threat_match rules'} + return {"results": "Unable to remote validate threat_match rules"} @staticmethod - def validate_machine_learning(self, contents: TOMLRuleContents) -> dict: + def validate_machine_learning(_, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "machine_learning" rule types.""" - return {'results': 'Unable to remote validate machine_learning rules'} + return {"results": "Unable to remote validate machine_learning rules"} diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 06c645d814c..10c7c65a99a 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -3,6 +3,7 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Rule object.""" + import copy import dataclasses import json @@ -14,36 +15,44 @@ from dataclasses import dataclass, field from functools import cached_property from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal from urllib.parse import urlparse from uuid import uuid4 -import eql +import eql # type: ignore[reportMissingTypeStubs] import marshmallow from semver import Version from marko.block import Document as MarkoDocument from marko.ext.gfm import gfm from marshmallow import ValidationError, pre_load, validates_schema -import kql +import kql # type: ignore[reportMissingTypeStubs] from . import beats, ecs, endgame, utils +from .version_lock import loaded_version_lock, VersionLock from .config import load_current_package_version, parse_rules_config -from .integrations import (find_least_compatible_version, get_integration_schema_fields, - load_integrations_manifests, load_integrations_schemas, - parse_datasets) +from .integrations import ( + find_least_compatible_version, + get_integration_schema_fields, + load_integrations_manifests, + load_integrations_schemas, +) from .mixins import MarshmallowDataclassMixin, StackCompatMixin from .rule_formatter import nested_normalize, toml_write -from .schemas import (SCHEMA_DIR, definitions, downgrade, - get_min_supported_stack_version, get_stack_schemas, - strip_non_public_fields) +from .schemas import ( + SCHEMA_DIR, + definitions, + downgrade, + get_min_supported_stack_version, + get_stack_schemas, + strip_non_public_fields, +) from .schemas.stack_compat import get_restricted_fields from .utils import PatchedTemplate, cached, convert_time_span, get_nested_value, set_nested_value -_META_SCHEMA_REQ_DEFAULTS = {} -MIN_FLEET_PACKAGE_VERSION = '7.13.0' -TIME_NOW = time.strftime('%Y/%m/%d') +MIN_FLEET_PACKAGE_VERSION = "7.13.0" +TIME_NOW = time.strftime("%Y/%m/%d") RULES_CONFIG = parse_rules_config() DEFAULT_PREBUILT_RULES_DIRS = RULES_CONFIG.rule_dirs DEFAULT_PREBUILT_BBR_DIRS = RULES_CONFIG.bbr_rules_dirs @@ -51,38 +60,38 @@ BUILD_FIELD_VERSIONS = { - "related_integrations": (Version.parse('8.3.0'), None), - "required_fields": (Version.parse('8.3.0'), None), - "setup": (Version.parse('8.3.0'), None) + "related_integrations": (Version.parse("8.3.0"), None), + "required_fields": (Version.parse("8.3.0"), None), + "setup": (Version.parse("8.3.0"), None), } -@dataclass +@dataclass(kw_only=True) class DictRule: """Simple object wrapper for raw rule dicts.""" - contents: dict - path: Optional[Path] = None + contents: dict[str, Any] + path: Path | None = None @property - def metadata(self) -> dict: + def metadata(self) -> dict[str, Any]: """Metadata portion of TOML file rule.""" - return self.contents.get('metadata', {}) + return self.contents.get("metadata", {}) @property - def data(self) -> dict: + def data(self) -> dict[str, Any]: """Rule portion of TOML file rule.""" - return self.contents.get('data') or self.contents + return self.contents.get("data") or self.contents @property def id(self) -> str: """Get the rule ID.""" - return self.data['rule_id'] + return self.data["rule_id"] # type: ignore[reportUnknownMemberType] @property def name(self) -> str: """Get the rule name.""" - return self.data['name'] + return self.data["name"] # type: ignore[reportUnknownMemberType] def __hash__(self) -> int: """Get the hash of the rule.""" @@ -93,35 +102,36 @@ def __repr__(self) -> str: return f"Rule({self.name} {self.id})" -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RuleMeta(MarshmallowDataclassMixin): """Data stored in a rule's [metadata] section of TOML.""" + creation_date: definitions.Date updated_date: definitions.Date - deprecation_date: Optional[definitions.Date] + deprecation_date: definitions.Date | None = None # Optional fields - bypass_bbr_timing: Optional[bool] - comments: Optional[str] - integration: Optional[Union[str, List[str]]] - maturity: Optional[definitions.Maturity] - min_stack_version: Optional[definitions.SemVer] - min_stack_comments: Optional[str] - os_type_list: Optional[List[definitions.OSType]] - query_schema_validation: Optional[bool] - related_endpoint_rules: Optional[List[str]] - promotion: Optional[bool] + bypass_bbr_timing: bool | None = None + comments: str | None = None + integration: str | list[str] | None = None + maturity: definitions.Maturity | None = None + min_stack_version: definitions.SemVer | None = None + min_stack_comments: str | None = None + os_type_list: list[definitions.OSType] | None = None + query_schema_validation: bool | None = None + related_endpoint_rules: list[str] | None = None + promotion: bool | None = None # Extended information as an arbitrary dictionary - extended: Optional[Dict[str, Any]] + extended: dict[str, Any] | None = None - def get_validation_stack_versions(self) -> Dict[str, dict]: + def get_validation_stack_versions(self) -> dict[str, dict[str, Any]]: """Get a dict of beats and ecs versions per stack release.""" stack_versions = get_stack_schemas(self.min_stack_version) return stack_versions -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RuleTransform(MarshmallowDataclassMixin): """Data stored in a rule's [transform] section of TOML.""" @@ -131,13 +141,13 @@ class RuleTransform(MarshmallowDataclassMixin): # timelines out of scope at the moment - @dataclass(frozen=True) + @dataclass(frozen=True, kw_only=True) class OsQuery: label: str query: str - ecs_mapping: Optional[Dict[str, Dict[Literal['field', 'value'], str]]] + ecs_mapping: dict[str, dict[Literal["field", "value"], str]] | None = None - @dataclass(frozen=True) + @dataclass(frozen=True, kw_only=True) class Investigate: @dataclass(frozen=True) class Provider: @@ -148,23 +158,25 @@ class Provider: valueType: definitions.InvestigateProviderValueType label: str - description: Optional[str] - providers: List[List[Provider]] - relativeFrom: Optional[str] - relativeTo: Optional[str] + description: str | None = None + providers: list[list[Provider]] + relativeFrom: str | None = None + relativeTo: str | None = None # these must be lists in order to have more than one. Their index in the list is how they will be referenced in the # note string templates - osquery: Optional[List[OsQuery]] - investigate: Optional[List[Investigate]] + osquery: list[OsQuery] | None = None + investigate: list[Investigate] | None = None - def render_investigate_osquery_to_string(self) -> Dict[definitions.TransformTypes, List[str]]: + def render_investigate_osquery_to_string(self) -> dict[definitions.TransformTypes, list[str]]: obj = self.to_dict() - rendered: Dict[definitions.TransformTypes, List[str]] = {'osquery': [], 'investigate': []} + rendered: dict[definitions.TransformTypes, list[str]] = {"osquery": [], "investigate": []} for plugin, entries in obj.items(): for entry in entries: - rendered[plugin].append(f'!{{{plugin}{json.dumps(entry, sort_keys=True, separators=(",", ":"))}}}') + if plugin not in rendered: + raise ValueError(f"Unexpected field value: {plugin}") + rendered[plugin].append(f"!{{{plugin}{json.dumps(entry, sort_keys=True, separators=(',', ':'))}}}") return rendered @@ -178,9 +190,10 @@ class BaseThreatEntry: reference: str @pre_load - def modify_url(self, data: Dict[str, Any], **kwargs): + def modify_url(self, data: dict[str, Any], **_: Any): """Modify the URL to support MITRE ATT&CK URLS with and without trailing forward slash.""" - if urlparse(data["reference"]).scheme: + p = urlparse(data["reference"]) # type: ignore[reportUnknownVariableType] + if p.scheme: # type: ignore[reportUnknownMemberType] if not data["reference"].endswith("/"): data["reference"] += "/" return data @@ -189,49 +202,53 @@ def modify_url(self, data: Dict[str, Any], **kwargs): @dataclass(frozen=True) class SubTechnique(BaseThreatEntry): """Mapping to threat subtechnique.""" + reference: definitions.SubTechniqueURL -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Technique(BaseThreatEntry): """Mapping to threat subtechnique.""" + # subtechniques are stored at threat[].technique.subtechnique[] reference: definitions.TechniqueURL - subtechnique: Optional[List[SubTechnique]] + subtechnique: list[SubTechnique] | None = None @dataclass(frozen=True) class Tactic(BaseThreatEntry): """Mapping to a threat tactic.""" + reference: definitions.TacticURL -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ThreatMapping(MarshmallowDataclassMixin): """Mapping to a threat framework.""" + framework: Literal["MITRE ATT&CK"] tactic: Tactic - technique: Optional[List[Technique]] + technique: list[Technique] | None = None @staticmethod - def flatten(threat_mappings: Optional[List]) -> 'FlatThreatMapping': + def flatten(threat_mappings: list["ThreatMapping"] | None) -> "FlatThreatMapping": """Get flat lists of tactic and technique info.""" - tactic_names = [] - tactic_ids = [] - technique_ids = set() - technique_names = set() - sub_technique_ids = set() - sub_technique_names = set() - - for entry in (threat_mappings or []): + tactic_names: list[str] = [] + tactic_ids: list[str] = [] + technique_ids: set[str] = set() + technique_names: set[str] = set() + sub_technique_ids: set[str] = set() + sub_technique_names: set[str] = set() + + for entry in threat_mappings or []: tactic_names.append(entry.tactic.name) tactic_ids.append(entry.tactic.id) - for technique in (entry.technique or []): + for technique in entry.technique or []: technique_names.add(technique.name) technique_ids.add(technique.id) - for subtechnique in (technique.subtechnique or []): + for subtechnique in technique.subtechnique or []: sub_technique_ids.add(subtechnique.id) sub_technique_names.add(subtechnique.name) @@ -241,48 +258,49 @@ def flatten(threat_mappings: Optional[List]) -> 'FlatThreatMapping': technique_names=sorted(technique_names), technique_ids=sorted(technique_ids), sub_technique_names=sorted(sub_technique_names), - sub_technique_ids=sorted(sub_technique_ids) + sub_technique_ids=sorted(sub_technique_ids), ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RiskScoreMapping(MarshmallowDataclassMixin): field: str - operator: Optional[definitions.Operator] - value: Optional[str] + operator: definitions.Operator | None = None + value: str | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SeverityMapping(MarshmallowDataclassMixin): field: str - operator: Optional[definitions.Operator] - value: Optional[str] - severity: Optional[str] + operator: definitions.Operator | None = None + value: str | None = None + severity: str | None = None @dataclass(frozen=True) class FlatThreatMapping(MarshmallowDataclassMixin): - tactic_names: List[str] - tactic_ids: List[str] - technique_names: List[str] - technique_ids: List[str] - sub_technique_names: List[str] - sub_technique_ids: List[str] + tactic_names: list[str] + tactic_ids: list[str] + technique_names: list[str] + technique_ids: list[str] + sub_technique_names: list[str] + sub_technique_ids: list[str] @dataclass(frozen=True) class AlertSuppressionDuration: """Mapping to alert suppression duration.""" + unit: definitions.TimeUnits value: definitions.AlertSuppressionValue -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AlertSuppressionMapping(MarshmallowDataclassMixin, StackCompatMixin): """Mapping to alert suppression.""" group_by: definitions.AlertSuppressionGroupBy - duration: Optional[AlertSuppressionDuration] + duration: AlertSuppressionDuration | None = None missing_fields_strategy: definitions.AlertSuppressionMissing @@ -298,19 +316,19 @@ class FilterStateStore: store: definitions.StoreType -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class FilterMeta: - alias: Optional[Union[str, None]] = None - disabled: Optional[bool] = None - negate: Optional[bool] = None - controlledBy: Optional[str] = None # identify who owns the filter - group: Optional[str] = None # allows grouping of filters - index: Optional[str] = None - isMultiIndex: Optional[bool] = None - type: Optional[str] = None - key: Optional[str] = None - params: Optional[str] = None # Expand to FilterMetaParams when needed - value: Optional[str] = None + alias: str | None = None + disabled: bool | None = None + negate: bool | None = None + controlledBy: str | None # identify who owns the filter + group: str | None # allows grouping of filters + index: str | None = None + isMultiIndex: bool | None = None + type: str | None = None + key: str | None = None + params: str | None = None # Expand to FilterMetaParams when needed + value: str | None = None @dataclass(frozen=True) @@ -319,28 +337,29 @@ class WildcardQuery: value: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Query: - wildcard: Optional[Dict[str, WildcardQuery]] = None + wildcard: dict[str, WildcardQuery] | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Filter: """Kibana Filter for Base Rule Data.""" + # TODO: Currently unused in BaseRuleData. Revisit to extend or remove. # https://github.com/elastic/detection-rules/issues/3773 meta: FilterMeta - state: Optional[FilterStateStore] = field(metadata=dict(data_key="$state")) - query: Optional[Union[Query, Dict[str, Any]]] = None + state: FilterStateStore | None = field(metadata=dict(data_key="$state")) + query: Query | dict[str, Any] | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): """Base rule data.""" @dataclass class InvestigationFields: - field_names: List[definitions.NonEmptyStr] + field_names: list[definitions.NonEmptyStr] @dataclass class RequiredFields: @@ -352,53 +371,52 @@ class RequiredFields: class RelatedIntegrations: package: definitions.NonEmptyStr version: definitions.NonEmptyStr - integration: Optional[definitions.NonEmptyStr] + integration: definitions.NonEmptyStr | None = None - actions: Optional[list] - author: List[str] - building_block_type: Optional[definitions.BuildingBlockType] - description: str - enabled: Optional[bool] - exceptions_list: Optional[list] - license: Optional[str] - false_positives: Optional[List[str]] - filters: Optional[List[dict]] - # trailing `_` required since `from` is a reserved word in python - from_: Optional[str] = field(metadata=dict(data_key="from")) - interval: Optional[definitions.Interval] - investigation_fields: Optional[InvestigationFields] = field(metadata=dict(metadata=dict(min_compat="8.11"))) - max_signals: Optional[definitions.MaxSignals] - meta: Optional[Dict[str, Any]] name: definitions.RuleName - note: Optional[definitions.Markdown] - # can we remove this comment? - # explicitly NOT allowed! - # output_index: Optional[str] - references: Optional[List[str]] - related_integrations: Optional[List[RelatedIntegrations]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - revision: Optional[int] = field(metadata=dict(metadata=dict(min_compat="8.8"))) + + author: list[str] + description: str + from_: str | None = field(metadata=dict(data_key="from")) + investigation_fields: InvestigationFields | None = field(metadata=dict(metadata=dict(min_compat="8.11"))) + related_integrations: list[RelatedIntegrations] | None = field(metadata=dict(metadata=dict(min_compat="8.3"))) + required_fields: list[RequiredFields] | None = field(metadata=dict(metadata=dict(min_compat="8.3"))) + revision: int | None = field(metadata=dict(metadata=dict(min_compat="8.8"))) + setup: definitions.Markdown | None = field(metadata=dict(metadata=dict(min_compat="8.3"))) + risk_score: definitions.RiskScore - risk_score_mapping: Optional[List[RiskScoreMapping]] rule_id: definitions.UUIDString - rule_name_override: Optional[str] - setup: Optional[definitions.Markdown] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - severity_mapping: Optional[List[SeverityMapping]] severity: definitions.Severity - tags: Optional[List[str]] - throttle: Optional[str] - timeline_id: Optional[definitions.TimelineTemplateId] - timeline_title: Optional[definitions.TimelineTemplateTitle] - timestamp_override: Optional[str] - to: Optional[str] type: definitions.RuleType - threat: Optional[List[ThreatMapping]] - version: Optional[definitions.PositiveInteger] + + actions: list[dict[str, Any]] | None = None + building_block_type: definitions.BuildingBlockType | None = None + enabled: bool | None = None + exceptions_list: list[dict[str, str]] | None = None + false_positives: list[str] | None = None + filters: list[dict[str, Any]] | None = None + interval: definitions.Interval | None = None + license: str | None = None + max_signals: definitions.MaxSignals | None = None + meta: dict[str, Any] | None = None + note: definitions.Markdown | None = None + references: list[str] | None = None + risk_score_mapping: list[RiskScoreMapping] | None = None + rule_name_override: str | None = None + severity_mapping: list[SeverityMapping] | None = None + tags: list[str] | None = None + threat: list[ThreatMapping] | None = None + throttle: str | None = None + timeline_id: definitions.TimelineTemplateId | None = None + timeline_title: definitions.TimelineTemplateTitle | None = None + timestamp_override: str | None = None + to: str | None = None + version: definitions.PositiveInteger | None = None @classmethod def save_schema(cls): """Save the schema as a jsonschema.""" - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(cls) + fields: tuple[dataclasses.Field[Any], ...] = dataclasses.fields(cls) type_field = next(f for f in fields if f.name == "type") rule_type = typing.get_args(type_field.type)[0] if cls != BaseRuleData else "base" schema = cls.jsonschema() @@ -409,36 +427,36 @@ def save_schema(cls): with (version_dir / f"master.{rule_type}.json").open("w") as f: json.dump(schema, f, indent=2, sort_keys=True) - def validate_query(self, meta: RuleMeta) -> None: + def validate_query(self, _: RuleMeta) -> None: pass @cached_property - def get_restricted_fields(self) -> Optional[Dict[str, tuple]]: + def get_restricted_fields(self) -> dict[str, tuple[Version | None, Version | None]] | None: """Get stack version restricted fields.""" - fields: List[dataclasses.Field, ...] = list(dataclasses.fields(self)) + fields: list[dataclasses.Field[Any]] = list(dataclasses.fields(self)) return get_restricted_fields(fields) @cached_property - def data_validator(self) -> Optional['DataValidator']: + def data_validator(self) -> "DataValidator | None": return DataValidator(is_elastic_rule=self.is_elastic_rule, **self.to_dict()) @cached_property def notify(self) -> bool: - return os.environ.get('DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE') is not None + return os.environ.get("DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE") is not None @cached_property - def parsed_note(self) -> Optional[MarkoDocument]: + def parsed_note(self) -> MarkoDocument | None: dv = self.data_validator if dv: return dv.parsed_note @property def is_elastic_rule(self): - return 'elastic' in [a.lower() for a in self.author] + return "elastic" in [a.lower() for a in self.author] - def get_build_fields(self) -> {}: + def get_build_fields(self) -> dict[str, tuple[Version, None]]: """Get a list of build-time fields along with the stack versions which they will build within.""" - build_fields = {} + build_fields: dict[str, tuple[Version, None]] = {} rule_fields = {f.name: f for f in dataclasses.fields(self)} for fld in BUILD_FIELD_VERSIONS: @@ -448,37 +466,36 @@ def get_build_fields(self) -> {}: return build_fields @classmethod - def process_transforms(cls, transform: RuleTransform, obj: dict) -> dict: + def process_transforms(cls, transform: RuleTransform, obj: dict[str, Any]) -> dict[str, Any]: """Process transforms from toml [transform] called in TOMLRuleContents.to_dict.""" # only create functions that CAREFULLY mutate the obj dict def process_note_plugins(): """Format the note field with osquery and investigate plugin strings.""" - note = obj.get('note') + note = obj.get("note") if not note: return rendered = transform.render_investigate_osquery_to_string() - rendered_patterns = {} + rendered_patterns: dict[str, Any] = {} for plugin, entries in rendered.items(): - rendered_patterns.update(**{f'{plugin}_{i}': e for i, e in enumerate(entries)}) + rendered_patterns.update(**{f"{plugin}_{i}": e for i, e in enumerate(entries)}) # type: ignore[reportUnknownMemberType] note_template = PatchedTemplate(note) rendered_note = note_template.safe_substitute(**rendered_patterns) - obj['note'] = rendered_note + obj["note"] = rendered_note # call transform functions - if transform: - process_note_plugins() + process_note_plugins() return obj @validates_schema - def validates_data(self, data, **kwargs): + def validates_data(self, data: dict[str, Any], **_: Any): """Validate fields and data for marshmallow schemas.""" # Validate version and revision fields not supplied. - disallowed_fields = [field for field in ['version', 'revision'] if data.get(field) is not None] + disallowed_fields = [field for field in ["version", "revision"] if data.get(field) is not None] if not disallowed_fields: return @@ -486,35 +503,39 @@ def validates_data(self, data, **kwargs): # If version and revision fields are supplied, and using locked versions raise an error. if BYPASS_VERSION_LOCK is not True: - msg = (f"Configuration error: Rule {data['name']} - {data['rule_id']} " - f"should not contain rules with `{error_message}` set.") + msg = ( + f"Configuration error: Rule {data['name']} - {data['rule_id']} " + f"should not contain rules with `{error_message}` set." + ) raise ValidationError(msg) class DataValidator: """Additional validation beyond base marshmallow schema validation.""" - def __init__(self, - name: definitions.RuleName, - is_elastic_rule: bool, - note: Optional[definitions.Markdown] = None, - interval: Optional[definitions.Interval] = None, - building_block_type: Optional[definitions.BuildingBlockType] = None, - setup: Optional[str] = None, - **extras): + def __init__( + self, + name: definitions.RuleName, + is_elastic_rule: bool, + note: definitions.Markdown | None = None, + interval: definitions.Interval | None = None, + building_block_type: definitions.BuildingBlockType | None = None, + setup: str | None = None, + **extras: Any, + ): # only define fields needing additional validation self.name = name self.is_elastic_rule = is_elastic_rule self.note = note # Need to use extras because from is a reserved word in python - self.from_ = extras.get('from') + self.from_ = extras.get("from") self.interval = interval self.building_block_type = building_block_type self.setup = setup self._setup_in_note = False @cached_property - def parsed_note(self) -> Optional[MarkoDocument]: + def parsed_note(self) -> MarkoDocument | None: if self.note: return gfm.parse(self.note) @@ -528,11 +549,11 @@ def setup_in_note(self, value: bool): @cached_property def skip_validate_note(self) -> bool: - return os.environ.get('DR_BYPASS_NOTE_VALIDATION_AND_PARSE') is not None + return os.environ.get("DR_BYPASS_NOTE_VALIDATION_AND_PARSE") is not None @cached_property def skip_validate_bbr(self) -> bool: - return os.environ.get('DR_BYPASS_BBR_LOOKBACK_VALIDATION') is not None + return os.environ.get("DR_BYPASS_BBR_LOOKBACK_VALIDATION") is not None def validate_bbr(self, bypass: bool = False): """Validate building block type and rule type.""" @@ -586,31 +607,35 @@ def validate_note(self): if self.skip_validate_note or not self.note: return + if not self.parsed_note: + return + try: for child in self.parsed_note.children: if child.get_type() == "Heading": header = gfm.renderer.render_children(child) if header.lower() == "setup": - # check that the Setup header is correctly formatted at level 2 - if child.level != 2: - raise ValidationError(f"Setup section with wrong header level: {child.level}") + if child.level != 2: # type: ignore[reportAttributeAccessIssue] + raise ValidationError(f"Setup section with wrong header level: {child.level}") # type: ignore[reportAttributeAccessIssue] # check that the Setup header is capitalized - if child.level == 2 and header != "Setup": + if child.level == 2 and header != "Setup": # type: ignore[reportAttributeAccessIssue] raise ValidationError(f"Setup header has improper casing: {header}") self.setup_in_note = True else: # check that the header Config does not exist in the Setup section - if child.level == 2 and "config" in header.lower(): + if child.level == 2 and "config" in header.lower(): # type: ignore[reportAttributeAccessIssue] raise ValidationError(f"Setup header contains Config: {header}") except Exception as e: - raise ValidationError(f"Invalid markdown in rule `{self.name}`: {e}. To bypass validation on the `note`" - f"field, use the environment variable `DR_BYPASS_NOTE_VALIDATION_AND_PARSE`") + raise ValidationError( + f"Invalid markdown in rule `{self.name}`: {e}. To bypass validation on the `note`" + f"field, use the environment variable `DR_BYPASS_NOTE_VALIDATION_AND_PARSE`" + ) # raise if setup header is in note and in setup if self.setup_in_note and (self.setup and self.setup != "None"): @@ -629,87 +654,91 @@ def ast(self) -> Any: def unique_fields(self) -> Any: raise NotImplementedError() - def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: + def validate(self, _: "QueryRuleData", __: RuleMeta) -> None: raise NotImplementedError() @cached - def get_required_fields(self, index: str) -> List[Optional[dict]]: + def get_required_fields(self, index: str) -> list[dict[str, Any]]: """Retrieves fields needed for the query along with type information from the schema.""" if isinstance(self, ESQLValidator): return [] current_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - ecs_version = get_stack_schemas()[str(current_version)]['ecs'] - beats_version = get_stack_schemas()[str(current_version)]['beats'] - endgame_version = get_stack_schemas()[str(current_version)]['endgame'] + ecs_version = get_stack_schemas()[str(current_version)]["ecs"] + beats_version = get_stack_schemas()[str(current_version)]["beats"] + endgame_version = get_stack_schemas()[str(current_version)]["endgame"] ecs_schema = ecs.get_schema(ecs_version) - beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version) + _, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version) endgame_schema = self.get_endgame_schema(index or [], endgame_version) # construct integration schemas packages_manifest = load_integrations_manifests() integrations_schemas = load_integrations_schemas() datasets, _ = beats.get_datasets_and_modules(self.ast) - package_integrations = parse_datasets(datasets, packages_manifest) - int_schema = {} + package_integrations = parse_datasets(list(datasets), packages_manifest) + int_schema: dict[str, Any] = {} data = {"notify": False} for pk_int in package_integrations: package = pk_int["package"] integration = pk_int["integration"] - schema, _ = get_integration_schema_fields(integrations_schemas, package, integration, - current_version, packages_manifest, {}, data) + schema, _ = get_integration_schema_fields( + integrations_schemas, package, integration, current_version, packages_manifest, {}, data + ) int_schema.update(schema) - required = [] - unique_fields = self.unique_fields or [] + required: list[dict[str, Any]] = [] + unique_fields: list[str] = self.unique_fields or [] for fld in unique_fields: - field_type = ecs_schema.get(fld, {}).get('type') + field_type = ecs_schema.get(fld, {}).get("type") is_ecs = field_type is not None if not is_ecs: if int_schema: - field_type = int_schema.get(fld, None) + field_type = int_schema.get(fld) elif beat_schema: - field_type = beat_schema.get(fld, {}).get('type') + field_type = beat_schema.get(fld, {}).get("type") elif endgame_schema: field_type = endgame_schema.endgame_schema.get(fld, None) - required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs)) + required.append(dict(name=fld, type=field_type or "unknown", ecs=is_ecs)) - return sorted(required, key=lambda f: f['name']) + return sorted(required, key=lambda f: f["name"]) @cached - def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict): + def get_beats_schema( + self, indices: list[str], beats_version: str, ecs_version: str + ) -> tuple[list[str], dict[str, Any] | None, dict[str, Any]]: """Get an assembled beats schema.""" - beat_types = beats.parse_beats_from_index(index) + beat_types = beats.parse_beats_from_index(indices) beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None - schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema) + schema = ecs.get_kql_schema(version=ecs_version, indexes=indices, beat_schema=beat_schema) return beat_types, beat_schema, schema @cached - def get_endgame_schema(self, index: list, endgame_version: str) -> Optional[endgame.EndgameSchema]: + def get_endgame_schema(self, indices: list[str], endgame_version: str) -> endgame.EndgameSchema | None: """Get an assembled flat endgame schema.""" - if index and "endgame-*" not in index: + if indices and "endgame-*" not in indices: return None endgame_schema = endgame.read_endgame_schema(endgame_version=endgame_version) return endgame.EndgameSchema(endgame_schema) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class QueryRuleData(BaseRuleData): """Specific fields for query event types.""" - type: Literal["query"] - index: Optional[List[str]] - data_view_id: Optional[str] + type: Literal["query"] query: str language: definitions.FilterLanguages - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.8"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata=dict(metadata=dict(min_compat="8.8"))) + + index: list[str] | None = None + data_view_id: str | None = None @cached_property def index_or_dataview(self) -> list[str]: @@ -722,7 +751,7 @@ def index_or_dataview(self) -> list[str]: return [] @cached_property - def validator(self) -> Optional[QueryValidator]: + def validator(self) -> QueryValidator | None: if self.language == "kuery": return KQLValidator(self.query) elif self.language == "eql": @@ -730,10 +759,10 @@ def validator(self) -> Optional[QueryValidator]: elif self.language == "esql": return ESQLValidator(self.query) - def validate_query(self, meta: RuleMeta) -> None: + def validate_query(self, meta: RuleMeta) -> None: # type: ignore[reportIncompatibleMethodOverride] validator = self.validator - if validator is not None: - return validator.validate(self, meta) + if validator: + validator.validate(self, meta) @cached_property def ast(self): @@ -748,32 +777,32 @@ def unique_fields(self): return validator.unique_fields @cached - def get_required_fields(self, index: str) -> List[dict]: + def get_required_fields(self, index: str) -> list[dict[str, Any]] | None: validator = self.validator if validator is not None: return validator.get_required_fields(index or []) @validates_schema - def validates_index_and_data_view_id(self, data, **kwargs): + def validates_index_and_data_view_id(self, data: dict[str, Any], **_: Any): """Validate that either index or data_view_id is set, but not both.""" - if data.get('index') and data.get('data_view_id'): + if data.get("index") and data.get("data_view_id"): raise ValidationError("Only one of index or data_view_id should be set.") -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MachineLearningRuleData(BaseRuleData): type: Literal["machine_learning"] anomaly_threshold: int - machine_learning_job_id: Union[str, List[str]] - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.15"))) + machine_learning_job_id: str | list[str] + alert_suppression: AlertSuppressionMapping | None = field(metadata=dict(metadata=dict(min_compat="8.15"))) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ThresholdQueryRuleData(QueryRuleData): """Specific fields for query event types.""" - @dataclass(frozen=True) + @dataclass(frozen=True, kw_only=True) class ThresholdMapping(MarshmallowDataclassMixin): @dataclass(frozen=True) class ThresholdCardinality: @@ -782,14 +811,14 @@ class ThresholdCardinality: field: definitions.CardinalityFields value: definitions.ThresholdValue - cardinality: Optional[List[ThresholdCardinality]] + cardinality: list[ThresholdCardinality] | None = None - type: Literal["threshold"] + type: Literal["threshold"] # type: ignore[reportIncompatibleVariableOverride] threshold: ThresholdMapping - alert_suppression: Optional[ThresholdAlertSuppression] = field(metadata=dict(metadata=dict(min_compat="8.12"))) + alert_suppression: ThresholdAlertSuppression | None = field(metadata=dict(metadata=dict(min_compat="8.12"))) # type: ignore[reportIncompatibleVariableOverride] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class NewTermsRuleData(QueryRuleData): """Specific fields for new terms field rule.""" @@ -802,25 +831,20 @@ class HistoryWindowStart: field: definitions.NonEmptyStr value: definitions.NewTermsFields - history_window_start: List[HistoryWindowStart] + history_window_start: list[HistoryWindowStart] - type: Literal["new_terms"] + type: Literal["new_terms"] # type: ignore[reportIncompatibleVariableOverride] new_terms: NewTermsMapping - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.14"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata=dict(metadata=dict(min_compat="8.14"))) @pre_load - def preload_data(self, data: dict, **kwargs) -> dict: + def preload_data(self, data: dict[str, Any], **_: Any) -> dict[str, Any]: """Preloads and formats the data to match the required schema.""" if "new_terms_fields" in data and "history_window_start" in data: new_terms_mapping = { "field": "new_terms_fields", "value": data["new_terms_fields"], - "history_window_start": [ - { - "field": "history_window_start", - "value": data["history_window_start"] - } - ] + "history_window_start": [{"field": "history_window_start", "value": data["history_window_start"]}], } data["new_terms"] = new_terms_mapping @@ -829,7 +853,7 @@ def preload_data(self, data: dict, **kwargs) -> dict: data.pop("history_window_start") return data - def transform(self, obj: dict) -> dict: + def transform(self, obj: dict[str, Any]) -> dict[str, Any]: """Transforms new terms data to API format for Kibana.""" obj[obj["new_terms"].get("field")] = obj["new_terms"].get("value") obj["history_window_start"] = obj["new_terms"]["history_window_start"][0].get("value") @@ -837,22 +861,23 @@ def transform(self, obj: dict) -> dict: return obj -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class EQLRuleData(QueryRuleData): """EQL rules are a special case of query rules.""" - type: Literal["eql"] + + type: Literal["eql"] # type: ignore[reportIncompatibleVariableOverride] language: Literal["eql"] - timestamp_field: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.0"))) - event_category_override: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.0"))) - tiebreaker_field: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.0"))) - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.14"))) + timestamp_field: str | None = field(metadata=dict(metadata=dict(min_compat="8.0"))) + event_category_override: str | None = field(metadata=dict(metadata=dict(min_compat="8.0"))) + tiebreaker_field: str | None = field(metadata=dict(metadata=dict(min_compat="8.0"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata=dict(metadata=dict(min_compat="8.14"))) def convert_relative_delta(self, lookback: str) -> int: now = len("now") - min_length = now + len('+5m') + min_length = now + len("+5m") if lookback.startswith("now") and len(lookback) >= min_length: - lookback = lookback[len("now"):] + lookback = lookback[len("now") :] sign = lookback[0] # + or - span = lookback[1:] amount = convert_time_span(span) @@ -863,59 +888,62 @@ def convert_relative_delta(self, lookback: str) -> int: @cached_property def is_sample(self) -> bool: """Checks if the current rule is a sample-based rule.""" - return eql.utils.get_query_type(self.ast) == 'sample' + return eql.utils.get_query_type(self.ast) == "sample" # type: ignore[reportUnknownMemberType] @cached_property def is_sequence(self) -> bool: """Checks if the current rule is a sequence-based rule.""" - return eql.utils.get_query_type(self.ast) == 'sequence' + return eql.utils.get_query_type(self.ast) == "sequence" # type: ignore[reportUnknownMemberType] @cached_property - def max_span(self) -> Optional[int]: + def max_span(self) -> int | None: """Maxspan value for sequence rules if defined.""" - if self.is_sequence and hasattr(self.ast.first, 'max_span'): + if not self.ast: + raise ValueError("No AST found") + if self.is_sequence and hasattr(self.ast.first, "max_span"): return self.ast.first.max_span.as_milliseconds() if self.ast.first.max_span else None @cached_property - def look_back(self) -> Optional[Union[int, Literal['unknown']]]: + def look_back(self) -> int | Literal["unknown"] | None: """Lookback value of a rule.""" # https://www.elastic.co/guide/en/elasticsearch/reference/current/common-options.html#date-math to = self.convert_relative_delta(self.to) if self.to else 0 from_ = self.convert_relative_delta(self.from_ or "now-6m") if not (to or from_): - return 'unknown' + return "unknown" else: return to - from_ @cached_property - def interval_ratio(self) -> Optional[float]: + def interval_ratio(self) -> float | None: """Ratio of interval time window / max_span time window.""" if self.max_span: - interval = convert_time_span(self.interval or '5m') + interval = convert_time_span(self.interval or "5m") return interval / self.max_span -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ESQLRuleData(QueryRuleData): """ESQL rules are a special case of query rules.""" - type: Literal["esql"] + + type: Literal["esql"] # type: ignore[reportIncompatibleVariableOverride] language: Literal["esql"] query: str - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.15"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata=dict(metadata=dict(min_compat="8.15"))) @validates_schema - def validates_esql_data(self, data, **kwargs): + def validates_esql_data(self, data: dict[str, Any], **_: Any): """Custom validation for query rule type and subclasses.""" - if data.get('index'): + if data.get("index"): raise ValidationError("Index is not a valid field for ES|QL rule type.") # Convert the query string to lowercase to handle case insensitivity - query_lower = data['query'].lower() + query_lower = data["query"].lower() # Combine both patterns using an OR operator and compile the regex combined_pattern = re.compile( - r'(from\s+\S+\s+metadata\s+_id,\s*_version,\s*_index)|(\bstats\b.*?\bby\b)', re.DOTALL + r"(from\s+\S+\s+metadata\s+_id,\s*_version,\s*_index)|(\bstats\b.*?\bby\b)", re.DOTALL ) # Ensure that non-aggregate queries have metadata @@ -927,47 +955,45 @@ def validates_esql_data(self, data, **kwargs): ) # Enforce KEEP command for ESQL rules - if '| keep' not in query_lower: + if "| keep" not in query_lower: raise ValidationError( - f"Rule: {data['name']} does not contain a 'keep' command ->" - f" Add a 'keep' command to the query." + f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query." ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ThreatMatchRuleData(QueryRuleData): """Specific fields for indicator (threat) match rule.""" @dataclass(frozen=True) class Entries: - @dataclass(frozen=True) class ThreatMapEntry: field: definitions.NonEmptyStr type: Literal["mapping"] value: definitions.NonEmptyStr - entries: List[ThreatMapEntry] + entries: list[ThreatMapEntry] - type: Literal["threat_match"] + type: Literal["threat_match"] # type: ignore[reportIncompatibleVariableOverride] - concurrent_searches: Optional[definitions.PositiveInteger] - items_per_search: Optional[definitions.PositiveInteger] + concurrent_searches: definitions.PositiveInteger | None = None + items_per_search: definitions.PositiveInteger | None = None - threat_mapping: List[Entries] - threat_filters: Optional[List[dict]] - threat_query: Optional[str] - threat_language: Optional[definitions.FilterLanguages] - threat_index: List[str] - threat_indicator_path: Optional[str] - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.13"))) + threat_mapping: list[Entries] + threat_filters: list[dict[str, Any]] | None = None + threat_query: str | None = None + threat_language: definitions.FilterLanguages | None = None + threat_index: list[str] + threat_indicator_path: str | None = None + alert_suppression: AlertSuppressionMapping | None = field(metadata=dict(metadata=dict(min_compat="8.13"))) def validate_query(self, meta: RuleMeta) -> None: super(ThreatMatchRuleData, self).validate_query(meta) if self.threat_query: if not self.threat_language: - raise ValidationError('`threat_language` required when a `threat_query` is defined') + raise ValidationError("`threat_language` required when a `threat_query` is defined") if self.threat_language == "kuery": threat_query_validator = KQLValidator(self.threat_query) @@ -981,8 +1007,15 @@ def validate_query(self, meta: RuleMeta) -> None: # All of the possible rule types # Sort inverse of any inheritance - see comment in TOMLRuleContents.to_dict -AnyRuleData = Union[EQLRuleData, ESQLRuleData, ThresholdQueryRuleData, ThreatMatchRuleData, - MachineLearningRuleData, QueryRuleData, NewTermsRuleData] +AnyRuleData = ( + EQLRuleData + | ESQLRuleData + | ThresholdQueryRuleData + | ThreatMatchRuleData + | MachineLearningRuleData + | QueryRuleData + | NewTermsRuleData +) class BaseRuleContents(ABC): @@ -990,28 +1023,27 @@ class BaseRuleContents(ABC): @property @abstractmethod - def id(self): + def id(self) -> str: pass @property @abstractmethod - def name(self): + def name(self) -> str: pass @property @abstractmethod - def version_lock(self): + def version_lock(self) -> "VersionLock": pass @property @abstractmethod - def type(self): + def type(self) -> str: pass - def lock_info(self, bump=True) -> dict: + def lock_info(self, bump: bool = True) -> dict[str, Any]: version = self.autobumped_version if bump else (self.saved_version or 1) contents = {"rule_name": self.name, "sha256": self.get_hash(), "version": version, "type": self.type} - return contents @property @@ -1031,7 +1063,7 @@ def is_dirty(self) -> bool: return is_dirty @property - def lock_entry(self) -> Optional[dict]: + def lock_entry(self) -> dict[str, Any] | None: lock_entry = self.version_lock.version_lock.data.get(self.id) if lock_entry: return lock_entry.to_dict() @@ -1041,7 +1073,7 @@ def has_forked(self) -> bool: """Determine if the rule has forked at any point (has a previous entry).""" lock_entry = self.lock_entry if lock_entry: - return 'previous' in lock_entry + return "previous" in lock_entry return False @property @@ -1049,36 +1081,44 @@ def is_in_forked_version(self) -> bool: """Determine if the rule is in a forked version.""" if not self.has_forked: return False - locked_min_stack = Version.parse(self.lock_entry['min_stack_version'], optional_minor_and_patch=True) + if not self.lock_entry: + raise ValueError("No lock entry found") + locked_min_stack = Version.parse(self.lock_entry["min_stack_version"], optional_minor_and_patch=True) current_package_ver = Version.parse(load_current_package_version(), optional_minor_and_patch=True) return current_package_ver < locked_min_stack - def get_version_space(self) -> Optional[int]: + def get_version_space(self) -> int | None: """Retrieve the number of version spaces available (None for unbound).""" if self.is_in_forked_version: - current_entry = self.lock_entry['previous'][self.metadata.min_stack_version] - current_version = current_entry['version'] - max_allowable_version = current_entry['max_allowable_version'] + if not self.lock_entry: + raise ValueError("No lock entry found") + + current_entry = self.lock_entry["previous"][self.metadata.min_stack_version] # type: ignore[reportAttributeAccessIssue] + current_version = current_entry["version"] + max_allowable_version = current_entry["max_allowable_version"] return max_allowable_version - current_version - 1 @property - def saved_version(self) -> Optional[int]: + def saved_version(self) -> int | None: """Retrieve the version from the version.lock or from the file if version locking is bypassed.""" - toml_version = self.data.get("version") + + toml_version = self.data.get("version") # type: ignore[reportAttributeAccessIssue] if BYPASS_VERSION_LOCK: - return toml_version + return toml_version # type: ignore[reportUnknownVariableType] if toml_version: - print(f"WARNING: Rule {self.name} - {self.id} has a version set in the rule TOML." - " This `version` will be ignored and defaulted to the version.lock.json file." - " Set `bypass_version_lock` to `True` in the rules config to use the TOML version.") + print( + f"WARNING: Rule {self.name} - {self.id} has a version set in the rule TOML." + " This `version` will be ignored and defaulted to the version.lock.json file." + " Set `bypass_version_lock` to `True` in the rules config to use the TOML version." + ) return self.version_lock.get_locked_version(self.id, self.get_supported_version()) @property - def autobumped_version(self) -> Optional[int]: + def autobumped_version(self) -> int | None: """Retrieve the current version of the rule, accounting for automatic increments.""" version = self.saved_version @@ -1092,7 +1132,7 @@ def autobumped_version(self) -> Optional[int]: # Auto-increment version if the rule is 'dirty' and not bypassing version lock return version + 1 if self.is_dirty else version - def get_synthetic_version(self, use_default: bool) -> Optional[int]: + def get_synthetic_version(self, use_default: bool) -> int | None: """ Get the latest actual representation of a rule's version, where changes are accounted for automatically when version locking is used, otherwise, return the version defined in the rule toml if present else optionally @@ -1101,7 +1141,7 @@ def get_synthetic_version(self, use_default: bool) -> Optional[int]: return self.autobumped_version or self.saved_version or (1 if use_default else None) @classmethod - def convert_supported_version(cls, stack_version: Optional[str]) -> Version: + def convert_supported_version(cls, stack_version: str | None) -> Version: """Convert an optional stack version to the minimum for the lock in the form major.minor.""" min_version = get_min_supported_stack_version() if stack_version is None: @@ -1110,11 +1150,11 @@ def convert_supported_version(cls, stack_version: Optional[str]) -> Version: def get_supported_version(self) -> str: """Get the lowest stack version for the rule that is currently supported in the form major.minor.""" - rule_min_stack = self.metadata.get('min_stack_version') - min_stack = self.convert_supported_version(rule_min_stack) + rule_min_stack = self.metadata.get("min_stack_version") # type: ignore[reportAttributeAccessIssue] + min_stack = self.convert_supported_version(rule_min_stack) # type: ignore[reportUnknownArgumentType] return f"{min_stack.major}.{min_stack.minor}" - def _post_dict_conversion(self, obj: dict) -> dict: + def _post_dict_conversion(self, obj: dict[str, Any]) -> dict[str, Any]: """Transform the converted API in place before sending to Kibana.""" # cleanup the whitespace in the rule @@ -1127,10 +1167,10 @@ def _post_dict_conversion(self, obj: dict) -> dict: return obj @abstractmethod - def to_api_format(self, include_version: bool = True) -> dict: + def to_api_format(self, include_version: bool = True) -> dict[str, Any]: """Convert the rule to the API format.""" - def get_hashable_content(self, include_version: bool = False, include_integrations: bool = False) -> dict: + def get_hashable_content(self, include_version: bool = False, include_integrations: bool = False) -> dict[str, Any]: """Returns the rule content to be used for calculating the hash value for the rule""" # get the API dict without the version by default, otherwise it'll always be dirty. @@ -1155,38 +1195,35 @@ def get_hash(self, include_version: bool = False, include_integrations: bool = F @dataclass(frozen=True) class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin): """Rule object which maps directly to the TOML layout.""" + metadata: RuleMeta - transform: Optional[RuleTransform] data: AnyRuleData = field(metadata=dict(data_key="rule")) + transform: RuleTransform | None = None @cached_property - def version_lock(self): - # VersionLock - from .version_lock import loaded_version_lock - + def version_lock(self) -> VersionLock: # type: ignore[reportIncompatibleMethodOverride] if RULES_CONFIG.bypass_version_lock is True: - err_msg = "Cannot access the version lock when the versioning strategy is configured to bypass the" \ - " version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock." + err_msg = ( + "Cannot access the version lock when the versioning strategy is configured to bypass the" + " version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) raise ValueError(err_msg) - return getattr(self, '_version_lock', None) or loaded_version_lock + return getattr(self, "_version_lock", None) or loaded_version_lock - def set_version_lock(self, value): - from .version_lock import VersionLock - - err_msg = "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." \ - " Set `bypass_version_lock` to `false` in the rules config to use the version lock." + def set_version_lock(self, value: VersionLock): + err_msg = ( + "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." + " Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) assert not RULES_CONFIG.bypass_version_lock, err_msg - if value and not isinstance(value, VersionLock): - raise TypeError(f'version lock property must be set with VersionLock objects only. Got {type(value)}') - # circumvent frozen class - self.__dict__['_version_lock'] = value + self.__dict__["_version_lock"] = value @classmethod - def all_rule_types(cls) -> set: - types = set() + def all_rule_types(cls) -> set[str]: + types: set[str] = set() for subclass in typing.get_args(AnyRuleData): field = next(field for field in dataclasses.fields(subclass) if field.name == "type") types.update(typing.get_args(field.type)) @@ -1198,7 +1235,7 @@ def get_data_subclass(cls, rule_type: str) -> typing.Type[BaseRuleData]: """Get the proper subclass depending on the rule type""" for subclass in typing.get_args(AnyRuleData): field = next(field for field in dataclasses.fields(subclass) if field.name == "type") - if (rule_type, ) == typing.get_args(field.type): + if (rule_type,) == typing.get_args(field.type): return subclass raise ValueError(f"Unknown rule type {rule_type}") @@ -1215,23 +1252,25 @@ def name(self) -> str: def type(self) -> str: return self.data.type - def _add_known_nulls(self, rule_dict: dict) -> dict: + def _add_known_nulls(self, rule_dict: dict[str, Any]) -> dict[str, Any]: """Add known nulls to the rule.""" # Note this is primarily as a stopgap until add support for Rule Actions for pair in definitions.KNOWN_NULL_ENTRIES: for compound_key, sub_key in pair.items(): value = get_nested_value(rule_dict, compound_key) if isinstance(value, list): - items_to_update = [ - item for item in value if isinstance(item, dict) and get_nested_value(item, sub_key) is None + items_to_update: list[dict[str, Any]] = [ + item + for item in value # type: ignore[reportUnknownVariableType] + if isinstance(item, dict) and get_nested_value(item, sub_key) is None ] for item in items_to_update: set_nested_value(item, sub_key, None) return rule_dict - def _post_dict_conversion(self, obj: dict) -> dict: + def _post_dict_conversion(self, obj: dict[str, Any]) -> dict[str, Any]: """Transform the converted API in place before sending to Kibana.""" - super()._post_dict_conversion(obj) + _ = super()._post_dict_conversion(obj) # build time fields self._convert_add_related_integrations(obj) @@ -1239,16 +1278,16 @@ def _post_dict_conversion(self, obj: dict) -> dict: self._convert_add_setup(obj) # validate new fields against the schema - rule_type = obj['type'] + rule_type = obj["type"] subclass = self.get_data_subclass(rule_type) subclass.from_dict(obj) # rule type transforms - self.data.transform(obj) if hasattr(self.data, 'transform') else False + self.data.transform(obj) if hasattr(self.data, "transform") else False # type: ignore[reportAttributeAccessIssue] return obj - def _convert_add_related_integrations(self, obj: dict) -> None: + def _convert_add_related_integrations(self, obj: dict[str, Any]) -> None: """Add restricted field related_integrations to the obj.""" field_name = "related_integrations" package_integrations = obj.get(field_name, []) @@ -1258,11 +1297,15 @@ def _convert_add_related_integrations(self, obj: dict) -> None: current_stack_version = load_current_package_version() if self.check_restricted_field_version(field_name): - if (isinstance(self.data, QueryRuleData) or isinstance(self.data, MachineLearningRuleData)): - if (self.data.get('language') is not None and self.data.get('language') != 'lucene') or \ - self.data.get('type') == 'machine_learning': - package_integrations = self.get_packaged_integrations(self.data, self.metadata, - packages_manifest) + if isinstance(self.data, QueryRuleData) or isinstance(self.data, MachineLearningRuleData): # type: ignore[reportUnnecessaryIsInstance] + if ( + self.data.get("language") is not None and self.data.get("language") != "lucene" + ) or self.data.get("type") == "machine_learning": + package_integrations = self.get_packaged_integrations( + self.data, # type: ignore[reportArgumentType] + self.metadata, + packages_manifest, + ) if not package_integrations: return @@ -1272,26 +1315,29 @@ def _convert_add_related_integrations(self, obj: dict) -> None: package=package["package"], integration=package["integration"], current_stack_version=current_stack_version, - packages_manifest=packages_manifest) + packages_manifest=packages_manifest, + ) # if integration is not a policy template remove if package["version"]: - version_data = packages_manifest.get(package["package"], - {}).get(package["version"].strip("^"), {}) + version_data = packages_manifest.get(package["package"], {}).get( + package["version"].strip("^"), {} + ) policy_templates = version_data.get("policy_templates", []) if package["integration"] not in policy_templates: del package["integration"] # remove duplicate entries - package_integrations = list({json.dumps(d, sort_keys=True): - d for d in package_integrations}.values()) + package_integrations = list( + {json.dumps(d, sort_keys=True): d for d in package_integrations}.values() + ) obj.setdefault("related_integrations", package_integrations) - def _convert_add_required_fields(self, obj: dict) -> None: + def _convert_add_required_fields(self, obj: dict[str, Any]) -> None: """Add restricted field required_fields to the obj, derived from the query AST.""" - if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene': - index = obj.get('index') or [] + if isinstance(self.data, QueryRuleData) and self.data.language != "lucene": + index: list[str] = obj.get("index") or [] required_fields = self.data.get_required_fields(index) else: required_fields = [] @@ -1300,7 +1346,7 @@ def _convert_add_required_fields(self, obj: dict) -> None: if required_fields and self.check_restricted_field_version(field_name=field_name): obj.setdefault(field_name, required_fields) - def _convert_add_setup(self, obj: dict) -> None: + def _convert_add_setup(self, obj: dict[str, Any]) -> None: """Add restricted field setup to the obj.""" rule_note = obj.get("note", "") field_name = "setup" @@ -1311,13 +1357,19 @@ def _convert_add_setup(self, obj: dict) -> None: data_validator = self.data.data_validator + if not data_validator: + raise ValueError("No data validator found") + if not data_validator.skip_validate_note and data_validator.setup_in_note and not field_value: parsed_note = self.data.parsed_note + if not parsed_note: + raise ValueError("No parsed note found") + # parse note tree for i, child in enumerate(parsed_note.children): - if child.get_type() == "Heading" and "Setup" in gfm.render(child): - field_value = self._convert_get_setup_content(parsed_note.children[i + 1:]) + if child.get_type() == "Heading" and "Setup" in gfm.render(child): # type: ignore[reportArgumentType] + field_value = self._convert_get_setup_content(parsed_note.children[i + 1 :]) # clean up old note field investigation_guide = rule_note.replace("## Setup\n\n", "") @@ -1327,14 +1379,14 @@ def _convert_add_setup(self, obj: dict) -> None: break @cached - def _convert_get_setup_content(self, note_tree: list) -> str: + def _convert_get_setup_content(self, note_tree: list[Any]) -> str: """Get note paragraph starting from the setup header.""" - setup = [] + setup: list[str] = [] for child in note_tree: if child.get_type() == "BlankLine" or child.get_type() == "LineBreak": setup.append("\n") elif child.get_type() == "CodeSpan": - setup.append(f"`{gfm.renderer.render_raw_text(child)}`") + setup.append(f"`{gfm.renderer.render_raw_text(child)}`") # type: ignore[reportUnknownMemberType] elif child.get_type() == "Paragraph": setup.append(self._convert_get_setup_content(child.children)) setup.append("\n") @@ -1353,11 +1405,17 @@ def _convert_get_setup_content(self, note_tree: list) -> str: def check_explicit_restricted_field_version(self, field_name: str) -> bool: """Explicitly check restricted fields against global min and max versions.""" min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name] + if not min_stack or not max_stack: + return True return self.compare_field_versions(min_stack, max_stack) def check_restricted_field_version(self, field_name: str) -> bool: """Check restricted fields against schema min and max versions.""" - min_stack, max_stack = self.data.get_restricted_fields.get(field_name) + if not self.data.get_restricted_fields: + raise ValueError("No restricted fields found") + min_stack, max_stack = self.data.get_restricted_fields[field_name] + if not min_stack or not max_stack: + return True return self.compare_field_versions(min_stack, max_stack) @staticmethod @@ -1368,10 +1426,14 @@ def compare_field_versions(min_stack: Version, max_stack: Version) -> bool: return min_stack <= current_version >= max_stack @classmethod - def get_packaged_integrations(cls, data: QueryRuleData, meta: RuleMeta, - package_manifest: dict) -> Optional[List[dict]]: - packaged_integrations = [] - datasets, _ = beats.get_datasets_and_modules(data.get('ast') or []) + def get_packaged_integrations( + cls, + data: QueryRuleData, + meta: RuleMeta, + package_manifest: dict[str, Any], + ) -> list[dict[str, Any]] | None: + packaged_integrations: list[dict[str, Any]] = [] + datasets, _ = beats.get_datasets_and_modules(data.get("ast") or []) # type: ignore[reportArgumentType] # integration is None to remove duplicate references upstream in Kibana # chronologically, event.dataset is checked for package:integration, then rule tags @@ -1381,78 +1443,92 @@ def get_packaged_integrations(cls, data: QueryRuleData, meta: RuleMeta, rule_integrations = meta.get("integration", []) if rule_integrations: for integration in rule_integrations: - ineligible_integrations = definitions.NON_DATASET_PACKAGES + \ - [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] + ineligible_integrations = definitions.NON_DATASET_PACKAGES + [ + *map(str.lower, definitions.MACHINE_LEARNING_PACKAGES) + ] if integration in ineligible_integrations or isinstance(data, MachineLearningRuleData): packaged_integrations.append({"package": integration, "integration": None}) - packaged_integrations.extend(parse_datasets(datasets, package_manifest)) + packaged_integrations.extend(parse_datasets(list(datasets), package_manifest)) return packaged_integrations @validates_schema - def post_conversion_validation(self, value: dict, **kwargs): + def post_conversion_validation(self, value: dict[str, Any], **_: Any): """Additional validations beyond base marshmallow schemas.""" data: AnyRuleData = value["data"] metadata: RuleMeta = value["metadata"] + if not data.data_validator: + raise ValueError("No data validator found") + test_config = RULES_CONFIG.test_config - if not test_config.check_skip_by_rule_id(value['data'].rule_id): + if not test_config.check_skip_by_rule_id(value["data"].rule_id): + bypass = metadata.get("bypass_bbr_timing") or False data.validate_query(metadata) data.data_validator.validate_note() - data.data_validator.validate_bbr(metadata.get('bypass_bbr_timing')) - data.validate(metadata) if hasattr(data, 'validate') else False + data.data_validator.validate_bbr(bypass) + data.validate(metadata) if hasattr(data, "validate") else False # type: ignore[reportUnknownMemberType] @staticmethod - def validate_remote(remote_validator: 'RemoteValidator', contents: 'TOMLRuleContents'): - remote_validator.validate_rule(contents) + def validate_remote(remote_validator: "RemoteValidator", contents: "TOMLRuleContents"): + _ = remote_validator.validate_rule(contents) @classmethod def from_rule_resource( - cls, rule: dict, creation_date: str = TIME_NOW, updated_date: str = TIME_NOW, maturity: str = 'development' - ) -> 'TOMLRuleContents': + cls, + rule: dict[str, Any], + creation_date: str = TIME_NOW, + updated_date: str = TIME_NOW, + maturity: str = "development", + ) -> "TOMLRuleContents": """Create a TOMLRuleContents from a kibana rule resource.""" - integrations = [r.get("package") for r in rule.get("related_integrations")] + integrations = [r["package"] for r in rule["related_integrations"]] meta = { "creation_date": creation_date, "updated_date": updated_date, "maturity": maturity, "integration": integrations, } - contents = cls.from_dict({'metadata': meta, 'rule': rule, 'transforms': None}, unknown=marshmallow.EXCLUDE) + contents = cls.from_dict({"metadata": meta, "rule": rule, "transforms": None}, unknown=marshmallow.EXCLUDE) return contents - def to_dict(self, strip_none_values=True) -> dict: + def to_dict(self, strip_none_values: bool = True) -> dict[str, Any]: # Load schemas directly from the data and metadata classes to avoid schema ambiguity which can # result from union fields which contain classes and related subclasses (AnyRuleData). See issue #1141 metadata = self.metadata.to_dict(strip_none_values=strip_none_values) data = self.data.to_dict(strip_none_values=strip_none_values) - self.data.process_transforms(self.transform, data) + if self.transform: + _ = self.data.process_transforms(self.transform, data) dict_obj = dict(metadata=metadata, rule=data) return nested_normalize(dict_obj) - def flattened_dict(self) -> dict: - flattened = dict() + def flattened_dict(self) -> dict[str, Any]: + flattened: dict[str, Any] = dict() flattened.update(self.data.to_dict()) flattened.update(self.metadata.to_dict()) return flattened - def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK, include_metadata: bool = False) -> dict: + def to_api_format( + self, + include_version: bool = not BYPASS_VERSION_LOCK, + include_metadata: bool = False, + ) -> dict[str, Any]: """Convert the TOML rule to the API format.""" rule_dict = self.to_dict() rule_dict = self._add_known_nulls(rule_dict) - converted_data = rule_dict['rule'] + converted_data = rule_dict["rule"] converted = self._post_dict_conversion(converted_data) if include_metadata: - converted["meta"] = rule_dict['metadata'] + converted["meta"] = rule_dict["metadata"] if include_version: converted["version"] = self.autobumped_version return converted - def check_restricted_fields_compatibility(self) -> Dict[str, dict]: + def check_restricted_fields_compatibility(self) -> dict[str, dict[str, Any]]: """Check for compatibility between restricted fields and the min_stack_version of the rule.""" default_min_stack = get_min_supported_stack_version() if self.metadata.min_stack_version is not None: @@ -1461,12 +1537,19 @@ def check_restricted_fields_compatibility(self) -> Dict[str, dict]: min_stack = default_min_stack restricted = self.data.get_restricted_fields - invalid = {} + if not restricted: + raise ValueError("No restricted fields found") + + invalid: dict[str, dict[str, Any]] = {} for _field, values in restricted.items(): if self.data.get(_field) is not None: min_allowed, _ = values + + if not min_allowed: + raise ValueError("Min allowed versino is None") + if min_stack < min_allowed: - invalid[_field] = {'min_stack_version': min_stack, 'min_allowed_version': min_allowed} + invalid[_field] = {"min_stack_version": min_stack, "min_allowed_version": min_allowed} return invalid @@ -1474,7 +1557,7 @@ def check_restricted_fields_compatibility(self) -> Dict[str, dict]: @dataclass class TOMLRule: contents: TOMLRuleContents = field(hash=True) - path: Optional[Path] = None + path: Path | None = None gh_pr: Any = field(hash=False, compare=False, default=None, repr=False) @property @@ -1485,12 +1568,14 @@ def id(self): def name(self): return self.contents.data.name - def get_asset(self) -> dict: + def get_asset(self) -> dict[str, Any]: """Generate the relevant fleet compatible asset.""" return {"id": self.id, "attributes": self.contents.to_api_format(), "type": definitions.SAVED_OBJECT_TYPE} def get_base_rule_dir(self) -> Path | None: """Get the base rule directory for the rule.""" + if not self.path: + raise ValueError("No path found") rule_path = self.path.resolve() for rules_dir in DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS: if rule_path.is_relative_to(rules_dir): @@ -1505,65 +1590,65 @@ def save_toml(self, strip_none_values: bool = True): ) if self.contents.transform: converted["transform"] = self.contents.transform.to_dict() - toml_write(converted, str(self.path.absolute())) + + if not self.path: + raise ValueError("No path found") + + toml_write(converted, self.path.absolute()) def save_json(self, path: Path, include_version: bool = True): - path = path.with_suffix('.json') - with open(str(path.absolute()), 'w', newline='\n') as f: + path = path.with_suffix(".json") + with open(str(path.absolute()), "w", newline="\n") as f: json.dump(self.contents.to_api_format(include_version=include_version), f, sort_keys=True, indent=2) - f.write('\n') + _ = f.write("\n") @dataclass(frozen=True) class DeprecatedRuleContents(BaseRuleContents): - metadata: dict - data: dict - transform: Optional[dict] + metadata: dict[str, Any] + data: dict[str, Any] + transform: dict[str, Any] | None = None @cached_property - def version_lock(self): + def version_lock(self): # type: ignore[reportIncompatibleMethodOverride] # VersionLock - from .version_lock import loaded_version_lock - - return getattr(self, '_version_lock', None) or loaded_version_lock - def set_version_lock(self, value): - from .version_lock import VersionLock + return getattr(self, "_version_lock", None) or loaded_version_lock - err_msg = "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." \ - " Set `bypass_version_lock` to `false` in the rules config to use the version lock." + def set_version_lock(self, value: "VersionLock | None"): + err_msg = ( + "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." + " Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) assert not RULES_CONFIG.bypass_version_lock, err_msg - if value and not isinstance(value, VersionLock): - raise TypeError(f'version lock property must be set with VersionLock objects only. Got {type(value)}') - # circumvent frozen class - self.__dict__['_version_lock'] = value + self.__dict__["_version_lock"] = value @property - def id(self) -> str: - return self.data.get('rule_id') + def id(self) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + return self.data.get("rule_id") @property - def name(self) -> str: - return self.data.get('name') + def name(self) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + return self.data.get("name") @property - def type(self) -> str: - return self.data.get('type') + def type(self) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + return self.data.get("type") @classmethod - def from_dict(cls, obj: dict): - kwargs = dict(metadata=obj['metadata'], data=obj['rule']) - kwargs['transform'] = obj['transform'] if 'transform' in obj else None + def from_dict(cls, obj: dict[str, Any]): + kwargs = dict(metadata=obj["metadata"], data=obj["rule"]) + kwargs["transform"] = obj["transform"] if "transform" in obj else None return cls(**kwargs) - def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK) -> dict: + def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK) -> dict[str, Any]: """Convert the TOML rule to the API format.""" data = copy.deepcopy(self.data) if self.transform: transform = RuleTransform.from_dict(self.transform) - BaseRuleData.process_transforms(transform, data) + _ = BaseRuleData.process_transforms(transform, data) converted = data if include_version: @@ -1573,33 +1658,36 @@ def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK) -> dict return converted -class DeprecatedRule(dict): +class DeprecatedRule(dict[str, Any]): """Minimal dict object for deprecated rule.""" - def __init__(self, path: Path, contents: DeprecatedRuleContents, *args, **kwargs): + def __init__(self, path: Path, contents: DeprecatedRuleContents, *args: Any, **kwargs: Any): super(DeprecatedRule, self).__init__(*args, **kwargs) self.path = path self.contents = contents def __repr__(self): - return f'{type(self).__name__}(contents={self.contents}, path={self.path})' + return f"{type(self).__name__}(contents={self.contents}, path={self.path})" @property - def id(self) -> str: + def id(self) -> str | None: return self.contents.id @property - def name(self) -> str: + def name(self) -> str | None: return self.contents.name -def downgrade_contents_from_rule(rule: TOMLRule, target_version: str, - replace_id: bool = True, include_metadata: bool = False) -> dict: +def downgrade_contents_from_rule( + rule: TOMLRule, + target_version: str, + replace_id: bool = True, + include_metadata: bool = False, +) -> dict[str, Any]: """Generate the downgraded contents from a rule.""" rule_dict = rule.contents.to_dict()["rule"] min_stack_version = target_version or rule.contents.metadata.min_stack_version or "8.3.0" - min_stack_version = Version.parse(min_stack_version, - optional_minor_and_patch=True) + min_stack_version = Version.parse(min_stack_version, optional_minor_and_patch=True) rule_dict.setdefault("meta", {}).update(rule.contents.metadata.to_dict()) if replace_id: @@ -1618,35 +1706,63 @@ def downgrade_contents_from_rule(rule: TOMLRule, target_version: str, return payload -def set_eql_config(min_stack_version: str) -> eql.parser.ParserConfig: +def set_eql_config(min_stack_version_val: str) -> eql.parser.ParserConfig: """Based on the rule version set the eql functions allowed.""" - if not min_stack_version: - min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) + if min_stack_version_val: + min_stack_version = Version.parse(min_stack_version_val, optional_minor_and_patch=True) else: - min_stack_version = Version.parse(min_stack_version, optional_minor_and_patch=True) + min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) config = eql.parser.ParserConfig() for feature, version_range in definitions.ELASTICSEARCH_EQL_FEATURES.items(): if version_range[0] <= min_stack_version <= (version_range[1] or min_stack_version): - config.context[feature] = True + config.context[feature] = True # type: ignore[reportUnknownMemberType] return config -def get_unique_query_fields(rule: TOMLRule) -> List[str]: +def get_unique_query_fields(rule: TOMLRule) -> list[str] | None: """Get a list of unique fields used in a rule query from rule contents.""" contents = rule.contents.to_api_format() - language = contents.get('language') - query = contents.get('query') - if language in ('kuery', 'eql'): + language = contents.get("language") + query = contents.get("query") + if language in ("kuery", "eql"): # TODO: remove once py-eql supports ipv6 for cidrmatch - cfg = set_eql_config(rule.contents.metadata.get('min_stack_version')) + min_stack_version = rule.contents.metadata.get("min_stack_version") + if not min_stack_version: + raise ValueError("Min stack version not found") + cfg = set_eql_config(min_stack_version) with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions, eql.parser.skip_optimizations, cfg: - parsed = (kql.parse(query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) - if language == 'kuery' else eql.parse_query(query)) - return sorted(set(str(f) for f in parsed if isinstance(f, (eql.ast.Field, kql.ast.Field)))) + parsed = ( # type: ignore[reportUnknownVariableType] + kql.parse(query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) # type: ignore[reportUnknownMemberType] + if language == "kuery" + else eql.parse_query(query) # type: ignore[reportUnknownMemberType] + ) + return sorted(set(str(f) for f in parsed if isinstance(f, (eql.ast.Field, kql.ast.Field)))) # type: ignore[reportUnknownVariableType] + + +def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> list[dict[str, Any]]: + """Parses datasets into packaged integrations from rule data.""" + packaged_integrations: list[dict[str, Any]] = [] + for value in sorted(datasets): + # cleanup extra quotes pulled from ast field + value = value.strip('"') + + integration = "Unknown" + if "." in value: + package, integration = value.split(".", 1) + # Handle cases where endpoint event datasource needs to be parsed uniquely (e.g endpoint.events.network) + # as endpoint.network + if package == "endpoint" and "events" in integration: + integration = integration.split(".")[1] + else: + package = value + + if package in list(package_manifest): + packaged_integrations.append({"package": package, "integration": integration}) + return packaged_integrations # avoid a circular import diff --git a/detection_rules/rule_formatter.py b/detection_rules/rule_formatter.py index f080ed0bf55..af851f050a9 100644 --- a/detection_rules/rule_formatter.py +++ b/detection_rules/rule_formatter.py @@ -4,14 +4,16 @@ # 2.0. """Helper functions for managing rules in the repository.""" + import copy import dataclasses -import io import json import textwrap -import typing +from pathlib import Path from collections import OrderedDict +from typing import Any, Iterable + import toml from .schemas import definitions @@ -26,61 +28,73 @@ @cached def get_preserved_fmt_fields(): from .rule import BaseRuleData - preserved_keys = set() - for field in dataclasses.fields(BaseRuleData): # type: dataclasses.Field - if field.type in (definitions.Markdown, typing.Optional[definitions.Markdown]): + preserved_keys: set[str] = set() + + for field in dataclasses.fields(BaseRuleData): + if field.type in (definitions.Markdown, None): preserved_keys.add(field.metadata.get("data_key", field.name)) return preserved_keys -def cleanup_whitespace(val): +def cleanup_whitespace(val: Any) -> Any: if isinstance(val, str): return " ".join(line.strip() for line in val.strip().splitlines()) return val -def nested_normalize(d, skip_cleanup=False): +def nested_normalize(d: Any, skip_cleanup: bool = False) -> Any: if isinstance(d, str): return d if skip_cleanup else cleanup_whitespace(d) elif isinstance(d, list): - return [nested_normalize(val) for val in d] + return [nested_normalize(val) for val in d] # type: ignore[reportUnknownVariableType] elif isinstance(d, dict): - for k, v in d.items(): - if k == 'query': + for k, v in d.items(): # type: ignore[reportUnknownVariableType] + if k == "query": # TODO: the linter still needs some work, but once up to par, uncomment to implement - kql.lint(v) # do not normalize queries - d.update({k: v}) + d.update({k: v}) # type: ignore[reportUnknownMemberType] elif k in get_preserved_fmt_fields(): # let these maintain newlines and whitespace for markdown support - d.update({k: nested_normalize(v, skip_cleanup=True)}) + d.update({k: nested_normalize(v, skip_cleanup=True)}) # type: ignore[reportUnknownMemberType] else: - d.update({k: nested_normalize(v)}) - return d + d.update({k: nested_normalize(v)}) # type: ignore[reportUnknownMemberType] + return d # type: ignore[reportUnknownVariableType] else: return d -def wrap_text(v, block_indent=0, join=False): +def wrap_text(v: str, block_indent: int = 0) -> list[str]: """Block and indent a blob of text.""" - v = ' '.join(v.split()) - lines = textwrap.wrap(v, initial_indent=' ' * block_indent, subsequent_indent=' ' * block_indent, width=120, - break_long_words=False, break_on_hyphens=False) - lines = [line + '\n' for line in lines] + v = " ".join(v.split()) + lines = textwrap.wrap( + v, + initial_indent=" " * block_indent, + subsequent_indent=" " * block_indent, + width=120, + break_long_words=False, + break_on_hyphens=False, + ) + lines = [line + "\n" for line in lines] # If there is a single line that contains a quote, add a new blank line to trigger multiline formatting if len(lines) == 1 and '"' in lines[0]: - lines = lines + [''] - return lines if not join else ''.join(lines) + lines = lines + [""] + return lines + + +def wrap_text_and_join(v: str, block_indent: int = 0) -> str: + lines = wrap_text(v, block_indent=block_indent) + return "".join(lines) class NonformattedField(str): """Non-formatting class.""" -def preserve_formatting_for_fields(data: OrderedDict, fields_to_preserve: list) -> OrderedDict: +def preserve_formatting_for_fields(data: OrderedDict[str, Any], fields_to_preserve: list[str]) -> OrderedDict[str, Any]: """Preserve formatting for specified nested fields in an action.""" - def apply_preservation(target: OrderedDict, keys: list) -> None: + def apply_preservation(target: OrderedDict[str, Any], keys: list[str]) -> None: """Apply NonformattedField preservation based on keys path.""" for key in keys[:-1]: # Iterate to the key, diving into nested dictionaries @@ -96,28 +110,28 @@ def apply_preservation(target: OrderedDict, keys: list) -> None: target[final_key] = NonformattedField(target[final_key]) for field_path in fields_to_preserve: - keys = field_path.split('.') + keys = field_path.split(".") apply_preservation(data, keys) return data -class RuleTomlEncoder(toml.TomlEncoder): +class RuleTomlEncoder(toml.TomlEncoder): # type: ignore[reportMissingTypeArgument] """Generate a pretty form of toml.""" - def __init__(self, _dict=dict, preserve=False): + def __init__(self, *args: Any, **kwargs: Any): """Create the encoder but override some default functions.""" - super(RuleTomlEncoder, self).__init__(_dict, preserve) + super(RuleTomlEncoder, self).__init__(*args, **kwargs) # type: ignore[reportUnknownMemberType] self._old_dump_str = toml.TomlEncoder().dump_funcs[str] self._old_dump_list = toml.TomlEncoder().dump_funcs[list] self.dump_funcs[str] = self.dump_str - self.dump_funcs[type(u"")] = self.dump_str + self.dump_funcs[type("")] = self.dump_str self.dump_funcs[list] = self.dump_list self.dump_funcs[NonformattedField] = self.dump_str - def dump_str(self, v): + def dump_str(self, v: str | NonformattedField) -> str: """Change the TOML representation to multi-line or single quote when logical.""" - initial_newline = ['\n'] + initial_newline = ["\n"] if isinstance(v, NonformattedField): # first line break is not forced like other multiline string dumps @@ -136,131 +150,136 @@ def dump_str(self, v): else: return "\n".join([TRIPLE_SQ] + [self._old_dump_str(line)[1:-1] for line in lines] + [TRIPLE_SQ]) elif raw: - return u"'{:s}'".format(lines[0]) + return "'{:s}'".format(lines[0]) return self._old_dump_str(v) - def _dump_flat_list(self, v): + def _dump_flat_list(self, v: Iterable[Any]): """A slightly tweaked version of original dump_list, removing trailing commas.""" if not v: return "[]" - retval = "[" + str(self.dump_value(v[0])) + "," - for u in v[1:]: + v_list = list(v) + + retval = "[" + str(self.dump_value(v_list[0])) + "," + for u in v_list[1:]: retval += " " + str(self.dump_value(u)) + "," - retval = retval.rstrip(',') + "]" + retval = retval.rstrip(",") + "]" return retval - def dump_list(self, v): + def dump_list(self, v: Iterable[Any]) -> str: """Dump a list more cleanly.""" if all([isinstance(d, str) for d in v]) and sum(len(d) + 3 for d in v) > 100: - dump = [] + dump: list[str] = [] for item in v: - if len(item) > (120 - 4 - 3 - 3) and ' ' in item: - dump.append(' """\n{} """'.format(wrap_text(item, block_indent=4, join=True))) + if len(item) > (120 - 4 - 3 - 3) and " " in item: + dump.append(' """\n{} """'.format(wrap_text_and_join(item, block_indent=4))) else: - dump.append(' ' * 4 + self.dump_value(item)) - return '[\n{},\n]'.format(',\n'.join(dump)) + dump.append(" " * 4 + self.dump_value(item)) + return "[\n{},\n]".format(",\n".join(dump)) if v and all(isinstance(i, dict) for i in v): # Compact inline format for lists of dictionaries with proper indentation - retval = "\n" + ' ' * 2 + "[\n" - retval += ",\n".join([' ' * 4 + self.dump_inline_table(u).strip() for u in v]) - retval += "\n" + ' ' * 2 + "]\n" + retval = "\n" + " " * 2 + "[\n" + retval += ",\n".join([" " * 4 + self.dump_inline_table(u).strip() for u in v]) + retval += "\n" + " " * 2 + "]\n" return retval return self._dump_flat_list(v) -def toml_write(rule_contents, outfile=None): +def toml_write(rule_contents: dict[str, Any], out_file_path: Path | None = None): """Write rule in TOML.""" - def write(text, nl=True): - if outfile: - outfile.write(text) + + def write(text: str, nl: bool = True): + if out_file_path: + # Append data to a file + with out_file_path.open("a") as f: + _ = f.write(text) if nl: - outfile.write(u"\n") + with out_file_path.open("a") as f: + _ = f.write("\n") else: - print(text, end='' if not nl else '\n') + print(text, end="" if not nl else "\n") encoder = RuleTomlEncoder() contents = copy.deepcopy(rule_contents) - needs_close = False - def order_rule(obj): + def order_rule(obj: Any): if isinstance(obj, dict): - obj = OrderedDict(sorted(obj.items())) + obj = OrderedDict(sorted(obj.items())) # type: ignore[reportUnknownArgumentType, reportUnknownVariableType] for k, v in obj.items(): if isinstance(v, dict) or isinstance(v, list): obj[k] = order_rule(v) if isinstance(obj, list): - for i, v in enumerate(obj): + for i, v in enumerate(obj): # type: ignore[reportUnknownMemberType] if isinstance(v, dict) or isinstance(v, list): obj[i] = order_rule(v) - obj = sorted(obj, key=lambda x: json.dumps(x)) + obj = sorted(obj, key=lambda x: json.dumps(x)) # type: ignore[reportUnknownArgumentType, reportUnknownVariableType] return obj - def _do_write(_data, _contents): + def _do_write(_data: str, _contents: dict[str, Any]): query = None threat_query = None - if _data == 'rule': + if _data == "rule": # - We want to avoid the encoder for the query and instead use kql-lint. # - Linting is done in rule.normalize() which is also called in rule.validate(). # - Until lint has tabbing, this is going to result in all queries being flattened with no wrapping, # but will at least purge extraneous white space - query = contents['rule'].pop('query', '').strip() + query = contents["rule"].pop("query", "").strip() # - As tags are expanding, we may want to reconsider the need to have them in alphabetical order # tags = contents['rule'].get("tags", []) # # if tags and isinstance(tags, list): # contents['rule']["tags"] = list(sorted(set(tags))) - threat_query = contents['rule'].pop('threat_query', '').strip() + threat_query = contents["rule"].pop("threat_query", "").strip() - top = OrderedDict() - bottom = OrderedDict() + top: OrderedDict[str, Any] = OrderedDict() + bottom: OrderedDict[str, Any] = OrderedDict() for k in sorted(list(_contents)): v = _contents.pop(k) - if k == 'actions': + if k == "actions": # explicitly preserve formatting for message field in actions preserved_fields = ["params.message"] v = [preserve_formatting_for_fields(action, preserved_fields) for action in v] if v is not None else [] - if k == 'filters': + if k == "filters": # explicitly preserve formatting for value field in filters preserved_fields = ["meta.value"] v = [preserve_formatting_for_fields(meta, preserved_fields) for meta in v] if v is not None else [] - if k == 'note' and isinstance(v, str): + if k == "note" and isinstance(v, str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. v = v.replace("\\", "\\\\") - if k == 'setup' and isinstance(v, str): + if k == "setup" and isinstance(v, str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. v = v.replace("\\", "\\\\") - if k == 'description' and isinstance(v, str): + if k == "description" and isinstance(v, str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. v = v.replace("\\", "\\\\") - if k == 'osquery' and isinstance(v, list): + if k == "osquery" and isinstance(v, list): # Specifically handle transform.osquery queries - for osquery_item in v: - if 'query' in osquery_item and isinstance(osquery_item['query'], str): + for osquery_item in v: # type: ignore[reportUnknownVariableType] + if "query" in osquery_item and isinstance(osquery_item["query"], str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. - osquery_item['query'] = osquery_item['query'].replace("\\", "\\\\") + osquery_item["query"] = osquery_item["query"].replace("\\", "\\\\") # type: ignore[reportUnknownMemberType] if isinstance(v, dict): - bottom[k] = OrderedDict(sorted(v.items())) + bottom[k] = OrderedDict(sorted(v.items())) # type: ignore[reportUnknownArgumentType] elif isinstance(v, list): - if any([isinstance(value, (dict, list)) for value in v]): + if any([isinstance(value, (dict, list)) for value in v]): # type: ignore[reportUnknownArgumentType] bottom[k] = v else: top[k] = v @@ -270,39 +289,30 @@ def _do_write(_data, _contents): top[k] = v if query: - top.update({'query': "XXxXX"}) + top.update({"query": "XXxXX"}) # type: ignore[reportUnknownMemberType] if threat_query: - top.update({'threat_query': "XXxXX"}) + top.update({"threat_query": "XXxXX"}) # type: ignore[reportUnknownMemberType] - top.update(bottom) - top = toml.dumps(OrderedDict({data: top}), encoder=encoder) + top.update(bottom) # type: ignore[reportUnknownMemberType] + top_out = toml.dumps(OrderedDict({data: top}), encoder=encoder) # type: ignore[reportUnknownMemberType] # we want to preserve the threat_query format, but want to modify it in the context of encoded dump if threat_query: - formatted_threat_query = "\nthreat_query = '''\n{}\n'''{}".format(threat_query, '\n\n' if bottom else '') - top = top.replace('threat_query = "XXxXX"', formatted_threat_query) + formatted_threat_query = "\nthreat_query = '''\n{}\n'''{}".format(threat_query, "\n\n" if bottom else "") + top_out = top_out.replace('threat_query = "XXxXX"', formatted_threat_query) # we want to preserve the query format, but want to modify it in the context of encoded dump if query: - formatted_query = "\nquery = '''\n{}\n'''{}".format(query, '\n\n' if bottom else '') - top = top.replace('query = "XXxXX"', formatted_query) - - write(top) - - try: - - if outfile and not isinstance(outfile, io.IOBase): - needs_close = True - outfile = open(outfile, 'w') - - for data in ('metadata', 'transform', 'rule'): - _contents = contents.get(data, {}) - if not _contents: - continue - order_rule(_contents) - _do_write(data, _contents) - - finally: - if needs_close and hasattr(outfile, "close"): - outfile.close() + formatted_query = "\nquery = '''\n{}\n'''{}".format(query, "\n\n" if bottom else "") + top_out = top_out.replace('query = "XXxXX"', formatted_query) + + write(top_out) + + for data in ("metadata", "transform", "rule"): + _contents = contents.get(data, {}) + if not _contents: + continue + # FIXME: commenting out the call here as order_rule has subtle side-effects while its output is not used explicitely + # order_rule(_contents) + _do_write(data, _contents) diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index b2d943c9e7d..87b1a80be74 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -4,39 +4,42 @@ # 2.0. """Load rule metadata transform between rule and api formats.""" + +import requests from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path +from multiprocessing.pool import ThreadPool from subprocess import CalledProcessError -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import Callable, Iterable, Any, Iterator import click -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] import json from marshmallow.exceptions import ValidationError +from github.PullRequest import PullRequest +from github.File import File from . import utils from .config import parse_rules_config -from .rule import ( - DeprecatedRule, DeprecatedRuleContents, DictRule, TOMLRule, - TOMLRuleContents -) +from .rule import DeprecatedRule, DeprecatedRuleContents, DictRule, TOMLRule, TOMLRuleContents +from .ghwrap import GithubClient from .schemas import definitions from .utils import cached, get_path RULES_CONFIG = parse_rules_config() DEFAULT_PREBUILT_RULES_DIRS = RULES_CONFIG.rule_dirs DEFAULT_PREBUILT_BBR_DIRS = RULES_CONFIG.bbr_rules_dirs -FILE_PATTERN = r'^([a-z0-9_])+\.(json|toml)$' +FILE_PATTERN = r"^([a-z0-9_])+\.(json|toml)$" -def path_getter(value: str) -> Callable[[dict], bool]: +def path_getter(value: str) -> Callable[[dict[str, Any]], Any]: """Get the path from a Python object.""" path = value.replace("__", ".").split(".") - def callback(obj: dict): + def callback(obj: dict[str, Any]) -> Any: for p in path: - if isinstance(obj, dict) and p in path: + if p in path: obj = obj[p] else: return None @@ -46,28 +49,36 @@ def callback(obj: dict): return callback -def dict_filter(_obj: Optional[dict] = None, **critieria) -> Callable[[dict], bool]: +def dict_filter(_obj: dict[str, Any] | None = None, **criteria: Any) -> Callable[[dict[str, Any]], bool]: """Get a callable that will return true if a dictionary matches a set of criteria. * each key is a dotted (or __ delimited) path into a dictionary to check * each value is a value or list of values to match """ - critieria.update(_obj or {}) - checkers = [(path_getter(k), set(v) if isinstance(v, (list, set, tuple)) else {v}) for k, v in critieria.items()] - - def callback(obj: dict) -> bool: + criteria.update(_obj or {}) + checkers = [ + # FIXME: v might not be hashable + (path_getter(k), set(v if isinstance(v, (list, set, tuple)) else (v,))) # type: ignore[reportUnknownArgumentType] + for k, v in criteria.items() + ] + + def callback(obj: dict[str, Any]) -> bool: for getter, expected in checkers: target_values = getter(obj) - target_values = set(target_values) if isinstance(target_values, (list, set, tuple)) else {target_values} + target_values = ( # type: ignore[reportUnknownVariableType] + set(target_values) # type: ignore[reportUnknownVariableType] + if isinstance(target_values, (list, set, tuple)) + else set((target_values,)) + ) - return bool(expected.intersection(target_values)) + return bool(expected.intersection(target_values)) # type: ignore[reportUnknownArgumentType] return False return callback -def metadata_filter(**metadata) -> Callable[[TOMLRule], bool]: +def metadata_filter(**metadata: Any) -> Callable[[TOMLRule], bool]: """Get a filter callback based off rule metadata""" flt = dict_filter(metadata) @@ -81,48 +92,60 @@ def callback(rule: TOMLRule) -> bool: production_filter = metadata_filter(maturity="production") -def load_locks_from_tag(remote: str, tag: str, version_lock: str = 'detection_rules/etc/version.lock.json', - deprecated_file: str = 'detection_rules/etc/deprecated_rules.json') -> (str, dict, dict): +def load_locks_from_tag( + remote: str, + tag: str, + version_lock: str = "detection_rules/etc/version.lock.json", + deprecated_file: str = "detection_rules/etc/deprecated_rules.json", +) -> tuple[str, dict[str, Any], dict[str, Any]]: """Loads version and deprecated lock files from git tag.""" import json + git = utils.make_git() - exists_args = ['ls-remote'] + exists_args = ["ls-remote"] if remote: exists_args.append(remote) - exists_args.append(f'refs/tags/{tag}') + exists_args.append(f"refs/tags/{tag}") - assert git(*exists_args), f'tag: {tag} does not exist in {remote or "local"}' + assert git(*exists_args), f"tag: {tag} does not exist in {remote or 'local'}" - fetch_tags = ['fetch'] + fetch_tags = ["fetch"] if remote: - fetch_tags += [remote, '--tags', '-f', tag] + fetch_tags += [remote, "--tags", "-f", tag] else: - fetch_tags += ['--tags', '-f', tag] + fetch_tags += ["--tags", "-f", tag] - git(*fetch_tags) + _ = git(*fetch_tags) - commit_hash = git('rev-list', '-1', tag) + commit_hash = git("rev-list", "-1", tag) try: - version = json.loads(git('show', f'{tag}:{version_lock}')) + version = json.loads(git("show", f"{tag}:{version_lock}")) except CalledProcessError: # Adding resiliency to account for the old directory structure - version = json.loads(git('show', f'{tag}:etc/version.lock.json')) + version = json.loads(git("show", f"{tag}:etc/version.lock.json")) try: - deprecated = json.loads(git('show', f'{tag}:{deprecated_file}')) + deprecated = json.loads(git("show", f"{tag}:{deprecated_file}")) except CalledProcessError: # Adding resiliency to account for the old directory structure - deprecated = json.loads(git('show', f'{tag}:etc/deprecated_rules.json')) + deprecated = json.loads(git("show", f"{tag}:etc/deprecated_rules.json")) return commit_hash, version, deprecated -def update_metadata_from_file(rule_path: Path, fields_to_update: dict) -> dict: +def update_metadata_from_file(rule_path: Path, fields_to_update: dict[str, Any]) -> dict[str, Any]: """Update metadata fields for a rule with local contents.""" - contents = {} + + contents: dict[str, Any] = dict() if not rule_path.exists(): return contents - local_metadata = RuleCollection().load_file(rule_path).contents.metadata.to_dict() + + rule_contents = RuleCollection().load_file(rule_path).contents + + if not isinstance(rule_contents, TOMLRuleContents): + raise ValueError("TOML rule expected") + + local_metadata = rule_contents.metadata.to_dict() if local_metadata: contents["maturity"] = local_metadata.get("maturity", "development") for field_name, should_update in fields_to_update.items(): @@ -132,34 +155,34 @@ def update_metadata_from_file(rule_path: Path, fields_to_update: dict) -> dict: @dataclass -class BaseCollection: +class BaseCollection[T]: """Base class for collections.""" - rules: list + rules: list[T] def __len__(self): """Get the total amount of rules in the collection.""" return len(self.rules) - def __iter__(self): + def __iter__(self) -> Iterator[T]: """Iterate over all rules in the collection.""" return iter(self.rules) @dataclass -class DeprecatedCollection(BaseCollection): +class DeprecatedCollection(BaseCollection[DeprecatedRule]): """Collection of loaded deprecated rule dicts.""" - id_map: Dict[str, DeprecatedRule] = field(default_factory=dict) - file_map: Dict[Path, DeprecatedRule] = field(default_factory=dict) - name_map: Dict[str, DeprecatedRule] = field(default_factory=dict) - rules: List[DeprecatedRule] = field(default_factory=list) + id_map: dict[str, DeprecatedRule] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + file_map: dict[Path, DeprecatedRule] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + name_map: dict[str, DeprecatedRule] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + rules: list[DeprecatedRule] = field(default_factory=list) # type: ignore[reportUnknownVariableType] def __contains__(self, rule: DeprecatedRule): """Check if a rule is in the map by comparing IDs.""" return rule.id in self.id_map - def filter(self, cb: Callable[[DeprecatedRule], bool]) -> 'RuleCollection': + def filter(self, cb: Callable[[DeprecatedRule], bool]) -> "RuleCollection": """Retrieve a filtered collection of rules.""" filtered_collection = RuleCollection() @@ -169,33 +192,33 @@ def filter(self, cb: Callable[[DeprecatedRule], bool]) -> 'RuleCollection': return filtered_collection -class RawRuleCollection(BaseCollection): +class RawRuleCollection(BaseCollection[DictRule]): """Collection of rules in raw dict form.""" __default = None __default_bbr = None - def __init__(self, rules: Optional[List[dict]] = None, ext_patterns: Optional[List[str]] = None): + def __init__(self, rules: list[DictRule] | None = None, ext_patterns: list[str] | None = None): """Create a new raw rule collection, with optional file ext pattern override.""" # ndjson is unsupported since it breaks the contract of 1 rule per file, so rules should be manually broken out # first - self.ext_patterns = ext_patterns or ['*.toml', '*.json'] - self.id_map: Dict[definitions.UUIDString, DictRule] = {} - self.file_map: Dict[Path, DictRule] = {} - self.name_map: Dict[definitions.RuleName, DictRule] = {} - self.rules: List[DictRule] = [] - self.errors: Dict[Path, Exception] = {} + self.ext_patterns = ext_patterns or ["*.toml", "*.json"] + self.id_map: dict[definitions.UUIDString, DictRule] = {} + self.file_map: dict[Path, DictRule] = {} + self.name_map: dict[definitions.RuleName, DictRule] = {} + self.rules: list[DictRule] = [] + self.errors: dict[Path, Exception] = {} self.frozen = False - self._raw_load_cache: Dict[Path, dict] = {} - for rule in (rules or []): + self._raw_load_cache: dict[Path, dict[str, Any]] = {} + for rule in rules or []: self.add_rule(rule) def __contains__(self, rule: DictRule): """Check if a rule is in the map by comparing IDs.""" return rule.id in self.id_map - def filter(self, cb: Callable[[DictRule], bool]) -> 'RawRuleCollection': + def filter(self, cb: Callable[[DictRule], bool]) -> "RawRuleCollection": """Retrieve a filtered collection of rules.""" filtered_collection = RawRuleCollection() @@ -204,7 +227,7 @@ def filter(self, cb: Callable[[DictRule], bool]) -> 'RawRuleCollection': return filtered_collection - def _load_rule_file(self, path: Path) -> dict: + def _load_rule_file(self, path: Path) -> dict[str, Any]: """Load a rule file into a dictionary.""" if path in self._raw_load_cache: return self._raw_load_cache[path] @@ -213,20 +236,20 @@ def _load_rule_file(self, path: Path) -> dict: # use pytoml instead of toml because of annoying bugs # https://github.com/uiri/toml/issues/152 # might also be worth looking at https://github.com/sdispater/tomlkit - raw_dict = pytoml.loads(path.read_text()) + raw_dict = pytoml.loads(path.read_text()) # type: ignore[reportUnknownMemberType] elif path.suffix == ".json": raw_dict = json.loads(path.read_text()) elif path.suffix == ".ndjson": - raise ValueError('ndjson is not supported in RawRuleCollection. Break out the rules individually.') + raise ValueError("ndjson is not supported in RawRuleCollection. Break out the rules individually.") else: raise ValueError(f"Unsupported file type {path.suffix} for rule {path}") self._raw_load_cache[path] = raw_dict - return raw_dict + return raw_dict # type: ignore[reportUnknownVariableType] - def _get_paths(self, directory: Path, recursive=True) -> List[Path]: + def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]: """Get all paths in a directory that match the ext patterns.""" - paths = [] + paths: list[Path] = [] for pattern in self.ext_patterns: paths.extend(sorted(directory.rglob(pattern) if recursive else directory.glob(pattern))) return paths @@ -238,10 +261,10 @@ def _assert_new(self, rule: DictRule): name_map = self.name_map assert not self.frozen, f"Unable to add rule {rule.name} {rule.id} to a frozen collection" - assert rule.id not in id_map, \ - f"Rule ID {rule.id} for {rule.name} collides with rule {id_map.get(rule.id).name}" - assert rule.name not in name_map, \ - f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map.get(rule.name).id}" + assert rule.id not in id_map, f"Rule ID {rule.id} for {rule.name} collides with rule {id_map[rule.id].name}" + assert rule.name not in name_map, ( + f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map[rule.name].id}" + ) if rule.path is not None: rule_path = rule.path.resolve() @@ -255,7 +278,7 @@ def add_rule(self, rule: DictRule): self.name_map[rule.name] = rule self.rules.append(rule) - def load_dict(self, obj: dict, path: Optional[Path] = None) -> DictRule: + def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> DictRule: """Load a rule from a dictionary.""" rule = DictRule(contents=obj, path=path) self.add_rule(rule) @@ -282,9 +305,9 @@ def load_file(self, path: Path) -> DictRule: def load_files(self, paths: Iterable[Path]): """Load multiple files into the collection.""" for path in paths: - self.load_file(path) + _ = self.load_file(path) - def load_directory(self, directory: Path, recursive=True, obj_filter: Optional[Callable[[dict], bool]] = None): + def load_directory(self, directory: Path, recursive: bool = True, obj_filter: Callable[..., bool] | None = None): """Load all rules in a directory.""" paths = self._get_paths(directory, recursive=recursive) if obj_filter is not None: @@ -292,8 +315,12 @@ def load_directory(self, directory: Path, recursive=True, obj_filter: Optional[C self.load_files(paths) - def load_directories(self, directories: Iterable[Path], recursive=True, - obj_filter: Optional[Callable[[dict], bool]] = None): + def load_directories( + self, + directories: Iterable[Path], + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ): """Load all rules in multiple directories.""" for path in directories: self.load_directory(path, recursive=recursive, obj_filter=obj_filter) @@ -303,7 +330,7 @@ def freeze(self): self.frozen = True @classmethod - def default(cls) -> 'RawRuleCollection': + def default(cls) -> "RawRuleCollection": """Return the default rule collection, which retrieves from rules/.""" if cls.__default is None: collection = RawRuleCollection() @@ -315,7 +342,7 @@ def default(cls) -> 'RawRuleCollection': return cls.__default @classmethod - def default_bbr(cls) -> 'RawRuleCollection': + def default_bbr(cls) -> "RawRuleCollection": """Return the default BBR collection, which retrieves from building_block_rules/.""" if cls.__default_bbr is None: collection = RawRuleCollection() @@ -326,34 +353,34 @@ def default_bbr(cls) -> 'RawRuleCollection': return cls.__default_bbr -class RuleCollection(BaseCollection): +class RuleCollection(BaseCollection[TOMLRule]): """Collection of rule objects.""" __default = None __default_bbr = None - def __init__(self, rules: Optional[List[TOMLRule]] = None): + def __init__(self, rules: list[TOMLRule] | None = None): from .version_lock import VersionLock - self.id_map: Dict[definitions.UUIDString, TOMLRule] = {} - self.file_map: Dict[Path, TOMLRule] = {} - self.name_map: Dict[definitions.RuleName, TOMLRule] = {} - self.rules: List[TOMLRule] = [] + self.id_map: dict[definitions.UUIDString, TOMLRule] = {} + self.file_map: dict[Path, TOMLRule] = {} + self.name_map: dict[definitions.RuleName, TOMLRule] = {} + self.rules: list[TOMLRule] = [] self.deprecated: DeprecatedCollection = DeprecatedCollection() - self.errors: Dict[Path, Exception] = {} + self.errors: dict[Path, Exception] = {} self.frozen = False - self._toml_load_cache: Dict[Path, dict] = {} - self._version_lock: Optional[VersionLock] = None + self._toml_load_cache: dict[Path, dict[str, Any]] = {} + self._version_lock: VersionLock | None = None - for rule in (rules or []): + for rule in rules or []: self.add_rule(rule) def __contains__(self, rule: TOMLRule): """Check if a rule is in the map by comparing IDs.""" return rule.id in self.id_map - def filter(self, cb: Callable[[TOMLRule], bool]) -> 'RuleCollection': + def filter(self, cb: Callable[[TOMLRule], bool]) -> "RuleCollection": """Retrieve a filtered collection of rules.""" filtered_collection = RuleCollection() @@ -363,10 +390,10 @@ def filter(self, cb: Callable[[TOMLRule], bool]) -> 'RuleCollection': return filtered_collection @staticmethod - def deserialize_toml_string(contents: Union[bytes, str]) -> dict: - return pytoml.loads(contents) + def deserialize_toml_string(contents: bytes | str) -> dict[str, Any]: + return pytoml.loads(contents) # type: ignore[reportUnknownMemberType] - def _load_toml_file(self, path: Path) -> dict: + def _load_toml_file(self, path: Path) -> dict[str, Any]: if path in self._toml_load_cache: return self._toml_load_cache[path] @@ -378,10 +405,10 @@ def _load_toml_file(self, path: Path) -> dict: self._toml_load_cache[path] = toml_dict return toml_dict - def _get_paths(self, directory: Path, recursive=True) -> List[Path]: - return sorted(directory.rglob('*.toml') if recursive else directory.glob('*.toml')) + def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]: + return sorted(directory.rglob("*.toml") if recursive else directory.glob("*.toml")) - def _assert_new(self, rule: Union[TOMLRule, DeprecatedRule], is_deprecated=False): + def _assert_new(self, rule: TOMLRule | DeprecatedRule, is_deprecated: bool = False): if is_deprecated: id_map = self.deprecated.id_map file_map = self.deprecated.file_map @@ -391,16 +418,23 @@ def _assert_new(self, rule: Union[TOMLRule, DeprecatedRule], is_deprecated=False file_map = self.file_map name_map = self.name_map + if not rule.id: + raise ValueError("Rule has no ID") + assert not self.frozen, f"Unable to add rule {rule.name} {rule.id} to a frozen collection" - assert rule.id not in id_map, \ - f"Rule ID {rule.id} for {rule.name} collides with rule {id_map.get(rule.id).name}" - assert rule.name not in name_map, \ - f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map.get(rule.name).id}" + assert rule.id not in id_map, f"Rule ID {rule.id} for {rule.name} collides with rule {id_map[rule.id].name}" + + if not rule.name: + raise ValueError("Rule has no name") + + assert rule.name not in name_map, ( + f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map[rule.name].id}" + ) if rule.path is not None: rule_path = rule.path.resolve() assert rule_path not in file_map, f"Rule file {rule_path} already loaded" - file_map[rule_path] = rule + file_map[rule_path] = rule # type: ignore[reportArgumentType] def add_rule(self, rule: TOMLRule): self._assert_new(rule) @@ -410,16 +444,24 @@ def add_rule(self, rule: TOMLRule): def add_deprecated_rule(self, rule: DeprecatedRule): self._assert_new(rule, is_deprecated=True) + + if not rule.id: + raise ValueError("Rule has no ID") + if not rule.name: + raise ValueError("Rule has no name") + self.deprecated.id_map[rule.id] = rule self.deprecated.name_map[rule.name] = rule self.deprecated.rules.append(rule) - def load_dict(self, obj: dict, path: Optional[Path] = None) -> Union[TOMLRule, DeprecatedRule]: + def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> TOMLRule | DeprecatedRule: # bypass rule object load (load_dict) and load as a dict only - if obj.get('metadata', {}).get('maturity', '') == 'deprecated': + if obj.get("metadata", {}).get("maturity", "") == "deprecated": contents = DeprecatedRuleContents.from_dict(obj) if not RULES_CONFIG.bypass_version_lock: contents.set_version_lock(self._version_lock) + if not path: + raise ValueError("No path value provided") deprecated_rule = DeprecatedRule(path, contents) self.add_deprecated_rule(deprecated_rule) return deprecated_rule @@ -431,7 +473,7 @@ def load_dict(self, obj: dict, path: Optional[Path] = None) -> Union[TOMLRule, D self.add_rule(rule) return rule - def load_file(self, path: Path) -> Union[TOMLRule, DeprecatedRule]: + def load_file(self, path: Path) -> TOMLRule | DeprecatedRule: try: path = path.resolve() @@ -453,18 +495,19 @@ def load_file(self, path: Path) -> Union[TOMLRule, DeprecatedRule]: print(f"Error loading rule in {path}") raise - def load_git_tag(self, branch: str, remote: Optional[str] = None, skip_query_validation=False): + def load_git_tag(self, branch: str, remote: str, skip_query_validation: bool = False): """Load rules from a Git branch.""" from .version_lock import VersionLock, add_rule_types_to_lock git = utils.make_git() - paths = [] + paths: list[str] = [] for rules_dir in DEFAULT_PREBUILT_RULES_DIRS: - rules_dir = rules_dir.relative_to(get_path(".")) - paths.extend(git("ls-tree", "-r", "--name-only", branch, rules_dir).splitlines()) + rules_dir = rules_dir.relative_to(get_path(["."])) + git_output = git("ls-tree", "-r", "--name-only", branch, rules_dir) + paths.extend(git_output.splitlines()) - rule_contents = [] - rule_map = {} + rule_contents: list[tuple[dict[str, Any], Path]] = [] + rule_map: dict[str, Any] = {} for path in paths: path = Path(path) if path.suffix != ".toml": @@ -474,15 +517,15 @@ def load_git_tag(self, branch: str, remote: Optional[str] = None, skip_query_val toml_dict = self.deserialize_toml_string(contents) if skip_query_validation: - toml_dict['metadata']['query_schema_validation'] = False + toml_dict["metadata"]["query_schema_validation"] = False rule_contents.append((toml_dict, path)) - rule_map[toml_dict['rule']['rule_id']] = toml_dict + rule_map[toml_dict["rule"]["rule_id"]] = toml_dict commit_hash, v_lock, d_lock = load_locks_from_tag(remote, branch) - v_lock_name_prefix = f'{remote}/' if remote else '' - v_lock_name = f'{v_lock_name_prefix}{branch}-{commit_hash}' + v_lock_name_prefix = f"{remote}/" if remote else "" + v_lock_name = f"{v_lock_name_prefix}{branch}-{commit_hash}" # For backwards compatibility with tagged branches that existed before the types were added and validation # enforced, we will need to manually add the rule types to the version lock allow them to pass validation. @@ -494,25 +537,34 @@ def load_git_tag(self, branch: str, remote: Optional[str] = None, skip_query_val for rule_content in rule_contents: toml_dict, path = rule_content try: - self.load_dict(toml_dict, path) + _ = self.load_dict(toml_dict, path) except ValidationError as e: self.errors[path] = e continue - def load_files(self, paths: Iterable[Path]): + def load_files(self, paths: Iterable[Path]) -> None: """Load multiple files into the collection.""" for path in paths: - self.load_file(path) - - def load_directory(self, directory: Path, recursive=True, obj_filter: Optional[Callable[[dict], bool]] = None): + _ = self.load_file(path) + + def load_directory( + self, + directory: Path, + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ): paths = self._get_paths(directory, recursive=recursive) if obj_filter is not None: paths = [path for path in paths if obj_filter(self._load_toml_file(path))] self.load_files(paths) - def load_directories(self, directories: Iterable[Path], recursive=True, - obj_filter: Optional[Callable[[dict], bool]] = None): + def load_directories( + self, + directories: Iterable[Path], + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ): for path in directories: self.load_directory(path, recursive=recursive, obj_filter=obj_filter) @@ -521,7 +573,7 @@ def freeze(self): self.frozen = True @classmethod - def default(cls) -> 'RuleCollection': + def default(cls) -> "RuleCollection": """Return the default rule collection, which retrieves from rules/.""" if cls.__default is None: collection = RuleCollection() @@ -533,7 +585,7 @@ def default(cls) -> 'RuleCollection': return cls.__default @classmethod - def default_bbr(cls) -> 'RuleCollection': + def default_bbr(cls) -> "RuleCollection": """Return the default BBR collection, which retrieves from building_block_rules/.""" if cls.__default_bbr is None: collection = RuleCollection() @@ -543,17 +595,18 @@ def default_bbr(cls) -> 'RuleCollection': return cls.__default_bbr - def compare_collections(self, other: 'RuleCollection' - ) -> (Dict[str, TOMLRule], Dict[str, TOMLRule], Dict[str, DeprecatedRule]): + def compare_collections( + self, other: "RuleCollection" + ) -> tuple[dict[str, TOMLRule], dict[str, TOMLRule], dict[str, DeprecatedRule]]: """Get the changes between two sets of rules.""" - assert self._version_lock, 'RuleCollection._version_lock must be set for self' - assert other._version_lock, 'RuleCollection._version_lock must be set for other' + assert self._version_lock, "RuleCollection._version_lock must be set for self" + assert other._version_lock, "RuleCollection._version_lock must be set for other" # we cannot trust the assumption that either of the versions or deprecated files were pre-locked, which means we # have to perform additional checks beyond what is done in manage_versions - changed_rules = {} - new_rules = {} - newly_deprecated = {} + changed_rules: dict[str, TOMLRule] = dict() + new_rules: dict[str, TOMLRule] = dict() + newly_deprecated: dict[str, DeprecatedRule] = dict() pre_versions_hash = utils.dict_hash(self._version_lock.version_lock.to_dict()) post_versions_hash = utils.dict_hash(other._version_lock.version_lock.to_dict()) @@ -564,7 +617,7 @@ def compare_collections(self, other: 'RuleCollection' return changed_rules, new_rules, newly_deprecated for rule in other: - if rule.contents.metadata.maturity != 'production': + if rule.contents.metadata.maturity != "production": continue if rule.id not in self.id_map: @@ -575,46 +628,44 @@ def compare_collections(self, other: 'RuleCollection' changed_rules[rule.id] = rule for rule in other.deprecated: - if rule.id not in self.deprecated.id_map: + if rule.id and rule.id not in self.deprecated.id_map: newly_deprecated[rule.id] = rule return changed_rules, new_rules, newly_deprecated @cached -def load_github_pr_rules(labels: list = None, repo: str = 'elastic/detection-rules', token=None, threads=50, - verbose=True) -> (Dict[str, TOMLRule], Dict[str, TOMLRule], Dict[str, list]): +def load_github_pr_rules( + labels: list[str] | None = None, + repo_name: str = "elastic/detection-rules", + token: str | None = None, + threads: int = 50, + verbose: bool = True, +) -> tuple[dict[str, TOMLRule], dict[str, list[TOMLRule]], dict[str, list[str]]]: """Load all rules active as a GitHub PR.""" - from multiprocessing.pool import ThreadPool - from pathlib import Path - - import pytoml - import requests - - from .ghwrap import GithubClient github = GithubClient(token=token) - repo = github.client.get_repo(repo) - labels = set(labels or []) - open_prs = [r for r in repo.get_pulls() if not labels.difference(set(list(lbl.name for lbl in r.get_labels())))] + repo = github.client.get_repo(repo_name) + labels_set = set(labels or []) + open_prs = [r for r in repo.get_pulls() if not labels_set.difference(set(lbl.name for lbl in r.get_labels()))] - new_rules: List[TOMLRule] = [] - modified_rules: List[TOMLRule] = [] - errors: Dict[str, list] = {} + new_rules: list[TOMLRule] = [] + modified_rules: list[TOMLRule] = [] + errors: dict[str, list[str]] = {} existing_rules = RuleCollection.default() - pr_rules = [] + pr_rules: list[tuple[PullRequest, File]] = [] if verbose: - click.echo('Downloading rules from GitHub PRs') + click.echo("Downloading rules from GitHub PRs") - def download_worker(pr_info): + def download_worker(pr_info: tuple[PullRequest, File]): pull, rule_file = pr_info response = requests.get(rule_file.raw_url) try: - raw_rule = pytoml.loads(response.text) - contents = TOMLRuleContents.from_dict(raw_rule) - rule = TOMLRule(path=rule_file.filename, contents=contents) + raw_rule = pytoml.loads(response.text) # type: ignore[reportUnknownVariableType] + contents = TOMLRuleContents.from_dict(raw_rule) # type: ignore[reportUnknownArgumentType] + rule = TOMLRule(path=Path(rule_file.filename), contents=contents) rule.gh_pr = pull if rule in existing_rules: @@ -623,19 +674,21 @@ def download_worker(pr_info): new_rules.append(rule) except Exception as e: - errors.setdefault(Path(rule_file.filename).name, []).append(str(e)) + name = Path(rule_file.filename).name + errors.setdefault(name, []).append(str(e)) for pr in open_prs: - pr_rules.extend([(pr, f) for f in pr.get_files() - if f.filename.startswith('rules/') and f.filename.endswith('.toml')]) + pr_rules.extend( + [(pr, f) for f in pr.get_files() if f.filename.startswith("rules/") and f.filename.endswith(".toml")] + ) pool = ThreadPool(processes=threads) - pool.map(download_worker, pr_rules) + _ = pool.map(download_worker, pr_rules) pool.close() pool.join() new = OrderedDict([(rule.contents.id, rule) for rule in sorted(new_rules, key=lambda r: r.contents.name)]) - modified = OrderedDict() + modified: OrderedDict[str, list[TOMLRule]] = OrderedDict() for modified_rule in sorted(modified_rules, key=lambda r: r.contents.name): modified.setdefault(modified_rule.contents.id, []).append(modified_rule) diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index 5fe957ec1a5..d1ac53adaa6 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -4,38 +4,39 @@ # 2.0. """Validation logic for rules containing queries.""" + import re from enum import Enum from functools import cached_property, wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable -import eql -from eql import ast -from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint -from eql.parser import _parse as base_parse +import eql # type: ignore[reportMissingTypeStubs] +from eql import ast # type: ignore[reportMissingTypeStubs] +from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint # type: ignore[reportMissingTypeStubs] +from eql.parser import _parse as base_parse # type: ignore[reportMissingTypeStubs] from marshmallow import ValidationError from semver import Version -import kql +import kql # type: ignore[reportMissingTypeStubs] import click from . import ecs, endgame from .config import CUSTOM_RULES_DIR, load_current_package_version, parse_rules_config from .custom_schemas import update_auto_generated_schema -from .integrations import (get_integration_schema_data, - load_integrations_manifests) -from .rule import (EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, - TOMLRuleContents, set_eql_config) +from .integrations import get_integration_schema_data, load_integrations_manifests +from .rule import EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, set_eql_config from .schemas import get_stack_schemas -EQL_ERROR_TYPES = Union[eql.EqlCompileError, - eql.EqlError, - eql.EqlParseError, - eql.EqlSchemaError, - eql.EqlSemanticError, - eql.EqlSyntaxError, - eql.EqlTypeMismatchError] -KQL_ERROR_TYPES = Union[kql.KqlCompileError, kql.KqlParseError] +EQL_ERROR_TYPES = ( + eql.EqlCompileError + | eql.EqlError + | eql.EqlParseError + | eql.EqlSchemaError + | eql.EqlSemanticError + | eql.EqlSyntaxError + | eql.EqlTypeMismatchError +) +KQL_ERROR_TYPES = kql.KqlCompileError | kql.KqlParseError RULES_CONFIG = parse_rules_config() @@ -52,47 +53,50 @@ def is_primitive(self): return self in self.primitives() -def custom_in_set(self, node: KvTree) -> NodeInfo: +def custom_in_set(self: LarkToEQL, node: KvTree) -> NodeInfo: """Override and address the limitations of the eql in_set method.""" - # return BaseInSetMethod(self, node) - outer, container = self.visit(node.child_trees) # type: (NodeInfo, list[NodeInfo]) + response = self.visit(node.child_trees) # type: ignore + if not response: + raise ValueError("Child trees are not provided") + + outer, container = response # type: ignore[reportUnknownVariableType] - if not outer.validate_type(ExtendedTypeHint.primitives()): + if not outer.validate_type(ExtendedTypeHint.primitives()): # type: ignore # can't compare non-primitives to sets - raise self._type_error(outer, ExtendedTypeHint.primitives()) + raise self._type_error(outer, ExtendedTypeHint.primitives()) # type: ignore # Check that everything inside the container has the same type as outside error_message = "Unable to compare {expected_type} to {actual_type}" - for inner in container: - if not inner.validate_type(outer): - raise self._type_error(inner, outer, error_message) + for inner in container: # type: ignore[reportUnknownVariableType] + if not inner.validate_type(outer): # type: ignore[reportUnknownMemberType] + raise self._type_error(inner, outer, error_message) # type: ignore - if self._elasticsearch_syntax and hasattr(outer, "type_info"): + if self._elasticsearch_syntax and hasattr(outer, "type_info"): # type: ignore # Check edge case of in_set and ip/string comparison - outer_type = outer.type_info - if isinstance(self._schema, ecs.KqlSchema2Eql): - type_hint = self._schema.kql_schema.get(str(outer.node), "unknown") - if hasattr(self._schema, "type_mapping") and type_hint == "ip": - outer.type_info = ExtendedTypeHint.IP - for inner in container: - if not inner.validate_type(outer): - raise self._type_error(inner, outer, error_message) + outer_type = outer.type_info # type: ignore + if isinstance(self._schema, ecs.KqlSchema2Eql): # type: ignore + type_hint = self._schema.kql_schema.get(str(outer.node), "unknown") # type: ignore + if hasattr(self._schema, "type_mapping") and type_hint == "ip": # type: ignore + outer.type_info = ExtendedTypeHint.IP # type: ignore + for inner in container: # type: ignore + if not inner.validate_type(outer): # type: ignore + raise self._type_error(inner, outer, error_message) # type: ignore # reset the type - outer.type_info = outer_type + outer.type_info = outer_type # type: ignore # This will always evaluate to true/false, so it should be a boolean - term = ast.InSet(outer.node, [c.node for c in container]) - nullable = outer.nullable or any(c.nullable for c in container) - return NodeInfo(term, TypeHint.Boolean, nullable=nullable, source=node) + term = ast.InSet(outer.node, [c.node for c in container]) # type: ignore + nullable = outer.nullable or any(c.nullable for c in container) # type: ignore + return NodeInfo(term, TypeHint.Boolean, nullable=nullable, source=node) # type: ignore def custom_base_parse_decorator(func: Callable[..., Any]) -> Callable[..., Any]: """Override and address the limitations of the eql in_set method.""" @wraps(func) - def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) -> Any: - original_in_set = LarkToEQL.in_set + def wrapper(query: str, start: str | None = None, **kwargs: dict[str, Any]) -> Any: + original_in_set = LarkToEQL.in_set # type: ignore[reportUnknownMemberType] LarkToEQL.in_set = custom_in_set try: result = func(query, start=start, **kwargs) @@ -103,40 +107,42 @@ def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) - return wrapper -eql.parser._parse = custom_base_parse_decorator(base_parse) +eql.parser._parse = custom_base_parse_decorator(base_parse) # type: ignore[reportPrivateUsage] class KQLValidator(QueryValidator): """Specific fields for KQL query event types.""" @cached_property - def ast(self) -> kql.ast.Expression: - return kql.parse(self.query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) + def ast(self) -> kql.ast.Expression: # type: ignore[reportIncompatibleMethod] + return kql.parse(self.query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) # type: ignore[reportUnknownMemberType] @cached_property - def unique_fields(self) -> List[str]: - return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field))) + def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethod] + return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field))) # type: ignore[reportUnknownVariableType] def auto_add_field(self, validation_checks_error: kql.errors.KqlParseError, index_or_dataview: str) -> None: """Auto add a missing field to the schema.""" field_name = extract_error_field(self.query, validation_checks_error) + if not field_name: + raise ValueError("No fied name found for the error") field_type = ecs.get_all_flattened_schema().get(field_name) update_auto_generated_schema(index_or_dataview, field_name, field_type) def to_eql(self) -> eql.ast.Expression: - return kql.to_eql(self.query) + return kql.to_eql(self.query) # type: ignore[reportUnknownVariableType] - def validate(self, data: QueryRuleData, meta: RuleMeta, max_attempts: int = 10) -> None: + def validate(self, data: QueryRuleData, meta: RuleMeta, max_attempts: int = 10) -> None: # type: ignore[reportIncompatibleMethod] """Validate the query, called from the parent which contains [metadata] information.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return - if isinstance(data, QueryRuleData) and data.language != 'lucene': + if data.language != "lucene": packages_manifest = load_integrations_manifests() package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) for _ in range(max_attempts): - validation_checks = {"stack": None, "integrations": None} + validation_checks: dict[str, KQL_ERROR_TYPES | None] = {"stack": None, "integrations": None} # validate the query against fields within beats validation_checks["stack"] = self.validate_stack_combos(data, meta) @@ -144,57 +150,59 @@ def validate(self, data: QueryRuleData, meta: RuleMeta, max_attempts: int = 10) # validate the query against related integration fields validation_checks["integrations"] = self.validate_integration(data, meta, package_integrations) - if (validation_checks["stack"] and not package_integrations): + if validation_checks["stack"] and not package_integrations: # if auto add, try auto adding and then call stack_combo validation again - if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: + if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: # type: ignore[reportAttributeAccessIssue] # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: raise validation_checks["stack"] - if (validation_checks["stack"] and validation_checks["integrations"]): + if validation_checks["stack"] and validation_checks["integrations"]: # if auto add, try auto adding and then call stack_combo validation again - if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: + if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: # type: ignore[reportAttributeAccessIssue] # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: - click.echo(f"Stack Error Trace: {validation_checks["stack"]}") - click.echo(f"Integrations Error Trace: {validation_checks["integrations"]}") + click.echo(f"Stack Error Trace: {validation_checks['stack']}") + click.echo(f"Integrations Error Trace: {validation_checks['integrations']}") raise ValueError("Error in both stack and integrations checks") else: break - else: - raise ValueError(f"Maximum validation attempts exceeded for {data.rule_id} - {data.name}") - - def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> Union[KQL_ERROR_TYPES, None, TypeError]: + def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> KQL_ERROR_TYPES | None: """Validate the query against ECS and beats schemas across stack combinations.""" for stack_version, mapping in meta.get_validation_stack_versions().items(): - beats_version = mapping['beats'] - ecs_version = mapping['ecs'] - err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}' + beats_version = mapping["beats"] + ecs_version = mapping["ecs"] + err_trailer = f"stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}" - beat_types, beat_schema, schema = self.get_beats_schema(data.index_or_dataview, - beats_version, ecs_version) + beat_types, _, schema = self.get_beats_schema(data.index_or_dataview, beats_version, ecs_version) try: - kql.parse(self.query, schema=schema, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) + kql.parse(self.query, schema=schema, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) # type: ignore[reportUnknownMemberType] except kql.KqlParseError as exc: message = exc.error_msg trailer = err_trailer if "Unknown field" in message and beat_types: trailer = f"\nTry adding event.module or event.dataset to specify beats module\n\n{trailer}" - return kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) - except Exception as exc: - print(err_trailer) - return exc + return kql.KqlParseError( + exc.error_msg, # type: ignore[reportUnknownArgumentType] + exc.line, # type: ignore[reportUnknownArgumentType] + exc.column, # type: ignore[reportUnknownArgumentType] + exc.source, # type: ignore[reportUnknownArgumentType] + len(exc.caret.lstrip()), + trailer=trailer, + ) def validate_integration( - self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict] - ) -> Union[KQL_ERROR_TYPES, None, TypeError]: + self, + data: QueryRuleData, + meta: RuleMeta, + package_integrations: list[dict[str, Any]], + ) -> KQL_ERROR_TYPES | None: """Validate the query, called from the parent which contains [metadata] information.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": return @@ -207,14 +215,12 @@ def validate_integration( package = integration_data["package"] integration = integration_data["integration"] if integration: - package_schemas.setdefault(package, {}).setdefault(integration, {}) + package_schemas.setdefault(package, {}).setdefault(integration, {}) # type: ignore[reportUnknownMemberType] else: - package_schemas.setdefault(package, {}) + package_schemas.setdefault(package, {}) # type: ignore[reportUnknownMemberType] # Process each integration schema - for integration_schema_data in get_integration_schema_data( - data, meta, package_integrations - ): + for integration_schema_data in get_integration_schema_data(data, meta, package_integrations): package, integration = ( integration_schema_data["package"], integration_schema_data["integration"], @@ -240,9 +246,11 @@ def validate_integration( # Validate the query against the schema try: - kql.parse(self.query, - schema=integration_schema, - normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) + kql.parse( # type: ignore[reportUnknownMemberType] + self.query, + schema=integration_schema, + normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords, + ) except kql.KqlParseError as exc: if exc.error_msg == "Unknown field": field = extract_error_field(self.query, exc) @@ -260,26 +268,24 @@ def validate_integration( "integration": integration, } if data.get("notify", False): - print( - f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}" - ) + print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}") else: return kql.KqlParseError( - exc.error_msg, - exc.line, - exc.column, - exc.source, + exc.error_msg, # type: ignore[reportUnknownArgumentType] + exc.line, # type: ignore[reportUnknownArgumentType] + exc.column, # type: ignore[reportUnknownArgumentType] + exc.source, # type: ignore[reportUnknownArgumentType] len(exc.caret.lstrip()), - exc.trailer, + exc.trailer, # type: ignore[reportUnknownArgumentType] ) # Check error fields against schemas of different packages or different integrations - for field, error_data in list(error_fields.items()): - error_package, error_integration = ( + for field, error_data in list(error_fields.items()): # type: ignore + error_package, error_integration = ( # type: ignore error_data["package"], error_data["integration"], ) - for package, integrations_or_schema in package_schemas.items(): + for package, integrations_or_schema in package_schemas.items(): # type: ignore if error_integration is None: # Compare against the schema directly if there's no integration if error_package != package and field in integrations_or_schema: @@ -287,24 +293,23 @@ def validate_integration( break else: # Compare against integration schemas - for integration, schema in integrations_or_schema.items(): - check_alt_schema = ( - error_package != package or # noqa: W504 - (error_package == package and error_integration != integration) + for integration, schema in integrations_or_schema.items(): # type: ignore + check_alt_schema = error_package != package or ( # type: ignore + error_package == package and error_integration != integration ) if check_alt_schema and field in schema: del error_fields[field] # Raise the first error if error_fields: - _, error_data = next(iter(error_fields.items())) + _, error_data = next(iter(error_fields.items())) # type: ignore[reportUnknownVariableType] return kql.KqlParseError( - error_data["error"].error_msg, - error_data["error"].line, - error_data["error"].column, - error_data["error"].source, - len(error_data["error"].caret.lstrip()), - error_data["trailer"], + error_data["error"].error_msg, # type: ignore[reportUnknownArgumentType] + error_data["error"].line, # type: ignore[reportUnknownArgumentType] + error_data["error"].column, # type: ignore[reportUnknownArgumentType] + error_data["error"].source, # type: ignore[reportUnknownArgumentType] + len(error_data["error"].caret.lstrip()), # type: ignore[reportUnknownArgumentType] + error_data["trailer"], # type: ignore[reportUnknownArgumentType] ) @@ -312,70 +317,75 @@ class EQLValidator(QueryValidator): """Specific fields for EQL query event types.""" @cached_property - def ast(self) -> eql.ast.Expression: + def ast(self) -> eql.ast.Expression: # type: ignore[reportIncompatibleMethodOverrichemas] latest_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) cfg = set_eql_config(str(latest_version)) with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions, eql.parser.skip_optimizations, cfg: - return eql.parse_query(self.query) + return eql.parse_query(self.query) # type: ignore[reportUnknownVariableType] - def text_fields(self, eql_schema: Union[ecs.KqlSchema2Eql, endgame.EndgameSchema]) -> List[str]: + def text_fields(self, eql_schema: ecs.KqlSchema2Eql | endgame.EndgameSchema) -> list[str]: """Return a list of fields of type text.""" - from kql.parser import elasticsearch_type_family - schema = eql_schema.kql_schema if isinstance(eql_schema, ecs.KqlSchema2Eql) else eql_schema.endgame_schema + from kql.parser import elasticsearch_type_family # type: ignore[reportMissingTypeStubs] - return [f for f in self.unique_fields if elasticsearch_type_family(schema.get(f)) == 'text'] + schema = eql_schema.kql_schema if isinstance(eql_schema, ecs.KqlSchema2Eql) else eql_schema.endgame_schema # type: ignore + + return [f for f in self.unique_fields if elasticsearch_type_family(schema.get(f)) == "text"] # type: ignore @cached_property - def unique_fields(self) -> List[str]: - return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field))) + def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethodOverride] + return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field))) # type: ignore[reportUnknownVariableType] def auto_add_field(self, validation_checks_error: eql.errors.EqlParseError, index_or_dataview: str) -> None: """Auto add a missing field to the schema.""" field_name = extract_error_field(self.query, validation_checks_error) + if not field_name: + raise ValueError("No field name found") field_type = ecs.get_all_flattened_schema().get(field_name) update_auto_generated_schema(index_or_dataview, field_name, field_type) - def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10) -> None: + def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10) -> None: # type: ignore[reportIncompatibleMethodOverride] """Validate an EQL query while checking TOMLRule.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return - if isinstance(data, QueryRuleData) and data.language != "lucene": + if data.language != "lucene": packages_manifest = load_integrations_manifests() package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) for _ in range(max_attempts): validation_checks = {"stack": None, "integrations": None} # validate the query against fields within beats - validation_checks["stack"] = self.validate_stack_combos(data, meta) + validation_checks["stack"] = self.validate_stack_combos(data, meta) # type: ignore[reportArgumentType] + + stack_check = validation_checks["stack"] if package_integrations: # validate the query against related integration fields - validation_checks["integrations"] = self.validate_integration(data, meta, package_integrations) + validation_checks["integrations"] = self.validate_integration(data, meta, package_integrations) # type: ignore[reportArgumentType] - if validation_checks["stack"] and not package_integrations: + if stack_check and not package_integrations: # if auto add, try auto adding and then validate again if ( - "Field not recognized" in validation_checks["stack"].error_msg - and RULES_CONFIG.auto_gen_schema_file # noqa: W503 + "Field not recognized" in str(stack_check) # type: ignore[reportUnknownMemberType] + and RULES_CONFIG.auto_gen_schema_file ): # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(stack_check, data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: - raise validation_checks["stack"] + raise stack_check - elif validation_checks["stack"] and validation_checks["integrations"]: + elif stack_check and validation_checks["integrations"]: # if auto add, try auto adding and then validate again if ( - "Field not recognized" in validation_checks["stack"].error_msg - and RULES_CONFIG.auto_gen_schema_file # noqa: W503 + "Field not recognized" in stack_check.error_msg # type: ignore[reportUnknownMemberType] + and RULES_CONFIG.auto_gen_schema_file ): # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(stack_check, data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: - click.echo(f"Stack Error Trace: {validation_checks["stack"]}") - click.echo(f"Integrations Error Trace: {validation_checks["integrations"]}") + click.echo(f"Stack Error Trace: {stack_check}") + click.echo(f"Integrations Error Trace: {validation_checks['integrations']}") raise ValueError("Error in both stack and integrations checks") else: @@ -385,7 +395,8 @@ def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10 raise ValueError(f"Maximum validation attempts exceeded for {data.rule_id} - {data.name}") rule_type_config_fields, rule_type_config_validation_failed = self.validate_rule_type_configurations( - data, meta + data, # type: ignore[reportArgumentType] + meta, ) if rule_type_config_validation_failed: raise ValueError( @@ -393,42 +404,55 @@ def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10 {rule_type_config_fields}""" ) - def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> Union[EQL_ERROR_TYPES, None, ValueError]: + def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> EQL_ERROR_TYPES | None | ValueError: """Validate the query against ECS and beats schemas across stack combinations.""" for stack_version, mapping in meta.get_validation_stack_versions().items(): - beats_version = mapping['beats'] - ecs_version = mapping['ecs'] - endgame_version = mapping['endgame'] - err_trailer = f'stack: {stack_version}, beats: {beats_version},' \ - f'ecs: {ecs_version}, endgame: {endgame_version}' - - beat_types, beat_schema, schema = self.get_beats_schema(data.index_or_dataview, - beats_version, ecs_version) + beats_version = mapping["beats"] + ecs_version = mapping["ecs"] + endgame_version = mapping["endgame"] + err_trailer = ( + f"stack: {stack_version}, beats: {beats_version},ecs: {ecs_version}, endgame: {endgame_version}" + ) + + beat_types, _, schema = self.get_beats_schema(data.index_or_dataview, beats_version, ecs_version) endgame_schema = self.get_endgame_schema(data.index_or_dataview, endgame_version) eql_schema = ecs.KqlSchema2Eql(schema) # validate query against the beats and eql schema - exc = self.validate_query_with_schema(data=data, schema=eql_schema, err_trailer=err_trailer, - beat_types=beat_types, min_stack_version=meta.min_stack_version) + exc = self.validate_query_with_schema( # type: ignore[reportUnknownVariableType] + data=data, + schema=eql_schema, + err_trailer=err_trailer, + beat_types=beat_types, + min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType] + ) if exc: return exc if endgame_schema: # validate query against the endgame schema - exc = self.validate_query_with_schema(data=data, schema=endgame_schema, err_trailer=err_trailer, - min_stack_version=meta.min_stack_version) + exc = self.validate_query_with_schema( + data=data, + schema=endgame_schema, + err_trailer=err_trailer, + min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType] + ) if exc: raise exc - def validate_integration(self, data: QueryRuleData, meta: RuleMeta, - package_integrations: List[dict]) -> Union[EQL_ERROR_TYPES, None, ValueError]: + def validate_integration( + self, + data: QueryRuleData, + meta: RuleMeta, + package_integrations: list[dict[str, Any]], + ) -> EQL_ERROR_TYPES | None | ValueError: """Validate an EQL query while checking TOMLRule against integration schemas.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return error_fields = {} - package_schemas = {} + package_schemas: dict[str, Any] = {} # Initialize package_schemas with a nested structure for integration_data in package_integrations: @@ -440,9 +464,7 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_schemas.setdefault(package, {}) # Process each integration schema - for integration_schema_data in get_integration_schema_data( - data, meta, package_integrations - ): + for integration_schema_data in get_integration_schema_data(data, meta, package_integrations): ecs_version = integration_schema_data["ecs_version"] package, integration = ( integration_schema_data["package"], @@ -477,11 +499,11 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, data=data, schema=eql_schema, err_trailer=err_trailer, - min_stack_version=meta.min_stack_version, + min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType] ) if isinstance(exc, eql.EqlParseError): - message = exc.error_msg + message = exc.error_msg # type: ignore[reportUnknownVariableType] if message == "Unknown field" or "Field not recognized" in message: field = extract_error_field(self.query, exc) trailer = ( @@ -497,15 +519,13 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, "integration": integration, } if data.get("notify", False): - print( - f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}" - ) + print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}") else: return exc # Check error fields against schemas of different packages or different integrations - for field, error_data in list(error_fields.items()): - error_package, error_integration = ( + for field, error_data in list(error_fields.items()): # type: ignore + error_package, error_integration = ( # type: ignore error_data["package"], error_data["integration"], ) @@ -517,27 +537,31 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, else: # Compare against integration schemas for integration, schema in integrations_or_schema.items(): - check_alt_schema = ( - error_package != package or # noqa: W504 - (error_package == package and error_integration != integration) + check_alt_schema = ( # type: ignore + error_package != package or (error_package == package and error_integration != integration) ) if check_alt_schema and field in schema: del error_fields[field] # raise the first error if error_fields: - _, data = next(iter(error_fields.items())) - exc = data["error"] - return exc - - def validate_query_with_schema(self, data: 'QueryRuleData', schema: Union[ecs.KqlSchema2Eql, endgame.EndgameSchema], - err_trailer: str, min_stack_version: str, beat_types: list = None) -> Union[ - EQL_ERROR_TYPES, ValueError, None]: + _, data = next(iter(error_fields.items())) # type: ignore + exc = data["error"] # type: ignore + return exc # type: ignore + + def validate_query_with_schema( + self, + data: "QueryRuleData", + schema: ecs.KqlSchema2Eql | endgame.EndgameSchema, + err_trailer: str, + min_stack_version: str, + beat_types: list[str] | None = None, + ) -> EQL_ERROR_TYPES | ValueError | None: """Validate the query against the schema.""" try: config = set_eql_config(min_stack_version) with config, schema, eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: - eql.parse_query(self.query) + _ = eql.parse_query(self.query) # type: ignore[reportUnknownMemberType] except eql.EqlParseError as exc: message = exc.error_msg trailer = err_trailer @@ -546,21 +570,25 @@ def validate_query_with_schema(self, data: 'QueryRuleData', schema: Union[ecs.Kq elif "Field not recognized" in message: text_fields = self.text_fields(schema) if text_fields: - fields_str = ', '.join(text_fields) + fields_str = ", ".join(text_fields) trailer = f"\neql does not support text fields: {fields_str}\n\n{trailer}" - return exc.__class__(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) + return exc.__class__( + exc.error_msg, # type: ignore[reportUnknownArgumentType] + exc.line, # type: ignore[reportUnknownArgumentType] + exc.column, # type: ignore[reportUnknownArgumentType] + exc.source, # type: ignore[reportUnknownArgumentType] + len(exc.caret.lstrip()), + trailer=trailer, + ) except Exception as exc: print(err_trailer) - return exc + return exc # type: ignore[reportReturnType] - def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) -> \ - Tuple[List[Optional[str]], bool]: + def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) -> tuple[list[str | None], bool]: """Validate EQL rule type configurations.""" if data.timestamp_field or data.event_category_override or data.tiebreaker_field: - # get a list of rule type configuration fields # Get a list of rule type configuration fields fields = ["timestamp_field", "event_category_override", "tiebreaker_field"] @@ -570,7 +598,7 @@ def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) - min_stack_version = meta.get("min_stack_version") if min_stack_version is None: min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] + ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] schema = ecs.get_schema(ecs_version) # return a list of rule type config field values and whether any are not in the schema @@ -584,31 +612,35 @@ class ESQLValidator(QueryValidator): """Validate specific fields for ESQL query event types.""" @cached_property - def ast(self): + def ast(self): # type: ignore[reportIncompatibleMethodOverride] return None @cached_property - def unique_fields(self) -> List[str]: + def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethodOverride] """Return a list of unique fields in the query.""" # return empty list for ES|QL rules until ast is available (friendlier than raising error) # raise NotImplementedError('ES|QL query parsing not yet supported') return [] - def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: + def validate(self, _: "QueryRuleData", __: RuleMeta) -> None: # type: ignore[reportIncompatibleMethodOverride] """Validate an ESQL query while checking TOMLRule.""" # temporarily override to NOP until ES|QL query parsing is supported - def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[ - ValidationError, None, ValueError]: + def validate_integration( + self, + _: QueryRuleData, + __: RuleMeta, + ___: list[dict[str, Any]], + ) -> ValidationError | None | ValueError: # return self.validate(data, meta) pass -def extract_error_field(source: str, exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]: +def extract_error_field(source: str, exc: eql.EqlParseError | kql.KqlParseError) -> str | None: """Extract the field name from an EQL or KQL parse error.""" lines = source.splitlines() - mod = -1 if exc.line == len(lines) else 0 - line = lines[exc.line + mod] - start = exc.column - stop = start + len(exc.caret.strip()) - return re.sub(r'^\W+|\W+$', '', line[start:stop]) + mod = -1 if exc.line == len(lines) else 0 # type: ignore[reportUnknownMemberType] + line = lines[exc.line + mod] # type: ignore[reportUnknownMemberType] + start = exc.column # type: ignore[reportUnknownMemberType] + stop = start + len(exc.caret.strip()) # type: ignore[reportUnknownVariableType] + return re.sub(r"^\W+|\W+$", "", line[start:stop]) # type: ignore[reportUnknownArgumentType] diff --git a/detection_rules/schemas/__init__.py b/detection_rules/schemas/__init__.py index 98506eeb2bd..f26e4d9d0b7 100644 --- a/detection_rules/schemas/__init__.py +++ b/detection_rules/schemas/__init__.py @@ -4,8 +4,7 @@ # 2.0. import json from collections import OrderedDict -from typing import List, Optional -from typing import OrderedDict as OrderedDictType +from typing import Callable, Any, OrderedDict as OrderedDictType import jsonschema from semver import Version @@ -28,22 +27,25 @@ ) RULES_CONFIG = parse_rules_config() -SCHEMA_DIR = get_etc_path("api_schemas") -migrations = {} +SCHEMA_DIR = get_etc_path(["api_schemas"]) +MigratedFuncT = Callable[..., Any] -def all_versions() -> List[str]: +migrations: dict[str, MigratedFuncT] = {} + + +def all_versions() -> list[str]: """Get all known stack versions.""" return [str(v) for v in sorted(migrations, key=lambda x: Version.parse(x, optional_minor_and_patch=True))] -def migrate(version: str): +def migrate(version: str) -> Callable[[MigratedFuncT], MigratedFuncT]: """Decorator to set a migration.""" # checks that the migrate decorator name is semi-semantic versioned # raises validation error from semver if not - Version.parse(version, optional_minor_and_patch=True) + _ = Version.parse(version, optional_minor_and_patch=True) - def wrapper(f): + def wrapper(f: MigratedFuncT) -> MigratedFuncT: assert version not in migrations migrations[version] = f return f @@ -52,7 +54,7 @@ def wrapper(f): @cached -def get_schema_file(version: Version, rule_type: str) -> dict: +def get_schema_file(version: Version, rule_type: str) -> dict[str, Any]: path = SCHEMA_DIR / str(version) / f"{version}.{rule_type}.json" if not path.exists(): @@ -61,13 +63,13 @@ def get_schema_file(version: Version, rule_type: str) -> dict: return json.loads(path.read_text(encoding="utf8")) -def strip_additional_properties(version: Version, api_contents: dict) -> dict: +def strip_additional_properties(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Remove all fields that the target schema doesn't recognize.""" - stripped = {} + stripped: dict[str, Any] = {} target_schema = get_schema_file(version, api_contents["type"]) - for field, field_schema in target_schema["properties"].items(): + for field, _ in target_schema["properties"].items(): if field in api_contents: stripped[field] = api_contents[field] @@ -76,7 +78,7 @@ def strip_additional_properties(version: Version, api_contents: dict) -> dict: return stripped -def strip_non_public_fields(min_stack_version: Version, data_dict: dict) -> dict: +def strip_non_public_fields(min_stack_version: Version, data_dict: dict[str, Any]) -> dict[str, Any]: """Remove all non public fields.""" for field, version_range in definitions.NON_PUBLIC_FIELDS.items(): if version_range[0] <= min_stack_version <= (version_range[1] or min_stack_version): @@ -86,23 +88,23 @@ def strip_non_public_fields(min_stack_version: Version, data_dict: dict) -> dict @migrate("7.8") -def migrate_to_7_8(version: Version, api_contents: dict) -> dict: +def migrate_to_7_8(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.8.""" return strip_additional_properties(version, api_contents) @migrate("7.9") -def migrate_to_7_9(version: Version, api_contents: dict) -> dict: +def migrate_to_7_9(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.9.""" return strip_additional_properties(version, api_contents) @migrate("7.10") -def downgrade_threat_to_7_10(version: Version, api_contents: dict) -> dict: +def downgrade_threat_to_7_10(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Downgrade the threat mapping changes from 7.11 to 7.10.""" if "threat" in api_contents: v711_threats = api_contents.get("threat", []) - v710_threats = [] + v710_threats: list[Any] = [] for threat in v711_threats: # drop tactic without threat @@ -130,24 +132,24 @@ def downgrade_threat_to_7_10(version: Version, api_contents: dict) -> dict: @migrate("7.11") -def downgrade_threshold_to_7_11(version: Version, api_contents: dict) -> dict: +def downgrade_threshold_to_7_11(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Remove 7.12 threshold changes that don't impact the rule.""" if "threshold" in api_contents: - threshold = api_contents['threshold'] - threshold_field = threshold['field'] + threshold = api_contents["threshold"] + threshold_field = threshold["field"] # attempt to convert threshold field to a string if len(threshold_field) > 1: - raise ValueError('Cannot downgrade a threshold rule that has multiple threshold fields defined') + raise ValueError("Cannot downgrade a threshold rule that has multiple threshold fields defined") - if threshold.get('cardinality'): - raise ValueError('Cannot downgrade a threshold rule that has a defined cardinality') + if threshold.get("cardinality"): + raise ValueError("Cannot downgrade a threshold rule that has a defined cardinality") api_contents = api_contents.copy() api_contents["threshold"] = api_contents["threshold"].copy() # if cardinality was defined with no field or value - api_contents['threshold'].pop('cardinality', None) + api_contents["threshold"].pop("cardinality", None) api_contents["threshold"]["field"] = api_contents["threshold"]["field"][0] # finally, downgrade any additional properties that were added @@ -155,20 +157,20 @@ def downgrade_threshold_to_7_11(version: Version, api_contents: dict) -> dict: @migrate("7.12") -def migrate_to_7_12(version: Version, api_contents: dict) -> dict: +def migrate_to_7_12(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.12.""" return strip_additional_properties(version, api_contents) @migrate("7.13") -def downgrade_ml_multijob_713(version: Version, api_contents: dict) -> dict: +def downgrade_ml_multijob_713(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Convert `machine_learning_job_id` as an array to a string for < 7.13.""" if "machine_learning_job_id" in api_contents: job_id = api_contents["machine_learning_job_id"] if isinstance(job_id, list): - if len(job_id) > 1: - raise ValueError('Cannot downgrade an ML rule with multiple jobs defined') + if len(job_id) > 1: # type: ignore[reportUnknownArgumentType] + raise ValueError("Cannot downgrade an ML rule with multiple jobs defined") api_contents = api_contents.copy() api_contents["machine_learning_job_id"] = job_id[0] @@ -178,149 +180,150 @@ def downgrade_ml_multijob_713(version: Version, api_contents: dict) -> dict: @migrate("7.14") -def migrate_to_7_14(version: Version, api_contents: dict) -> dict: +def migrate_to_7_14(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.14.""" return strip_additional_properties(version, api_contents) @migrate("7.15") -def migrate_to_7_15(version: Version, api_contents: dict) -> dict: +def migrate_to_7_15(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.15.""" return strip_additional_properties(version, api_contents) @migrate("7.16") -def migrate_to_7_16(version: Version, api_contents: dict) -> dict: +def migrate_to_7_16(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.16.""" return strip_additional_properties(version, api_contents) @migrate("8.0") -def migrate_to_8_0(version: Version, api_contents: dict) -> dict: +def migrate_to_8_0(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.0.""" return strip_additional_properties(version, api_contents) @migrate("8.1") -def migrate_to_8_1(version: Version, api_contents: dict) -> dict: +def migrate_to_8_1(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.1.""" return strip_additional_properties(version, api_contents) @migrate("8.2") -def migrate_to_8_2(version: Version, api_contents: dict) -> dict: +def migrate_to_8_2(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.2.""" return strip_additional_properties(version, api_contents) @migrate("8.3") -def migrate_to_8_3(version: Version, api_contents: dict) -> dict: +def migrate_to_8_3(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.3.""" return strip_additional_properties(version, api_contents) @migrate("8.4") -def migrate_to_8_4(version: Version, api_contents: dict) -> dict: +def migrate_to_8_4(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.4.""" return strip_additional_properties(version, api_contents) @migrate("8.5") -def migrate_to_8_5(version: Version, api_contents: dict) -> dict: +def migrate_to_8_5(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.5.""" return strip_additional_properties(version, api_contents) @migrate("8.6") -def migrate_to_8_6(version: Version, api_contents: dict) -> dict: +def migrate_to_8_6(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.6.""" return strip_additional_properties(version, api_contents) @migrate("8.7") -def migrate_to_8_7(version: Version, api_contents: dict) -> dict: +def migrate_to_8_7(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.7.""" return strip_additional_properties(version, api_contents) @migrate("8.8") -def migrate_to_8_8(version: Version, api_contents: dict) -> dict: +def migrate_to_8_8(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.8.""" return strip_additional_properties(version, api_contents) @migrate("8.9") -def migrate_to_8_9(version: Version, api_contents: dict) -> dict: +def migrate_to_8_9(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.9.""" return strip_additional_properties(version, api_contents) @migrate("8.10") -def migrate_to_8_10(version: Version, api_contents: dict) -> dict: +def migrate_to_8_10(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.10.""" return strip_additional_properties(version, api_contents) @migrate("8.11") -def migrate_to_8_11(version: Version, api_contents: dict) -> dict: +def migrate_to_8_11(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.11.""" return strip_additional_properties(version, api_contents) @migrate("8.12") -def migrate_to_8_12(version: Version, api_contents: dict) -> dict: +def migrate_to_8_12(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.12.""" return strip_additional_properties(version, api_contents) @migrate("8.13") -def migrate_to_8_13(version: Version, api_contents: dict) -> dict: +def migrate_to_8_13(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.13.""" return strip_additional_properties(version, api_contents) @migrate("8.14") -def migrate_to_8_14(version: Version, api_contents: dict) -> dict: +def migrate_to_8_14(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.14.""" return strip_additional_properties(version, api_contents) @migrate("8.15") -def migrate_to_8_15(version: Version, api_contents: dict) -> dict: +def migrate_to_8_15(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.15.""" return strip_additional_properties(version, api_contents) @migrate("8.16") -def migrate_to_8_16(version: Version, api_contents: dict) -> dict: +def migrate_to_8_16(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.16.""" return strip_additional_properties(version, api_contents) @migrate("8.17") -def migrate_to_8_17(version: Version, api_contents: dict) -> dict: +def migrate_to_8_17(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.17.""" return strip_additional_properties(version, api_contents) @migrate("8.18") -def migrate_to_8_18(version: Version, api_contents: dict) -> dict: +def migrate_to_8_18(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.18.""" return strip_additional_properties(version, api_contents) @migrate("9.0") -def migrate_to_9_0(version: Version, api_contents: dict) -> dict: +def migrate_to_9_0(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 9.0.""" return strip_additional_properties(version, api_contents) -def downgrade(api_contents: dict, target_version: str, current_version: Optional[str] = None) -> dict: +def downgrade( + api_contents: dict[str, Any], target_version: str, current_version_val: str | None = None +) -> dict[str, Any]: """Downgrade a rule to a target stack version.""" from ..packaging import current_stack_version - if current_version is None: - current_version = current_stack_version() + current_version = current_version_val or current_stack_version() current = Version.parse(current_version, optional_minor_and_patch=True) target = Version.parse(target_version, optional_minor_and_patch=True) @@ -340,38 +343,40 @@ def downgrade(api_contents: dict, target_version: str, current_version: Optional @cached -def load_stack_schema_map() -> dict: +def load_stack_schema_map() -> dict[str, Any]: return RULES_CONFIG.stack_schema_map @cached -def get_stack_schemas(stack_version: Optional[str] = '0.0.0') -> OrderedDictType[str, dict]: +def get_stack_schemas(stack_version_val: str | None = "0.0.0") -> OrderedDictType[str, dict[str, Any]]: """ Return all ECS, beats, and custom stack versions for every stack version. Only versions >= specified stack version and <= package are returned. """ - stack_version = Version.parse(stack_version or '0.0.0', optional_minor_and_patch=True) + stack_version = Version.parse(stack_version_val or "0.0.0", optional_minor_and_patch=True) current_package = Version.parse(load_current_package_version(), optional_minor_and_patch=True) stack_map = load_stack_schema_map() - versions = {k: v for k, v in stack_map.items() if - (((mapped_version := Version.parse(k)) >= stack_version) - and (mapped_version <= current_package) and v)} # noqa: W503 + versions = { + k: v + for k, v in stack_map.items() + if (((mapped_version := Version.parse(k)) >= stack_version) and (mapped_version <= current_package) and v) + } if stack_version > current_package: - versions[stack_version] = {'beats': 'main', 'ecs': 'master'} + versions[stack_version] = {"beats": "main", "ecs": "master"} versions_reversed = OrderedDict(sorted(versions.items(), reverse=True)) return versions_reversed -def get_stack_versions(drop_patch=False) -> List[str]: +def get_stack_versions(drop_patch: bool = False) -> list[str]: """Get a list of stack versions supported (for the matrix).""" versions = list(load_stack_schema_map()) if drop_patch: - abridged_versions = [] + abridged_versions: list[str] = [] for version in versions: - abridged, _ = version.rsplit('.', 1) + abridged, _ = version.rsplit(".", 1) abridged_versions.append(abridged) return abridged_versions else: diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index b18cc9cca9f..0dea7cc19dd 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -4,256 +4,344 @@ # 2.0. """Custom shared definitions for schemas.""" + import os -from typing import Final, List, Literal +import re +from typing import Final, Literal, Annotated, Pattern, Any, Callable from marshmallow import fields, validate -from marshmallow_dataclass import NewType from semver import Version from detection_rules.config import CUSTOM_RULES_DIR -def elastic_timeline_template_id_validator(): +def elastic_timeline_template_id_validator() -> Callable[[Any], Any]: """Custom validator for Timeline Template IDs.""" - def validator(value): - if os.environ.get('DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION') is not None: - fields.String().deserialize(value) - else: - validate.OneOf(list(TIMELINE_TEMPLATES))(value) - return validator + def validator_wrapper(value: Any) -> Any: + if os.environ.get("DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION") is None: + template_ids = list(TIMELINE_TEMPLATES) + validator = validate.OneOf(template_ids) + validator(value) + return value + + return validator_wrapper -def elastic_timeline_template_title_validator(): +def elastic_timeline_template_title_validator() -> Callable[[Any], Any]: """Custom validator for Timeline Template Titles.""" - def validator(value): - if os.environ.get('DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION') is not None: - fields.String().deserialize(value) - else: - validate.OneOf(TIMELINE_TEMPLATES.values())(value) - return validator + def validator_wrapper(value: Any) -> Any: + if os.environ.get("DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION") is None: + template_titles = TIMELINE_TEMPLATES.values() + validator = validate.OneOf(template_titles) + validator(value) + return value + + return validator_wrapper -def elastic_rule_name_regexp(pattern): +def elastic_rule_name_regexp(pattern: Pattern[str]) -> Callable[[Any], Any]: """Custom validator for rule names.""" - def validator(value): + + regexp_validator = validate.Regexp(pattern) + + def validator_wrapper(value: Any) -> Any: if not CUSTOM_RULES_DIR: - validate.Regexp(pattern)(value) - else: - fields.String().deserialize(value) - return validator + regexp_validator(value) + return value + + return validator_wrapper ASSET_TYPE = "security_rule" SAVED_OBJECT_TYPE = "security-rule" -DATE_PATTERN = r'^\d{4}/\d{2}/\d{2}$' -MATURITY_LEVELS = ['development', 'experimental', 'beta', 'production', 'deprecated'] -OS_OPTIONS = ['windows', 'linux', 'macos'] -NAME_PATTERN = r'^[a-zA-Z0-9].+?[a-zA-Z0-9\[\]()]$' -PR_PATTERN = r'^$|\d+$' -SHA256_PATTERN = r'^[a-fA-F0-9]{64}$' -UUID_PATTERN = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' - -_version = r'\d+\.\d+(\.\d+[\w-]*)*' -CONDITION_VERSION_PATTERN = rf'^\^{_version}$' -VERSION_PATTERN = f'^{_version}$' -MINOR_SEMVER = r'^\d+\.\d+$' -BRANCH_PATTERN = f'{VERSION_PATTERN}|^master$' +DATE_PATTERN = re.compile(r"^\d{4}/\d{2}/\d{2}$") +MATURITY_LEVELS = ["development", "experimental", "beta", "production", "deprecated"] +OS_OPTIONS = ["windows", "linux", "macos"] + +NAME_PATTERN = re.compile(r"^[a-zA-Z0-9].+?[a-zA-Z0-9\[\]()]$") +PR_PATTERN = re.compile(r"^$|\d+$") +SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$") +UUID_PATTERN = re.compile(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$") + +_version = r"\d+\.\d+(\.\d+[\w-]*)*" +CONDITION_VERSION_PATTERN = re.compile(rf"^\^{_version}$") +VERSION_PATTERN = f"^{_version}$" +MINOR_SEMVER = re.compile(r"^\d+\.\d+$") +BRANCH_PATTERN = f"{VERSION_PATTERN}|^master$" ELASTICSEARCH_EQL_FEATURES = { - "allow_negation": (Version.parse('8.9.0'), None), - "allow_runs": (Version.parse('7.16.0'), None), - "allow_sample": (Version.parse('8.6.0'), None), - "elasticsearch_validate_optional_fields": (Version.parse('7.16.0'), None) + "allow_negation": (Version.parse("8.9.0"), None), + "allow_runs": (Version.parse("7.16.0"), None), + "allow_sample": (Version.parse("8.6.0"), None), + "elasticsearch_validate_optional_fields": (Version.parse("7.16.0"), None), } -NON_DATASET_PACKAGES = ['apm', - 'auditd_manager', - 'cloud_defend', - 'endpoint', - 'jamf_protect', - 'network_traffic', - 'system', - 'windows', - 'sentinel_one_cloud_funnel', - 'ti_rapid7_threat_command', - 'm365_defender', - 'panw', - 'crowdstrike'] +NON_DATASET_PACKAGES = [ + "apm", + "auditd_manager", + "cloud_defend", + "endpoint", + "jamf_protect", + "network_traffic", + "system", + "windows", + "sentinel_one_cloud_funnel", + "ti_rapid7_threat_command", + "m365_defender", + "panw", + "crowdstrike", +] NON_PUBLIC_FIELDS = { - "related_integrations": (Version.parse('8.3.0'), None), - "required_fields": (Version.parse('8.3.0'), None), - "setup": (Version.parse('8.3.0'), None) + "related_integrations": (Version.parse("8.3.0"), None), + "required_fields": (Version.parse("8.3.0"), None), + "setup": (Version.parse("8.3.0"), None), } -INTERVAL_PATTERN = r'^\d+[mshd]$' -TACTIC_URL = r'^https://attack.mitre.org/tactics/TA[0-9]+/$' -TECHNIQUE_URL = r'^https://attack.mitre.org/techniques/T[0-9]+/$' -SUBTECHNIQUE_URL = r'^https://attack.mitre.org/techniques/T[0-9]+/[0-9]+/$' -MACHINE_LEARNING = 'machine_learning' -QUERY = 'query' +INTERVAL_PATTERN = r"^\d+[mshd]$" +TACTIC_URL = r"^https://attack.mitre.org/tactics/TA[0-9]+/$" +TECHNIQUE_URL = r"^https://attack.mitre.org/techniques/T[0-9]+/$" +SUBTECHNIQUE_URL = r"^https://attack.mitre.org/techniques/T[0-9]+/[0-9]+/$" +MACHINE_LEARNING = "machine_learning" +QUERY = "query" QUERY_FIELD_OP_EXCEPTIONS = ["powershell.file.script_block_text"] # we had a bad rule ID make it in before tightening up the pattern, and so we have to let it bypass -KNOWN_BAD_RULE_IDS = Literal['119c8877-8613-416d-a98a-96b6664ee73a5'] -KNOWN_BAD_DEPRECATED_DATES = Literal['2021-03-03'] +KNOWN_BAD_RULE_IDS = Literal["119c8877-8613-416d-a98a-96b6664ee73a5"] +KNOWN_BAD_DEPRECATED_DATES = Literal["2021-03-03"] # Known Null values that cannot be handled in TOML due to lack of Null value support via compound dicts KNOWN_NULL_ENTRIES = [{"rule.actions": "frequency.throttle"}] -OPERATORS = ['equals'] - -TIMELINE_TEMPLATES: Final[dict] = { - 'db366523-f1c6-4c1f-8731-6ce5ed9e5717': 'Generic Endpoint Timeline', - '91832785-286d-4ebe-b884-1a208d111a70': 'Generic Network Timeline', - '76e52245-7519-4251-91ab-262fb1a1728c': 'Generic Process Timeline', - '495ad7a7-316e-4544-8a0f-9c098daee76e': 'Generic Threat Match Timeline', - '4d4c0b59-ea83-483f-b8c1-8c360ee53c5c': 'Comprehensive File Timeline', - 'e70679c2-6cde-4510-9764-4823df18f7db': 'Comprehensive Process Timeline', - '300afc76-072d-4261-864d-4149714bf3f1': 'Comprehensive Network Timeline', - '3e47ef71-ebfc-4520-975c-cb27fc090799': 'Comprehensive Registry Timeline', - '3e827bab-838a-469f-bd1e-5e19a2bff2fd': 'Alerts Involving a Single User Timeline', - '4434b91a-94ca-4a89-83cb-a37cdc0532b7': 'Alerts Involving a Single Host Timeline' +OPERATORS = ["equals"] + +TIMELINE_TEMPLATES: Final[dict[str, str]] = { + "db366523-f1c6-4c1f-8731-6ce5ed9e5717": "Generic Endpoint Timeline", + "91832785-286d-4ebe-b884-1a208d111a70": "Generic Network Timeline", + "76e52245-7519-4251-91ab-262fb1a1728c": "Generic Process Timeline", + "495ad7a7-316e-4544-8a0f-9c098daee76e": "Generic Threat Match Timeline", + "4d4c0b59-ea83-483f-b8c1-8c360ee53c5c": "Comprehensive File Timeline", + "e70679c2-6cde-4510-9764-4823df18f7db": "Comprehensive Process Timeline", + "300afc76-072d-4261-864d-4149714bf3f1": "Comprehensive Network Timeline", + "3e47ef71-ebfc-4520-975c-cb27fc090799": "Comprehensive Registry Timeline", + "3e827bab-838a-469f-bd1e-5e19a2bff2fd": "Alerts Involving a Single User Timeline", + "4434b91a-94ca-4a89-83cb-a37cdc0532b7": "Alerts Involving a Single Host Timeline", } EXPECTED_RULE_TAGS = [ - 'Data Source: Active Directory', - 'Data Source: Amazon Web Services', - 'Data Source: Auditd Manager', - 'Data Source: AWS', - 'Data Source: APM', - 'Data Source: Azure', - 'Data Source: CyberArk PAS', - 'Data Source: Elastic Defend', - 'Data Source: Elastic Defend for Containers', - 'Data Source: Elastic Endgame', - 'Data Source: GCP', - 'Data Source: Google Cloud Platform', - 'Data Source: Google Workspace', - 'Data Source: Kubernetes', - 'Data Source: Microsoft 365', - 'Data Source: Okta', - 'Data Source: PowerShell Logs', - 'Data Source: Sysmon Only', - 'Data Source: Zoom', - 'Domain: Cloud', - 'Domain: Container', - 'Domain: Endpoint', - 'Mitre Atlas: *', - 'OS: Linux', - 'OS: macOS', - 'OS: Windows', - 'Rule Type: BBR', - 'Resources: Investigation Guide', - 'Rule Type: Higher-Order Rule', - 'Rule Type: Machine Learning', - 'Rule Type: ML', - 'Tactic: Collection', - 'Tactic: Command and Control', - 'Tactic: Credential Access', - 'Tactic: Defense Evasion', - 'Tactic: Discovery', - 'Tactic: Execution', - 'Tactic: Exfiltration', - 'Tactic: Impact', - 'Tactic: Initial Access', - 'Tactic: Lateral Movement', - 'Tactic: Persistence', - 'Tactic: Privilege Escalation', - 'Tactic: Reconnaissance', - 'Tactic: Resource Development', - 'Threat: BPFDoor', - 'Threat: Cobalt Strike', - 'Threat: Lightning Framework', - 'Threat: Orbit', - 'Threat: Rootkit', - 'Threat: TripleCross', - 'Use Case: Active Directory Monitoring', - 'Use Case: Asset Visibility', - 'Use Case: Configuration Audit', - 'Use Case: Guided Onboarding', - 'Use Case: Identity and Access Audit', - 'Use Case: Log Auditing', - 'Use Case: Network Security Monitoring', - 'Use Case: Threat Detection', - 'Use Case: UEBA', - 'Use Case: Vulnerability' + "Data Source: Active Directory", + "Data Source: Amazon Web Services", + "Data Source: Auditd Manager", + "Data Source: AWS", + "Data Source: APM", + "Data Source: Azure", + "Data Source: CyberArk PAS", + "Data Source: Elastic Defend", + "Data Source: Elastic Defend for Containers", + "Data Source: Elastic Endgame", + "Data Source: GCP", + "Data Source: Google Cloud Platform", + "Data Source: Google Workspace", + "Data Source: Kubernetes", + "Data Source: Microsoft 365", + "Data Source: Okta", + "Data Source: PowerShell Logs", + "Data Source: Sysmon Only", + "Data Source: Zoom", + "Domain: Cloud", + "Domain: Container", + "Domain: Endpoint", + "Mitre Atlas: *", + "OS: Linux", + "OS: macOS", + "OS: Windows", + "Rule Type: BBR", + "Resources: Investigation Guide", + "Rule Type: Higher-Order Rule", + "Rule Type: Machine Learning", + "Rule Type: ML", + "Tactic: Collection", + "Tactic: Command and Control", + "Tactic: Credential Access", + "Tactic: Defense Evasion", + "Tactic: Discovery", + "Tactic: Execution", + "Tactic: Exfiltration", + "Tactic: Impact", + "Tactic: Initial Access", + "Tactic: Lateral Movement", + "Tactic: Persistence", + "Tactic: Privilege Escalation", + "Tactic: Reconnaissance", + "Tactic: Resource Development", + "Threat: BPFDoor", + "Threat: Cobalt Strike", + "Threat: Lightning Framework", + "Threat: Orbit", + "Threat: Rootkit", + "Threat: TripleCross", + "Use Case: Active Directory Monitoring", + "Use Case: Asset Visibility", + "Use Case: Configuration Audit", + "Use Case: Guided Onboarding", + "Use Case: Identity and Access Audit", + "Use Case: Log Auditing", + "Use Case: Network Security Monitoring", + "Use Case: Threat Detection", + "Use Case: UEBA", + "Use Case: Vulnerability", ] -NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) -MACHINE_LEARNING_PACKAGES = ['LMD', 'DGA', 'DED', 'ProblemChild', 'Beaconing', 'PAD'] -AlertSuppressionGroupBy = NewType('AlertSuppressionGroupBy', List[NonEmptyStr], validate=validate.Length(min=1, max=3)) -AlertSuppressionMissing = NewType('AlertSuppressionMissing', str, - validate=validate.OneOf(['suppress', 'doNotSuppress'])) -AlertSuppressionValue = NewType("AlertSupressionValue", int, validate=validate.Range(min=1)) -TimeUnits = Literal['s', 'm', 'h'] -BranchVer = NewType('BranchVer', str, validate=validate.Regexp(BRANCH_PATTERN)) -CardinalityFields = NewType('CardinalityFields', List[NonEmptyStr], validate=validate.Length(min=0, max=3)) -CodeString = NewType("CodeString", str) -ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN)) -Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN)) -ExceptionEntryOperator = Literal['included', 'excluded'] -ExceptionEntryType = Literal['match', 'match_any', 'exists', 'list', 'wildcard', 'nested'] -ExceptionNamespaceType = Literal['single', 'agnostic'] -ExceptionItemEndpointTags = Literal['endpoint', 'os:windows', 'os:linux', 'os:macos'] -ExceptionContainerType = Literal['detection', 'endpoint', 'rule_default'] -ExceptionItemType = Literal['simple'] + +MACHINE_LEARNING_PACKAGES = ["LMD", "DGA", "DED", "ProblemChild", "Beaconing", "PAD"] + +CodeString = str +Markdown = CodeString + +TimeUnits = Literal["s", "m", "h"] +ExceptionEntryOperator = Literal["included", "excluded"] +ExceptionEntryType = Literal["match", "match_any", "exists", "list", "wildcard", "nested"] +ExceptionNamespaceType = Literal["single", "agnostic"] +ExceptionItemEndpointTags = Literal["endpoint", "os:windows", "os:linux", "os:macos"] +ExceptionContainerType = Literal["detection", "endpoint", "rule_default"] +ExceptionItemType = Literal["simple"] FilterLanguages = Literal["eql", "esql", "kuery", "lucene"] -Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN)) + InvestigateProviderQueryType = Literal["phrase", "range"] InvestigateProviderValueType = Literal["string", "boolean"] -Markdown = NewType("MarkdownField", CodeString) -Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated'] -MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1)) -NewTermsFields = NewType('NewTermsFields', List[NonEmptyStr], validate=validate.Length(min=1, max=3)) -Operator = Literal['equals'] -OSType = Literal['windows', 'linux', 'macos'] -PositiveInteger = NewType('PositiveInteger', int, validate=validate.Range(min=1)) -RiskScore = NewType("MaxSignals", int, validate=validate.Range(min=1, max=100)) -RuleName = NewType('RuleName', str, validate=elastic_rule_name_regexp(NAME_PATTERN)) -RuleType = Literal['query', 'saved_query', 'machine_learning', 'eql', 'esql', 'threshold', 'threat_match', 'new_terms'] -SemVer = NewType('SemVer', str, validate=validate.Regexp(VERSION_PATTERN)) -SemVerMinorOnly = NewType('SemVerFullStrict', str, validate=validate.Regexp(MINOR_SEMVER)) -Severity = Literal['low', 'medium', 'high', 'critical'] -Sha256 = NewType('Sha256', str, validate=validate.Regexp(SHA256_PATTERN)) -SubTechniqueURL = NewType('SubTechniqueURL', str, validate=validate.Regexp(SUBTECHNIQUE_URL)) -StoreType = Literal['appState', 'globalState'] -TacticURL = NewType('TacticURL', str, validate=validate.Regexp(TACTIC_URL)) -TechniqueURL = NewType('TechniqueURL', str, validate=validate.Regexp(TECHNIQUE_URL)) -ThresholdValue = NewType("ThresholdValue", int, validate=validate.Range(min=1)) -TimelineTemplateId = NewType('TimelineTemplateId', str, validate=elastic_timeline_template_id_validator()) -TimelineTemplateTitle = NewType('TimelineTemplateTitle', str, validate=elastic_timeline_template_title_validator()) + +Operator = Literal["equals"] +OSType = Literal["windows", "linux", "macos"] + +Severity = Literal["low", "medium", "high", "critical"] +Maturity = Literal["development", "experimental", "beta", "production", "deprecated"] +RuleType = Literal["query", "saved_query", "machine_learning", "eql", "esql", "threshold", "threat_match", "new_terms"] +StoreType = Literal["appState", "globalState"] TransformTypes = Literal["osquery", "investigate"] -UUIDString = NewType('UUIDString', str, validate=validate.Regexp(UUID_PATTERN)) -BuildingBlockType = Literal['default'] +BuildingBlockType = Literal["default"] + +nonEmptyStringField = fields.String(validate=validate.Length(min=1)) +NonEmptyStr = Annotated[str, nonEmptyStringField] + +AlertSuppressionGroupBy = Annotated[ + list[NonEmptyStr], fields.List(nonEmptyStringField, validate=validate.Length(min=1, max=3)) +] +AlertSuppressionMissing = Annotated[str, fields.String(validate=validate.OneOf(["suppress", "doNotSuppress"]))] +AlertSuppressionValue = Annotated[int, fields.String(validate=validate.Range(min=1))] +BranchVer = Annotated[str, fields.String(validate=validate.Regexp(BRANCH_PATTERN))] +CardinalityFields = Annotated[ + list[NonEmptyStr], + fields.List(nonEmptyStringField, validate=validate.Length(min=0, max=3)), +] +ConditionSemVer = Annotated[str, fields.String(validate=validate.Regexp(CONDITION_VERSION_PATTERN))] +Date = Annotated[str, fields.String(validate=validate.Regexp(DATE_PATTERN))] +Interval = Annotated[str, fields.String(validate=validate.Regexp(INTERVAL_PATTERN))] +MaxSignals = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +NewTermsFields = Annotated[list[NonEmptyStr], fields.List(nonEmptyStringField, validate=validate.Length(min=1, max=3))] +PositiveInteger = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +RiskScore = Annotated[int, fields.Integer(validate=validate.Range(min=1, max=100))] +RuleName = Annotated[str, fields.String(validate=elastic_rule_name_regexp(NAME_PATTERN))] +SemVer = Annotated[str, fields.String(validate=validate.Regexp(VERSION_PATTERN))] +SemVerMinorOnly = Annotated[str, fields.String(validate=validate.Regexp(MINOR_SEMVER))] +Sha256 = Annotated[str, fields.String(validate=validate.Regexp(SHA256_PATTERN))] +SubTechniqueURL = Annotated[str, fields.String(validate=validate.Regexp(SUBTECHNIQUE_URL))] +TacticURL = Annotated[str, fields.String(validate=validate.Regexp(TACTIC_URL))] +TechniqueURL = Annotated[str, fields.String(validate=validate.Regexp(TECHNIQUE_URL))] +ThresholdValue = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +TimelineTemplateId = Annotated[str, fields.String(validate=elastic_timeline_template_id_validator())] +TimelineTemplateTitle = Annotated[str, fields.String(validate=elastic_timeline_template_title_validator())] +UUIDString = Annotated[str, fields.String(validate=validate.Regexp(UUID_PATTERN))] # experimental machine learning features and releases -MachineLearningType = getattr(Literal, '__getitem__')(tuple(MACHINE_LEARNING_PACKAGES)) # noqa: E999 -MachineLearningTypeLower = getattr(Literal, '__getitem__')( - tuple(map(str.lower, MACHINE_LEARNING_PACKAGES))) # noqa: E999 -## +MachineLearningType = Literal[MACHINE_LEARNING_PACKAGES] +MACHINE_LEARNING_PACKAGES_LOWER = tuple(map(str.lower, MACHINE_LEARNING_PACKAGES)) +MachineLearningTypeLower = Literal[MACHINE_LEARNING_PACKAGES_LOWER] ActionTypeId = Literal[ - ".slack", ".slack_api", ".email", ".index", ".pagerduty", ".swimlane", ".webhook", ".servicenow", - ".servicenow-itom", ".servicenow-sir", ".jira", ".resilient", ".opsgenie", ".teams", ".torq", ".tines", - ".d3security" + ".slack", + ".slack_api", + ".email", + ".index", + ".pagerduty", + ".swimlane", + ".webhook", + ".servicenow", + ".servicenow-itom", + ".servicenow-sir", + ".jira", + ".resilient", + ".opsgenie", + ".teams", + ".torq", + ".tines", + ".d3security", ] EsDataTypes = Literal[ - 'binary', 'boolean', - 'keyword', 'constant_keyword', 'wildcard', - 'long', 'integer', 'short', 'byte', 'double', 'float', 'half_float', 'scaled_float', 'unsigned_long', - 'date', 'date_nanos', - 'alias', 'object', 'flatten', 'nested', 'join', - 'integer_range', 'float_range', 'long_range', 'double_range', 'date_range', 'ip_range', - 'ip', 'version', 'murmur3', 'aggregate_metric_double', 'histogram', - 'text', 'text_match_only', 'annotated-text', 'completion', 'search_as_you_type', 'token_count', - 'dense_vector', 'sparse_vector', 'rank_feature', 'rank_features', - 'geo_point', 'geo_shape', 'point', 'shape', - 'percolator' + "binary", + "boolean", + "keyword", + "constant_keyword", + "wildcard", + "long", + "integer", + "short", + "byte", + "double", + "float", + "half_float", + "scaled_float", + "unsigned_long", + "date", + "date_nanos", + "alias", + "object", + "flatten", + "nested", + "join", + "integer_range", + "float_range", + "long_range", + "double_range", + "date_range", + "ip_range", + "ip", + "version", + "murmur3", + "aggregate_metric_double", + "histogram", + "text", + "text_match_only", + "annotated-text", + "completion", + "search_as_you_type", + "token_count", + "dense_vector", + "sparse_vector", + "rank_feature", + "rank_features", + "geo_point", + "geo_shape", + "point", + "shape", + "percolator", ] # definitions for the integration to index mapping unit test case -IGNORE_IDS = ["eb079c62-4481-4d6e-9643-3ca499df7aaa", "699e9fdb-b77c-4c01-995c-1c15019b9c43", - "0c9a14d9-d65d-486f-9b5b-91e4e6b22bd0", "a198fbbd-9413-45ec-a269-47ae4ccf59ce", - "0c41e478-5263-4c69-8f9e-7dfd2c22da64", "aab184d3-72b3-4639-b242-6597c99d8bca", - "a61809f3-fb5b-465c-8bff-23a8a068ac60", "f3e22c8b-ea47-45d1-b502-b57b6de950b3", - "fcf18de8-ad7d-4d01-b3f7-a11d5b3883af"] -IGNORE_INDICES = ['.alerts-security.*', 'logs-*', 'metrics-*', 'traces-*', 'endgame-*', - 'filebeat-*', 'packetbeat-*', 'auditbeat-*', 'winlogbeat-*'] +IGNORE_IDS = [ + "eb079c62-4481-4d6e-9643-3ca499df7aaa", + "699e9fdb-b77c-4c01-995c-1c15019b9c43", + "0c9a14d9-d65d-486f-9b5b-91e4e6b22bd0", + "a198fbbd-9413-45ec-a269-47ae4ccf59ce", + "0c41e478-5263-4c69-8f9e-7dfd2c22da64", + "aab184d3-72b3-4639-b242-6597c99d8bca", + "a61809f3-fb5b-465c-8bff-23a8a068ac60", + "f3e22c8b-ea47-45d1-b502-b57b6de950b3", + "fcf18de8-ad7d-4d01-b3f7-a11d5b3883af", +] +IGNORE_INDICES = [ + ".alerts-security.*", + "logs-*", + "metrics-*", + "traces-*", + "endgame-*", + "filebeat-*", + "packetbeat-*", + "auditbeat-*", + "winlogbeat-*", +] diff --git a/detection_rules/schemas/stack_compat.py b/detection_rules/schemas/stack_compat.py index 0981f30cb5c..e4acae8c2ac 100644 --- a/detection_rules/schemas/stack_compat.py +++ b/detection_rules/schemas/stack_compat.py @@ -3,8 +3,8 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. +from typing import Any from dataclasses import Field -from typing import Dict, List, Optional, Tuple from semver import Version @@ -12,22 +12,21 @@ @cached -def get_restricted_field(schema_field: Field) -> Tuple[Optional[Version], Optional[Version]]: +def get_restricted_field(schema_field: Field[Any]) -> tuple[Version | None, Version | None]: """Get an optional min and max compatible versions of a field (from a schema or dataclass).""" # nested get is to support schema fields being passed directly from dataclass or fields in schema class, since # marshmallow_dataclass passes the embedded metadata directly - min_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('min_compat') - max_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('max_compat') + min_compat = schema_field.metadata.get("metadata", schema_field.metadata).get("min_compat") + max_compat = schema_field.metadata.get("metadata", schema_field.metadata).get("max_compat") min_compat = Version.parse(min_compat, optional_minor_and_patch=True) if min_compat else None max_compat = Version.parse(max_compat, optional_minor_and_patch=True) if max_compat else None return min_compat, max_compat @cached -def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optional[Version], - Optional[Version]]]: +def get_restricted_fields(schema_fields: list[Field[Any]]) -> dict[str, tuple[Version | None, Version | None]]: """Get a list of optional min and max compatible versions of fields (from a schema or dataclass).""" - restricted = {} + restricted: dict[str, tuple[Version | None, Version | None]] = {} for _field in schema_fields: min_compat, max_compat = get_restricted_field(_field) if min_compat or max_compat: @@ -37,13 +36,15 @@ def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optiona @cached -def get_incompatible_fields(schema_fields: List[Field], package_version: Version) -> \ - Optional[Dict[str, tuple]]: +def get_incompatible_fields( + schema_fields: list[Field[Any]], + package_version: Version, +) -> dict[str, tuple[Version | None, Version | None]] | None: """Get a list of fields that are incompatible with the package version.""" if not schema_fields: return - incompatible = {} + incompatible: dict[str, tuple[Version | None, Version | None]] = {} restricted_fields = get_restricted_fields(schema_fields) for field_name, values in restricted_fields.items(): min_compat, max_compat = values diff --git a/detection_rules/utils.py b/detection_rules/utils.py index bbc8515f6d8..6845741442d 100644 --- a/detection_rules/utils.py +++ b/detection_rules/utils.py @@ -4,6 +4,7 @@ # 2.0. """Util functions.""" + import base64 import contextlib import functools @@ -20,17 +21,15 @@ from dataclasses import is_dataclass, astuple from datetime import datetime, date, timezone from pathlib import Path -from typing import Dict, Union, Optional, Callable +from typing import Callable, Any, Iterator from string import Template import click -import pytoml -import eql.utils -from eql.utils import load_dump, stream_json_lines +import pytoml # type: ignore[reportMissingTypeStubs] +import eql.utils # type: ignore[reportMissingTypeStubs] +from eql.utils import load_dump # type: ignore[reportMissingTypeStubs] from github.Repository import Repository -import kql - CURR_DIR = Path(__file__).resolve().parent ROOT_DIR = CURR_DIR.parent @@ -38,25 +37,18 @@ INTEGRATION_RULE_DIR = ROOT_DIR / "rules" / "integrations" -class NonelessDict(dict): - """Wrapper around dict that doesn't populate None values.""" - - def __setitem__(self, key, value): - if value is not None: - dict.__setitem__(self, key, value) - - class DateTimeEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (date, datetime)): - return obj.isoformat() + def default(self, o: Any) -> Any: + if isinstance(o, (date, datetime)): + return o.isoformat() marshmallow_schemas = {} -def gopath() -> Optional[str]: - """Retrieve $GOPATH.""" +def gopath() -> str | None: + """Retrieve $GOPATH""" + env_path = os.getenv("GOPATH") if env_path: return env_path @@ -66,170 +58,155 @@ def gopath() -> Optional[str]: output = subprocess.check_output([go_bin, "env"], encoding="utf-8").splitlines() for line in output: if line.startswith("GOPATH="): - return line[len("GOPATH="):].strip('"') + return line[len("GOPATH=") :].strip('"') -def dict_hash(obj: dict) -> str: +def dict_hash(obj: dict[Any, Any]) -> str: """Hash a dictionary deterministically.""" - raw_bytes = base64.b64encode(json.dumps(obj, sort_keys=True).encode('utf-8')) + raw_bytes = base64.b64encode(json.dumps(obj, sort_keys=True).encode("utf-8")) return hashlib.sha256(raw_bytes).hexdigest() -def ensure_list_of_strings(value: str | list) -> list[str]: +def ensure_list_of_strings(value: str | list[str]) -> list[str]: """Ensure or convert a value is a list of strings.""" if isinstance(value, str): # Check if the string looks like a JSON list - if value.startswith('[') and value.endswith(']'): + if value.startswith("[") and value.endswith("]"): try: # Attempt to parse the string as a JSON list parsed_value = json.loads(value) if isinstance(parsed_value, list): - return [str(v) for v in parsed_value] + return [str(v) for v in parsed_value] # type: ignore[reportUnknownVariableType] except json.JSONDecodeError: pass # If it's not a JSON list, split by commas if present # Else return a list with the original string - return list(map(lambda x: x.strip().strip('"'), value.split(','))) - elif isinstance(value, list): - return [str(v) for v in value] + return list(map(lambda x: x.strip().strip('"'), value.split(","))) else: - return [] - - -def get_json_iter(f): - """Get an iterator over a JSON file.""" - first = f.read(2) - f.seek(0) - - if first[0] == '[' or first == "{\n": - return json.load(f) - else: - data = list(stream_json_lines(f)) - return data + return [str(v) for v in value] -def get_nested_value(dictionary, compound_key): - """Get a nested value from a dictionary.""" - keys = compound_key.split('.') +def get_nested_value(obj: Any, compound_key: str) -> Any: + """Get a nested value from a obj.""" + keys = compound_key.split(".") for key in keys: - if isinstance(dictionary, dict): - dictionary = dictionary.get(key) + if isinstance(obj, dict): + obj = obj.get(key) # type: ignore[reportUnknownVariableType] else: return None - return dictionary + return obj # type: ignore[reportUnknownVariableType] -def get_path(*paths) -> Path: +def get_path(paths: list[str]) -> Path: """Get a file by relative path.""" return ROOT_DIR.joinpath(*paths) -def get_etc_path(*paths) -> Path: +def get_etc_path(paths: list[str]) -> Path: """Load a file from the detection_rules/etc/ folder.""" return ETC_DIR.joinpath(*paths) -def get_etc_glob_path(*patterns) -> list: +def get_etc_glob_path(patterns: list[str]) -> list[str]: """Load a file from the detection_rules/etc/ folder.""" pattern = os.path.join(*patterns) return glob.glob(str(ETC_DIR / pattern)) -def get_etc_file(name, mode="r"): +def get_etc_file(name: str, mode: str = "r") -> str: """Load a file from the detection_rules/etc/ folder.""" - with open(get_etc_path(name), mode) as f: + with open(get_etc_path([name]), mode) as f: return f.read() -def load_etc_dump(*path): +def load_etc_dump(paths: list[str]) -> Any: """Load a json/yml/toml file from the detection_rules/etc/ folder.""" - return eql.utils.load_dump(str(get_etc_path(*path))) + return eql.utils.load_dump(str(get_etc_path(paths))) # type: ignore[reportUnknownVariableType] -def save_etc_dump(contents, *path, **kwargs): +def save_etc_dump(contents: dict[str, Any], path: list[str], sort_keys: bool = True, indent: int = 2): """Save a json/yml/toml file from the detection_rules/etc/ folder.""" - path = str(get_etc_path(*path)) - _, ext = os.path.splitext(path) - sort_keys = kwargs.pop('sort_keys', True) - indent = kwargs.pop('indent', 2) + path_joined = str(get_etc_path(path)) + _, ext = os.path.splitext(path_joined) if ext == ".json": - with open(path, "wt") as f: - json.dump(contents, f, cls=DateTimeEncoder, sort_keys=sort_keys, indent=indent, **kwargs) + with open(path_joined, "wt") as f: + json.dump(contents, f, cls=DateTimeEncoder, sort_keys=sort_keys, indent=indent) else: - return eql.utils.save_dump(contents, path) + return eql.utils.save_dump(contents, path) # type: ignore[reportUnknownVariableType] def set_all_validation_bypass(env_value: bool = False): """Set all validation bypass environment variables.""" - os.environ['DR_BYPASS_NOTE_VALIDATION_AND_PARSE'] = str(env_value) - os.environ['DR_BYPASS_BBR_LOOKBACK_VALIDATION'] = str(env_value) - os.environ['DR_BYPASS_TAGS_VALIDATION'] = str(env_value) - os.environ['DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION'] = str(env_value) + os.environ["DR_BYPASS_NOTE_VALIDATION_AND_PARSE"] = str(env_value) + os.environ["DR_BYPASS_BBR_LOOKBACK_VALIDATION"] = str(env_value) + os.environ["DR_BYPASS_TAGS_VALIDATION"] = str(env_value) + os.environ["DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION"] = str(env_value) -def set_nested_value(dictionary, compound_key, value): - """Set a nested value in a dictionary.""" - keys = compound_key.split('.') +def set_nested_value(obj: dict[str, Any], compound_key: str, value: Any): + """Set a nested value in a obj.""" + keys = compound_key.split(".") for key in keys[:-1]: - dictionary = dictionary.setdefault(key, {}) - dictionary[keys[-1]] = value + obj = obj.setdefault(key, {}) + obj[keys[-1]] = value -def gzip_compress(contents) -> bytes: +def gzip_compress(contents: str) -> bytes: gz_file = io.BytesIO() with gzip.GzipFile(mode="w", fileobj=gz_file) as f: - if not isinstance(contents, bytes): - contents = contents.encode("utf8") - f.write(contents) + if isinstance(contents, bytes): + encoded = contents + else: + encoded = contents.encode("utf8") + + _ = f.write(encoded) return gz_file.getvalue() -def read_gzip(path): - with gzip.GzipFile(path, mode='r') as gz: +def read_gzip(path: str | Path): + with gzip.GzipFile(str(path), mode="r") as gz: return gz.read().decode("utf8") @contextlib.contextmanager -def unzip(contents): # type: (bytes) -> zipfile.ZipFile +def unzip(contents: bytes) -> Iterator[zipfile.ZipFile]: """Get zipped contents.""" zipped = io.BytesIO(contents) archive = zipfile.ZipFile(zipped, mode="r") try: yield archive - finally: archive.close() -def unzip_and_save(contents, path, member=None, verbose=True): +def unzip_and_save(contents: bytes, path: str, member: str | None = None, verbose: bool = True): """Save unzipped from raw zipped contents.""" with unzip(contents) as archive: - if member: - archive.extract(member, path) + _ = archive.extract(member, path) else: archive.extractall(path) if verbose: - name_list = archive.namelist()[member] if not member else archive.namelist() - print('Saved files to {}: \n\t- {}'.format(path, '\n\t- '.join(name_list))) + name_list = archive.namelist() + print("Saved files to {}: \n\t- {}".format(path, "\n\t- ".join(name_list))) -def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[dict, str]]: +def unzip_to_dict(zipped: zipfile.ZipFile, load_json: bool = True) -> dict[str, Any]: """Unzip and load contents to dict with filenames as keys.""" - bundle = {} + bundle: dict[str, Any] = {} for filename in zipped.namelist(): - if filename.endswith('/'): + if filename.endswith("/"): continue fp = Path(filename) contents = zipped.read(filename) - if load_json and fp.suffix == '.json': + if load_json and fp.suffix == ".json": contents = json.loads(contents) bundle[fp.name] = contents @@ -237,7 +214,12 @@ def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[di return bundle -def event_sort(events, timestamp='@timestamp', date_format='%Y-%m-%dT%H:%M:%S.%f%z', asc=True): +def event_sort( + events: list[Any], + timestamp: str = "@timestamp", + date_format: str = "%Y-%m-%dT%H:%M:%S.%f%z", + order_asc: bool = True, +) -> list[Any]: """Sort events from elasticsearch by timestamp.""" def round_microseconds(t: str) -> str: @@ -247,7 +229,7 @@ def round_microseconds(t: str) -> str: # Return early if the timestamp string is empty return t - parts = t.split('.') + parts = t.split(".") if len(parts) == 2: # Remove trailing "Z" from microseconds part micro_seconds = parts[1].rstrip("Z") @@ -259,28 +241,19 @@ def round_microseconds(t: str) -> str: # Format the rounded value to always have 6 decimal places # Reconstruct the timestamp string with the rounded microseconds part - formatted_micro_seconds = f'{rounded_micro_seconds:0.6f}'.split(".")[-1] + formatted_micro_seconds = f"{rounded_micro_seconds:0.6f}".split(".")[-1] t = f"{parts[0]}.{formatted_micro_seconds}Z" return t - def _event_sort(event: dict) -> datetime: + def _event_sort(event: dict[str, Any]) -> datetime: """Calculates the sort key for an event as a datetime object.""" t = round_microseconds(event[timestamp]) # Return the timestamp as a datetime object for comparison return datetime.strptime(t, date_format) - return sorted(events, key=_event_sort, reverse=not asc) - - -def combine_sources(*sources): # type: (list[list]) -> list - """Combine lists of events from multiple sources.""" - combined = [] - for source in sources: - combined.extend(source.copy()) - - return event_sort(combined) + return sorted(events, key=_event_sort, reverse=not order_asc) def convert_time_span(span: str) -> int: @@ -290,55 +263,56 @@ def convert_time_span(span: str) -> int: return eql.ast.TimeRange(amount, unit).as_milliseconds() -def evaluate(rule, events, normalize_kql_keywords: bool = False): - """Evaluate a query against events.""" - evaluator = kql.get_evaluator(kql.parse(rule.query), normalize_kql_keywords=normalize_kql_keywords) - filtered = list(filter(evaluator, events)) - return filtered - - -def unix_time_to_formatted(timestamp): # type: (int|str) -> str +def unix_time_to_formatted(timestamp: int | float | str) -> str: """Converts unix time in seconds or milliseconds to the default format.""" if isinstance(timestamp, (int, float)): - if timestamp > 2 ** 32: + if timestamp > 2**32: timestamp = round(timestamp / 1000, 3) - return datetime.fromtimestamp(timestamp, timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' + return datetime.fromtimestamp(timestamp, timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + return timestamp -def normalize_timing_and_sort(events, timestamp='@timestamp', asc=True): +def normalize_timing_and_sort( + events: list[dict[str, Any]], + timestamp: str = "@timestamp", + order_asc: bool = True, +): """Normalize timestamp formats and sort events.""" for event in events: _timestamp = event[timestamp] if not isinstance(_timestamp, str): event[timestamp] = unix_time_to_formatted(_timestamp) - return event_sort(events, timestamp=timestamp, asc=asc) + return event_sort(events, timestamp=timestamp, order_asc=order_asc) -def freeze(obj): +def freeze(obj: Any) -> Any: """Helper function to make mutable objects immutable and hashable.""" if not isinstance(obj, type) and is_dataclass(obj): - obj = astuple(obj) + obj = astuple(obj) # type: ignore[reportUnknownVariableType] if isinstance(obj, (list, tuple)): - return tuple(freeze(o) for o in obj) + return tuple(freeze(o) for o in obj) # type: ignore[reportUnknownVariableType] elif isinstance(obj, dict): - return freeze(sorted(obj.items())) + items = obj.items() # type: ignore[reportUnknownVariableType] + return freeze(sorted(items)) # type: ignore[reportUnknownVariableType] else: return obj -_cache = {} +_cache: dict[int, dict[tuple[Any, Any], Any]] = {} -def cached(f): +# FIXME: should be replaced with `functools.cache` +# https://docs.python.org/3/library/functools.html#functools.cache +def cached(f: Callable[..., Any]) -> Callable[..., Any]: """Helper function to memoize functions.""" func_key = id(f) @functools.wraps(f) - def wrapped(*args, **kwargs): - _cache.setdefault(func_key, {}) + def wrapped(*args: Any, **kwargs: Any) -> Any: + _ = _cache.setdefault(func_key, {}) cache_key = freeze(args), freeze(kwargs) if cache_key not in _cache[func_key]: @@ -347,9 +321,9 @@ def wrapped(*args, **kwargs): return _cache[func_key][cache_key] def clear(): - _cache.pop(func_key, None) + _ = _cache.pop(func_key, None) - wrapped.clear = clear + wrapped.clear = clear # type: ignore[reportAttributeAccessIssue] return wrapped @@ -357,34 +331,34 @@ def clear_caches(): _cache.clear() -def rulename_to_filename(name: str, tactic_name: str = None, ext: str = '.toml') -> str: +def rulename_to_filename(name: str, tactic_name: str | None = None, ext: str = ".toml") -> str: """Convert a rule name to a filename.""" - name = re.sub(r'[^_a-z0-9]+', '_', name.strip().lower()).strip('_') + name = re.sub(r"[^_a-z0-9]+", "_", name.strip().lower()).strip("_") if tactic_name: - pre = rulename_to_filename(name=tactic_name, ext='') - name = f'{pre}_{name}' - return name + ext or '' + pre = rulename_to_filename(name=tactic_name, ext="") + name = f"{pre}_{name}" + return name + ext or "" -def load_rule_contents(rule_file: Path, single_only=False) -> list: +def load_rule_contents(rule_file: Path, single_only: bool = False) -> list[Any]: """Load a rule file from multiple formats.""" _, extension = os.path.splitext(rule_file) raw_text = rule_file.read_text() - if extension in ('.ndjson', '.jsonl'): + if extension in (".ndjson", ".jsonl"): # kibana exported rule object is ndjson with the export metadata on the last line contents = [json.loads(line) for line in raw_text.splitlines()] - if len(contents) > 1 and 'exported_count' in contents[-1]: + if len(contents) > 1 and "exported_count" in contents[-1]: contents.pop(-1) if single_only and len(contents) > 1: - raise ValueError('Multiple rules not allowed') + raise ValueError("Multiple rules not allowed") return contents or [{}] - elif extension == '.toml': - rule = pytoml.loads(raw_text) - elif extension.lower() in ('yaml', 'yml'): + elif extension == ".toml": + rule = pytoml.loads(raw_text) # type: ignore[reportUnknownVariableType] + elif extension.lower() in ("yaml", "yml"): rule = load_dump(str(rule_file)) else: return [] @@ -392,20 +366,27 @@ def load_rule_contents(rule_file: Path, single_only=False) -> list: if isinstance(rule, dict): return [rule] elif isinstance(rule, list): - return rule + return rule # type: ignore[reportUnknownVariableType] else: raise ValueError(f"Expected a list or dictionary in {rule_file}") -def load_json_from_branch(repo: Repository, file_path: str, branch: Optional[str]): +def load_json_from_branch(repo: Repository, file_path: str, branch: str): """Load JSON file from a specific branch.""" - content_file = repo.get_contents(file_path, ref=branch) - return json.loads(content_file.decoded_content.decode("utf-8")) + content_files = repo.get_contents(file_path, ref=branch) + if isinstance(content_files, list): + raise ValueError("Multiple files found") -def compare_versions(base_json: dict, branch_json: dict) -> list[tuple[str, str, int, int]]: + content_file = content_files + content = content_file.decoded_content + data = content.decode("utf-8") + return json.loads(data) + + +def compare_versions(base_json: dict[str, Any], branch_json: dict[str, Any]) -> list[tuple[str, str, int, int]]: """Compare versions of two lock version file JSON objects.""" - changes = [] + changes: list[tuple[str, str, int, int]] = [] for key in base_json: if key in branch_json: base_version = base_json[key].get("version") @@ -418,7 +399,7 @@ def compare_versions(base_json: dict, branch_json: dict) -> list[tuple[str, str, def check_double_bumps(changes: list[tuple[str, str, int, int]]) -> list[tuple[str, str, int, int]]: """Check for double bumps in version changes of the result of compare versions of a version lock file.""" - double_bumps = [] + double_bumps: list[tuple[str, str, int, int]] = [] for key, name, removed, added in changes: # Determine the modulo dynamically based on the highest number of digits max_digits = max(len(str(removed)), len(str(added))) @@ -429,7 +410,11 @@ def check_double_bumps(changes: list[tuple[str, str, int, int]]) -> list[tuple[s def check_version_lock_double_bumps( - repo: Repository, file_path: str, base_branch: str, branch: str = "", local_file: Path = None + repo: Repository, + file_path: str, + base_branch: str, + branch: str = "", + local_file: Path | None = None, ) -> list[tuple[str, str, int, int]]: """Check for double bumps in version changes of the result of compare versions of a version lock file.""" base_json = load_json_from_branch(repo, file_path, base_branch) @@ -445,13 +430,13 @@ def check_version_lock_double_bumps( return double_bumps -def format_command_options(ctx): +def format_command_options(ctx: click.Context): """Echo options for a click command.""" formatter = ctx.make_formatter() - opts = [] + opts: list[tuple[str, str]] = [] for param in ctx.command.get_params(ctx): - if param.name == 'help': + if param.name == "help": continue rv = param.get_help_record(ctx) @@ -459,15 +444,18 @@ def format_command_options(ctx): opts.append(rv) if opts: - with formatter.section('Options'): + with formatter.section("Options"): formatter.write_dl(opts) return formatter.getvalue() -def make_git(*prefix_args) -> Optional[Callable]: +def make_git(*prefix_args: Any) -> Callable[..., str]: git_exe = shutil.which("git") - prefix_args = [str(arg) for arg in prefix_args] + prefix_arg_strs = [str(arg) for arg in prefix_args] + + if "-C" not in prefix_arg_strs: + prefix_arg_strs = ["-C", str(ROOT_DIR)] + prefix_arg_strs if not git_exe: click.secho("Unable to find git", err=True, fg="red") @@ -476,58 +464,56 @@ def make_git(*prefix_args) -> Optional[Callable]: if ctx is not None: ctx.exit(1) - return - - def git(*args, print_output=False): - nonlocal prefix_args + raise ValueError("Git not found") - if '-C' not in prefix_args: - prefix_args = ['-C', get_path()] + prefix_args - - full_args = [git_exe] + prefix_args + [str(arg) for arg in args] - if print_output: - return subprocess.check_call(full_args) + def git(*args: Any) -> str: + arg_strs = [str(arg) for arg in args] + full_args = [git_exe] + prefix_arg_strs + arg_strs return subprocess.check_output(full_args, encoding="utf-8").rstrip() return git -def git(*args, **kwargs): +def git(*args: Any, **kwargs: Any) -> str | int: """Find and run a one-off Git command.""" - return make_git()(*args, **kwargs) + g = make_git() + return g(*args, **kwargs) + + +FuncT = Callable[..., Any] -def add_params(*params): +def add_params(*params: Any): """Add parameters to a click command.""" - def decorator(f): - if not hasattr(f, '__click_params__'): - f.__click_params__ = [] - f.__click_params__.extend(params) + def decorator(f: FuncT) -> FuncT: + if not hasattr(f, "__click_params__"): + f.__click_params__ = [] # type: ignore[reportFunctionMemberAccess] + f.__click_params__.extend(params) # type: ignore[reportFunctionMemberAccess] return f return decorator -class Ndjson(list): +class Ndjson(list[dict[str, Any]]): """Wrapper for ndjson data.""" def to_string(self, sort_keys: bool = False): """Format contents list to ndjson string.""" - return '\n'.join(json.dumps(c, sort_keys=sort_keys) for c in self) + '\n' + return "\n".join(json.dumps(c, sort_keys=sort_keys) for c in self) + "\n" @classmethod - def from_string(cls, ndjson_string: str, **kwargs): + def from_string(cls, ndjson_string: str, **kwargs: Any) -> "Ndjson": """Load ndjson string to a list.""" contents = [json.loads(line, **kwargs) for line in ndjson_string.strip().splitlines()] return Ndjson(contents) - def dump(self, filename: Path, sort_keys=False): + def dump(self, filename: Path, sort_keys: bool = False): """Save contents to an ndjson file.""" - filename.write_text(self.to_string(sort_keys=sort_keys)) + _ = filename.write_text(self.to_string(sort_keys=sort_keys)) @classmethod - def load(cls, filename: Path, **kwargs): + def load(cls, filename: Path, **kwargs: Any): """Load content from an ndjson file.""" return cls.from_string(filename.read_text(), **kwargs) @@ -535,19 +521,18 @@ def load(cls, filename: Path, **kwargs): class PatchedTemplate(Template): """String template with updated methods from future versions.""" - def get_identifiers(self): + def get_identifiers(self) -> list[str]: """Returns a list of the valid identifiers in the template, in the order they first appear, ignoring any invalid identifiers.""" # https://github.com/python/cpython/blob/3b4f8fc83dcea1a9d0bc5bd33592e5a3da41fa71/Lib/string.py#LL157-L171C19 - ids = [] + ids: list[str] = [] for mo in self.pattern.finditer(self.template): - named = mo.group('named') or mo.group('braced') - if named is not None and named not in ids: + named = mo.group("named") or mo.group("braced") + if named and named not in ids: # add a named group only the first time it appears ids.append(named) - elif named is None and mo.group('invalid') is None and mo.group('escaped') is None: + elif not named and mo.group("invalid") is None and mo.group("escaped") is None: # If all the groups are None, there must be # another group we're not expecting - raise ValueError('Unrecognized named group in pattern', - self.pattern) + raise ValueError("Unrecognized named group in pattern", self.pattern) return ids diff --git a/detection_rules/version_lock.py b/detection_rules/version_lock.py index 36c62633bb6..bc4d9bd15ce 100644 --- a/detection_rules/version_lock.py +++ b/detection_rules/version_lock.py @@ -3,17 +3,17 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Helper utilities to manage the version lock.""" + from copy import deepcopy from dataclasses import dataclass from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Union +from typing import ClassVar, Literal, Any import click from semver import Version from .config import parse_rules_config from .mixins import LockDataclassMixin, MarshmallowDataclassMixin -from .rule_loader import RuleCollection from .schemas import definitions from .utils import cached @@ -35,30 +35,33 @@ class BaseEntry: @dataclass(frozen=True) class PreviousEntry(BaseEntry): - # this is Optional for resiliency in already tagged branches missing this field. This means we should strictly # validate elsewhere - max_allowable_version: Optional[int] + max_allowable_version: int | None = None @dataclass(frozen=True) class VersionLockFileEntry(MarshmallowDataclassMixin, BaseEntry): """Schema for a rule entry in the version lock.""" - min_stack_version: Optional[definitions.SemVerMinorOnly] - previous: Optional[Dict[definitions.SemVerMinorOnly, PreviousEntry]] + + min_stack_version: definitions.SemVerMinorOnly | None = None + previous: dict[definitions.SemVerMinorOnly, PreviousEntry] | None = None @dataclass(frozen=True) class VersionLockFile(LockDataclassMixin): """Schema for the full version lock file.""" - data: Dict[Union[definitions.UUIDString, definitions.KNOWN_BAD_RULE_IDS], VersionLockFileEntry] + + data: dict[definitions.UUIDString | definitions.KNOWN_BAD_RULE_IDS, VersionLockFileEntry] file_path: ClassVar[Path] = RULES_CONFIG.version_lock_file def __contains__(self, rule_id: str): """Check if a rule is in the map by comparing IDs.""" return rule_id in self.data - def __getitem__(self, item) -> VersionLockFileEntry: + def __getitem__( + self, item: definitions.UUIDString | Literal["119c8877-8613-416d-a98a-96b6664ee73a5"] + ) -> VersionLockFileEntry: """Return entries by rule id.""" if item not in self.data: raise KeyError(item) @@ -68,7 +71,8 @@ def __getitem__(self, item) -> VersionLockFileEntry: @dataclass(frozen=True) class DeprecatedRulesEntry(MarshmallowDataclassMixin): """Schema for rule entry in the deprecated rules file.""" - deprecation_date: Union[definitions.Date, definitions.KNOWN_BAD_DEPRECATED_DATES] + + deprecation_date: definitions.Date | definitions.KNOWN_BAD_DEPRECATED_DATES rule_name: definitions.RuleName stack_version: definitions.SemVer @@ -76,14 +80,17 @@ class DeprecatedRulesEntry(MarshmallowDataclassMixin): @dataclass(frozen=True) class DeprecatedRulesFile(LockDataclassMixin): """Schema for the full deprecated rules file.""" - data: Dict[Union[definitions.UUIDString, definitions.KNOWN_BAD_RULE_IDS], DeprecatedRulesEntry] + + data: dict[definitions.UUIDString | definitions.KNOWN_BAD_RULE_IDS, DeprecatedRulesEntry] file_path: ClassVar[Path] = RULES_CONFIG.deprecated_rules_file def __contains__(self, rule_id: str): """Check if a rule is in the map by comparing IDs.""" return rule_id in self.data - def __getitem__(self, item) -> DeprecatedRulesEntry: + def __getitem__( + self, item: definitions.UUIDString | Literal["119c8877-8613-416d-a98a-96b6664ee73a5"] + ) -> DeprecatedRulesEntry: """Return entries by rule id.""" if item not in self.data: raise KeyError(item) @@ -91,7 +98,7 @@ def __getitem__(self, item) -> DeprecatedRulesEntry: @cached -def load_versions() -> dict: +def load_versions() -> dict[str, Any]: """Load and validate the default version.lock file.""" version_lock_file = VersionLockFile.load_from_file() return version_lock_file.to_dict() @@ -100,20 +107,20 @@ def load_versions() -> dict: # for tagged branches which existed before the types were added and validation enforced, we will need to manually add # them to allow them to pass validation. These will only ever currently be loaded via the RuleCollection.load_git_tag # method, which is primarily for generating diffs across releases, so there is no risk to versioning -def add_rule_types_to_lock(lock_contents: dict, rule_map: Dict[str, dict]): +def add_rule_types_to_lock(lock_contents: dict[str, Any], rule_map: dict[str, Any]): """Add the rule type to entries in the lock file,if missing.""" for rule_id, lock in lock_contents.items(): rule = rule_map.get(rule_id, {}) # this defaults to query if the rule is not found - it is just for validation so should not impact - rule_type = rule.get('rule', {}).get('type', 'query') + rule_type = rule.get("rule", {}).get("type", "query") # the type is a bit less important than the structure to pass validation - lock['type'] = rule_type + lock["type"] = rule_type - if 'previous' in lock: - for _, prev_lock in lock['previous'].items(): - prev_lock['type'] = rule_type + if "previous" in lock: + for _, prev_lock in lock["previous"].items(): + prev_lock["type"] = rule_type return lock_contents @@ -121,16 +128,21 @@ def add_rule_types_to_lock(lock_contents: dict, rule_map: Dict[str, dict]): class VersionLock: """Version handling for rule files and collections.""" - def __init__(self, version_lock_file: Optional[Path] = None, deprecated_lock_file: Optional[Path] = None, - version_lock: Optional[dict] = None, deprecated_lock: Optional[dict] = None, - name: Optional[str] = None, invalidated: Optional[bool] = False): - + def __init__( + self, + version_lock_file: Path | None = None, + deprecated_lock_file: Path | None = None, + version_lock: dict[str, Any] | None = None, + deprecated_lock: dict[str, Any] | None = None, + name: str | None = None, + invalidated: bool | None = False, + ): if invalidated: err_msg = "This VersionLock configuration is not valid when configued to bypass_version_lock." raise NotImplementedError(err_msg) - assert (version_lock_file or version_lock), 'Must provide version lock file or contents' - assert (deprecated_lock_file or deprecated_lock), 'Must provide deprecated lock file or contents' + assert version_lock_file or version_lock, "Must provide version lock file or contents" + assert deprecated_lock_file or deprecated_lock, "Must provide deprecated lock file or contents" self.name = name self.version_lock_file = version_lock_file @@ -147,12 +159,12 @@ def __init__(self, version_lock_file: Optional[Path] = None, deprecated_lock_fil self.deprecated_lock = DeprecatedRulesFile.from_dict(dict(data=deprecated_lock)) @staticmethod - def save_file(path: Path, lock_file: Union[VersionLockFile, DeprecatedRulesFile]): - assert path, f'{path} not set' + def save_file(path: Path, lock_file: VersionLockFile | DeprecatedRulesFile): + assert path, f"{path} not set" lock_file.save_to_file(path) - print(f'Updated {path} file') + print(f"Updated {path} file") - def get_locked_version(self, rule_id: str, min_stack_version: Optional[str] = None) -> Optional[int]: + def get_locked_version(self, rule_id: str, min_stack_version: str | None = None) -> int | None: if rule_id in self.version_lock: latest_version_info = self.version_lock[rule_id] if latest_version_info.previous and latest_version_info.previous.get(min_stack_version): @@ -161,7 +173,7 @@ def get_locked_version(self, rule_id: str, min_stack_version: Optional[str] = No stack_version_info = latest_version_info return stack_version_info.version - def get_locked_hash(self, rule_id: str, min_stack_version: Optional[str] = None) -> Optional[str]: + def get_locked_hash(self, rule_id: str, min_stack_version: str | None = None) -> str | None: """Get the version info matching the min_stack_version if present.""" if rule_id in self.version_lock: latest_version_info = self.version_lock[rule_id] @@ -172,17 +184,26 @@ def get_locked_hash(self, rule_id: str, min_stack_version: Optional[str] = None) existing_sha256: str = stack_version_info.sha256 return existing_sha256 - def manage_versions(self, rules: RuleCollection, - exclude_version_update=False, save_changes=False, - verbose=True, buffer_int: int = 100) -> (List[str], List[str], List[str]): + def manage_versions( + self, + rules: Any, # type: ignore[reportRedeclaration] + exclude_version_update: bool = False, + save_changes: bool = False, + verbose: bool = True, + buffer_int: int = 100, + ) -> tuple[list[definitions.UUIDString], list[str], list[str]]: """Update the contents of the version.lock file and optionally save changes.""" from .packaging import current_stack_version + from .rule_loader import RuleCollection + from .rule import TOMLRule + + rules: RuleCollection = rules version_lock_hash = self.version_lock.sha256() lock_file_contents = deepcopy(self.version_lock.to_dict()) current_deprecated_lock = deepcopy(self.deprecated_lock.to_dict()) - verbose_echo = click.echo if verbose else (lambda x: None) + verbose_echo = click.echo if verbose else (lambda _: None) # type: ignore[reportUnknownVariableType] already_deprecated = set(current_deprecated_lock) deprecated_rules = set(rules.deprecated.id_map) @@ -195,22 +216,22 @@ def manage_versions(self, rules: RuleCollection, if not (new_rules or changed_rules or newly_deprecated): return list(changed_rules), list(new_rules), list(newly_deprecated) - verbose_echo('Rule changes detected!') - changes = [] + verbose_echo("Rule changes detected!") + + changes: list[str] = [] - def log_changes(r, route_taken, new_rule_version, *msg): - new = [f' {route_taken}: {r.id}, new version: {new_rule_version}'] - new.extend([f' - {m}' for m in msg if m]) + def log_changes(r: TOMLRule, route_taken: str, new_rule_version: Any, *msg: str): + new = [f" {route_taken}: {r.id}, new version: {new_rule_version}"] + new.extend([f" - {m}" for m in msg if m]) changes.extend(new) for rule in rules: if rule.contents.metadata.maturity == "production" or rule.id in newly_deprecated: # assume that older stacks are always locked first - min_stack = Version.parse(rule.contents.get_supported_version(), - optional_minor_and_patch=True) + min_stack = Version.parse(rule.contents.get_supported_version(), optional_minor_and_patch=True) lock_from_rule = rule.contents.lock_info(bump=not exclude_version_update) - lock_from_file: dict = lock_file_contents.setdefault(rule.id, {}) + lock_from_file = lock_file_contents.setdefault(rule.id, {}) # scenarios to handle, assuming older stacks are always locked first: # 1) no breaking changes ever made or the first time a rule is created @@ -218,40 +239,42 @@ def log_changes(r, route_taken, new_rule_version, *msg): # 3) on the latest stack, locking in a breaking change # 4) on an old stack, after a breaking change has been made latest_locked_stack_version = rule.contents.convert_supported_version( - lock_from_file.get("min_stack_version")) + lock_from_file.get("min_stack_version") + ) # strip version down to only major.minor to compare against lock file versioning stripped_version = f"{min_stack.major}.{min_stack.minor}" if not lock_from_file or min_stack == latest_locked_stack_version: - route = 'A' + route = "A" # 1) no breaking changes ever made or the first time a rule is created # 2) on the latest, after a breaking change has been locked lock_from_file.update(lock_from_rule) - new_version = lock_from_rule['version'] + new_version = lock_from_rule["version"] # add the min_stack_version to the lock if it's explicitly set if rule.contents.metadata.min_stack_version is not None: lock_from_file["min_stack_version"] = stripped_version - log_msg = f'min_stack_version added: {min_stack}' + log_msg = f"min_stack_version added: {min_stack}" log_changes(rule, route, new_version, log_msg) elif min_stack > latest_locked_stack_version: - route = 'B' + route = "B" # 3) on the latest stack, locking in a breaking change - stripped_latest_locked_stack_version = f"{latest_locked_stack_version.major}." \ - f"{latest_locked_stack_version.minor}" + stripped_latest_locked_stack_version = ( + f"{latest_locked_stack_version.major}.{latest_locked_stack_version.minor}" + ) # preserve buffer space to support forked version spacing if exclude_version_update: buffer_int -= 1 lock_from_rule["version"] = lock_from_file["version"] + buffer_int previous_lock_info = { - "max_allowable_version": lock_from_rule['version'] - 1, + "max_allowable_version": lock_from_rule["version"] - 1, "rule_name": lock_from_file["rule_name"], "sha256": lock_from_file["sha256"], "version": lock_from_file["version"], - "type": lock_from_file["type"] + "type": lock_from_file["type"], } lock_from_file.setdefault("previous", {}) @@ -260,19 +283,22 @@ def log_changes(r, route_taken, new_rule_version, *msg): # overwrite the "latest" part of the lock at the top level lock_from_file.update(lock_from_rule, min_stack_version=stripped_version) - new_version = lock_from_rule['version'] + new_version = lock_from_rule["version"] log_changes( - rule, route, new_version, - f'previous {stripped_latest_locked_stack_version} saved as \ - version: {previous_lock_info["version"]}', - f'current min_stack updated to {stripped_version}' + rule, + route, + new_version, + f"previous {stripped_latest_locked_stack_version} saved as \ + version: {previous_lock_info['version']}", + f"current min_stack updated to {stripped_version}", ) elif min_stack < latest_locked_stack_version: - route = 'C' + route = "C" # 4) on an old stack, after a breaking change has been made (updated fork) - assert stripped_version in lock_from_file.get("previous", {}), \ + assert stripped_version in lock_from_file.get("previous", {}), ( f"Expected {rule.id} @ v{stripped_version} in the rule lock" + ) # TODO: Figure out whether we support locking old versions and if we want to # "leave room" by skipping versions when breaking changes are made. @@ -280,22 +306,28 @@ def log_changes(r, route_taken, new_rule_version, *msg): # since it's a good summary of everything that happens previous_entry = lock_from_file["previous"][stripped_version] - max_allowable_version = previous_entry['max_allowable_version'] + max_allowable_version = previous_entry["max_allowable_version"] # if version bump collides with future bump: fail # if space: change and log - info_from_rule = (lock_from_rule['sha256'], lock_from_rule['version']) - info_from_file = (previous_entry['sha256'], previous_entry['version']) + info_from_rule = (lock_from_rule["sha256"], lock_from_rule["version"]) + info_from_file = (previous_entry["sha256"], previous_entry["version"]) - if lock_from_rule['version'] > max_allowable_version: - raise ValueError(f'Forked rule: {rule.id} - {rule.name} has changes that will force it to ' - f'exceed the max allowable version of {max_allowable_version}') + if lock_from_rule["version"] > max_allowable_version: + raise ValueError( + f"Forked rule: {rule.id} - {rule.name} has changes that will force it to " + f"exceed the max allowable version of {max_allowable_version}" + ) if info_from_rule != info_from_file: lock_from_file["previous"][stripped_version].update(lock_from_rule) new_version = lock_from_rule["version"] - log_changes(rule, route, 'unchanged', - f'previous version {stripped_version} updated version to {new_version}') + log_changes( + rule, + route, + "unchanged", + f"previous version {stripped_version} updated version to {new_version}", + ) continue else: raise RuntimeError("Unreachable code") @@ -305,20 +337,21 @@ def log_changes(r, route_taken, new_rule_version, *msg): current_deprecated_lock[rule.id] = { "rule_name": rule.name, "stack_version": current_stack_version(), - "deprecation_date": rule.contents.metadata['deprecation_date'] + "deprecation_date": rule.contents.metadata["deprecation_date"], } if save_changes or verbose: - click.echo(f' - {len(changed_rules)} changed rules') - click.echo(f' - {len(new_rules)} new rules') - click.echo(f' - {len(newly_deprecated)} newly deprecated rules') + click.echo(f" - {len(changed_rules)} changed rules") + click.echo(f" - {len(new_rules)} new rules") + click.echo(f" - {len(newly_deprecated)} newly deprecated rules") if not save_changes: verbose_echo( - 'run `build-release --update-version-lock` to update version.lock.json and deprecated_rules.json') + "run `build-release --update-version-lock` to update version.lock.json and deprecated_rules.json" + ) return list(changed_rules), list(new_rules), list(newly_deprecated) - click.echo('Detailed changes: \n' + '\n'.join(changes)) + click.echo("Detailed changes: \n" + "\n".join(changes)) # reset local version lock self.version_lock = VersionLockFile.from_dict(dict(data=lock_file_contents)) @@ -326,13 +359,13 @@ def log_changes(r, route_taken, new_rule_version, *msg): new_hash = self.version_lock.sha256() - if version_lock_hash != new_hash: + if version_lock_hash != new_hash and self.version_lock_file: self.save_file(self.version_lock_file, self.version_lock) - if newly_deprecated: + if newly_deprecated and self.deprecated_lock_file: self.save_file(self.deprecated_lock_file, self.deprecated_lock) - return changed_rules, list(new_rules), newly_deprecated + return list(changed_rules), list(new_rules), list(newly_deprecated) name = str(RULES_CONFIG.version_lock_file) diff --git a/hunting/__main__.py b/hunting/__main__.py index 4ce320566a0..3bff686e574 100644 --- a/hunting/__main__.py +++ b/hunting/__main__.py @@ -8,9 +8,10 @@ from collections import Counter from dataclasses import asdict from pathlib import Path +from typing import Any import click -from tabulate import tabulate +from tabulate import tabulate # type: ignore[reportMissingModuleSource] from detection_rules.misc import parse_user_config @@ -18,8 +19,7 @@ from .markdown import MarkdownGenerator from .run import QueryRunner from .search import QueryIndex -from .utils import (filter_elasticsearch_params, get_hunt_path, load_all_toml, - load_toml, update_index_yml) +from .utils import filter_elasticsearch_params, get_hunt_path, load_all_toml, load_toml, update_index_yml @click.group() @@ -28,20 +28,19 @@ def hunting(): pass -@hunting.command('generate-markdown') -@click.argument('path', required=False) -def generate_markdown(path: Path = None): +@hunting.command("generate-markdown") +@click.argument("path", required=False) +def generate_markdown(path: Path | None = None): """Convert TOML hunting queries to Markdown format.""" markdown_generator = MarkdownGenerator(HUNTING_DIR) if path: - path = Path(path) - if path.is_file() and path.suffix == '.toml': + if path.is_file() and path.suffix == ".toml": click.echo(f"Generating Markdown for single file: {path}") markdown_generator.process_file(path) elif (HUNTING_DIR / path).is_dir(): click.echo(f"Generating Markdown for folder: {path}") - markdown_generator.process_folder(path) + markdown_generator.process_folder(str(path)) else: raise ValueError(f"Invalid path provided: {path}") else: @@ -52,7 +51,7 @@ def generate_markdown(path: Path = None): markdown_generator.update_index_md() -@hunting.command('refresh-index') +@hunting.command("refresh-index") def refresh_index(): """Refresh the index.yml file from TOML files and then refresh the index.md file.""" click.echo("Refreshing the index.yml and index.md files.") @@ -62,12 +61,12 @@ def refresh_index(): click.echo("Index refresh complete.") -@hunting.command('search') -@click.option('--tactic', type=str, default=None, help="Search by MITRE tactic ID (e.g., TA0001)") -@click.option('--technique', type=str, default=None, help="Search by MITRE technique ID (e.g., T1078)") -@click.option('--sub-technique', type=str, default=None, help="Search by MITRE sub-technique ID (e.g., T1078.001)") -@click.option('--data-source', type=str, default=None, help="Filter by data_source like 'aws', 'macos', or 'linux'") -@click.option('--keyword', type=str, default=None, help="Search by keyword in name, description, and notes") +@hunting.command("search") +@click.option("--tactic", type=str, default=None, help="Search by MITRE tactic ID (e.g., TA0001)") +@click.option("--technique", type=str, default=None, help="Search by MITRE technique ID (e.g., T1078)") +@click.option("--sub-technique", type=str, default=None, help="Search by MITRE sub-technique ID (e.g., T1078.001)") +@click.option("--data-source", type=str, default=None, help="Filter by data_source like 'aws', 'macos', or 'linux'") +@click.option("--keyword", type=str, default=None, help="Search by keyword in name, description, and notes") def search_queries(tactic: str, technique: str, sub_technique: str, data_source: str, keyword: str): """Search for queries based on MITRE tactic, technique, sub-technique, or data_source.""" @@ -90,13 +89,13 @@ def search_queries(tactic: str, technique: str, sub_technique: str, data_source: click.secho(f"\nFound {len(results)} matching queries:\n", fg="green", bold=True) # Prepare the data for tabulate - table_data = [] + table_data: list[str | Any] = [] for result in results: # Customize output to include technique, data_source, and UUID - data_source_str = result['data_source'] - mitre_str = ", ".join(result['mitre']) - uuid = result['uuid'] - table_data.append([result['name'], uuid, result['path'], data_source_str, mitre_str]) + data_source_str = result["data_source"] + mitre_str = ", ".join(result["mitre"]) + uuid = result["uuid"] + table_data.append([result["name"], uuid, result["path"], data_source_str, mitre_str]) # Output results using tabulate table_headers = ["Name", "UUID", "Location", "Data Source", "MITRE"] @@ -106,12 +105,17 @@ def search_queries(tactic: str, technique: str, sub_technique: str, data_source: click.secho("No matching queries found.", fg="red", bold=True) -@hunting.command('view-hunt') -@click.option('--uuid', type=str, help="View a specific hunt by UUID.") -@click.option('--path', type=str, help="View a specific hunt by file path.") -@click.option('--format', 'output_format', default='toml', type=click.Choice(['toml', 'json'], case_sensitive=False), - help="Output format (toml or json).") -@click.option('--query-only', is_flag=True, help="Only display the query content.") +@hunting.command("view-hunt") +@click.option("--uuid", type=str, help="View a specific hunt by UUID.") +@click.option("--path", type=str, help="View a specific hunt by file path.") +@click.option( + "--format", + "output_format", + default="toml", + type=click.Choice(["toml", "json"], case_sensitive=False), + help="Output format (toml or json).", +) +@click.option("--query-only", is_flag=True, help="Only display the query content.") def view_hunt(uuid: str, path: str, output_format: str, query_only: bool): """View a specific hunt by UUID or file path in the specified format (TOML or JSON).""" @@ -121,6 +125,9 @@ def view_hunt(uuid: str, path: str, output_format: str, query_only: bool): if error_message: raise click.ClickException(error_message) + if not hunt_path: + raise ValueError("No hunt path found") + # Load the TOML data hunt = load_toml(hunt_path) @@ -134,17 +141,20 @@ def view_hunt(uuid: str, path: str, output_format: str, query_only: bool): return # Output the hunt in the requested format - if output_format == 'toml': + if output_format == "toml": click.echo(hunt_path.read_text()) - elif output_format == 'json': + elif output_format == "json": hunt_dict = asdict(hunt) click.echo(json.dumps(hunt_dict, indent=4)) -@hunting.command('hunt-summary') -@click.option('--breakdown', type=click.Choice(['platform', 'integration', 'language'], - case_sensitive=False), default='platform', - help="Specify how to break down the summary: 'platform', 'integration', or 'language'.") +@hunting.command("hunt-summary") +@click.option( + "--breakdown", + type=click.Choice(["platform", "integration", "language"], case_sensitive=False), + default="platform", + help="Specify how to break down the summary: 'platform', 'integration', or 'language'.", +) def hunt_summary(breakdown: str): """ Generate a summary of hunt queries, broken down by platform, integration, or language. @@ -155,9 +165,9 @@ def hunt_summary(breakdown: str): all_hunts = load_all_toml(HUNTING_DIR) # Use Counter for more concise counting - platform_counter = Counter() - integration_counter = Counter() - language_counter = Counter() + platform_counter: Counter[str] = Counter() + integration_counter: Counter[str] = Counter() + language_counter: Counter[str] = Counter() for hunt, path in all_hunts: # Get the platform based on the folder name @@ -168,28 +178,30 @@ def hunt_summary(breakdown: str): integration_counter.update(hunt.integration) # Count languages, renaming 'SQL' to 'OSQuery' - languages = ['OSQuery' if lang == 'SQL' else lang for lang in hunt.language] + languages = ["OSQuery" if lang == "SQL" else lang for lang in hunt.language] language_counter.update(languages) # Prepare and display the table based on the selected breakdown - if breakdown == 'platform': + if breakdown == "platform": table_data = [[platform, count] for platform, count in platform_counter.items()] table_headers = ["Platform (Folder)", "Hunt Count"] - elif breakdown == 'integration': + elif breakdown == "integration": table_data = [[integration, count] for integration, count in integration_counter.items()] table_headers = ["Integration", "Hunt Count"] - elif breakdown == 'language': + elif breakdown == "language": table_data = [[language, count] for language, count in language_counter.items()] table_headers = ["Language", "Hunt Count"] + else: + raise ValueError(f"Unsupported breakdown value: {breakdown}") click.echo(tabulate(table_data, headers=table_headers, tablefmt="fancy_grid")) -@hunting.command('run-query') -@click.option('--uuid', help="The UUID of the hunting query to run.") -@click.option('--file-path', help="The file path of the hunting query to run.") -@click.option('--all', 'run_all', is_flag=True, help="Run all eligible queries in the file.") -@click.option('--wait-time', 'wait_time', default=180, help="Time to wait for query completion.") +@hunting.command("run-query") +@click.option("--uuid", help="The UUID of the hunting query to run.") +@click.option("--file-path", help="The file path of the hunting query to run.") +@click.option("--all", "run_all", is_flag=True, help="Run all eligible queries in the file.") +@click.option("--wait-time", "wait_time", default=180, help="Time to wait for query completion.") def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): """Run a hunting query by UUID or file path. Only ES|QL queries are supported.""" @@ -200,6 +212,9 @@ def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): click.echo(error_message) return + if not hunt_path: + raise ValueError("No hunt path found") + # Load the user configuration config = parse_user_config() if not config: @@ -234,7 +249,7 @@ def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): click.secho("Available queries:", fg="blue", bold=True) for i, query in eligible_queries.items(): click.secho(f"\nQuery {i + 1}:", fg="green", bold=True) - click.echo(query_runner._format_query(query)) + click.echo(query_runner.format_query(query)) click.secho("\n" + "-" * 120, fg="yellow") # Handle query selection diff --git a/hunting/definitions.py b/hunting/definitions.py index ef52da689f6..40968e16727 100644 --- a/hunting/definitions.py +++ b/hunting/definitions.py @@ -6,7 +6,6 @@ import re from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, List # Define the hunting directory path HUNTING_DIR = Path(__file__).parent @@ -16,25 +15,24 @@ ATTACK_URL = "https://attack.mitre.org/techniques/" # Static mapping for specific integrations -STATIC_INTEGRATION_LINK_MAP = { - 'aws_bedrock.invocation': 'aws_bedrock' -} +STATIC_INTEGRATION_LINK_MAP = {"aws_bedrock.invocation": "aws_bedrock"} @dataclass class Hunt: """Dataclass to represent a hunt.""" + author: str description: str - integration: List[str] + integration: list[str] uuid: str name: str - language: List[str] + language: list[str] license: str - query: List[str] - notes: Optional[List[str]] = field(default_factory=list) - mitre: List[str] = field(default_factory=list) - references: Optional[List[str]] = field(default_factory=list) + query: list[str] + notes: list[str] | None = field(default_factory=list) # type: ignore[reportUnknownVariableType] + mitre: list[str] = field(default_factory=list) # type: ignore[reportUnknownVariableType] + references: list[str] | None = field(default_factory=list) # type: ignore[reportUnknownVariableType] def __post_init__(self): """Post-initialization to determine which validation to apply.""" @@ -42,7 +40,7 @@ def __post_init__(self): raise ValueError(f"Hunt: {self.name} - Query field must be provided.") # Loop through each query in the array - for idx, q in enumerate(self.query): + for q in self.query: query_start = q.strip().lower() # Only validate queries that start with "from" (ESQL queries) @@ -55,8 +53,8 @@ def validate_esql_query(self, query: str) -> None: if self.author == "Elastic": # Regex patterns for checking "stats by" and "| keep" - stats_by_pattern = re.compile(r'\bstats\b.*?\bby\b', re.DOTALL) - keep_pattern = re.compile(r'\| keep', re.DOTALL) + stats_by_pattern = re.compile(r"\bstats\b.*?\bby\b", re.DOTALL) + keep_pattern = re.compile(r"\| keep", re.DOTALL) # Check if either "stats by" or "| keep" exists in the query if not stats_by_pattern.search(query) and not keep_pattern.search(query): diff --git a/hunting/markdown.py b/hunting/markdown.py index 19c0c57995e..8573d159d26 100644 --- a/hunting/markdown.py +++ b/hunting/markdown.py @@ -11,6 +11,7 @@ class MarkdownGenerator: """Class to generate or update Markdown documentation from TOML or YAML files.""" + def __init__(self, base_path: Path): """Initialize with the base path and load the hunting index.""" self.base_path = base_path @@ -18,7 +19,7 @@ def __init__(self, base_path: Path): def process_file(self, file_path: Path) -> None: """Process a single TOML file and generate its Markdown representation.""" - if not file_path.is_file() or file_path.suffix != '.toml': + if not file_path.is_file() or file_path.suffix != ".toml": raise ValueError(f"The provided path is not a valid TOML file: {file_path}") click.echo(f"Processing specific TOML file: {file_path}") @@ -83,7 +84,7 @@ def convert_toml_to_markdown(self, hunt_config: Hunt, file_path: Path) -> str: def save_markdown(self, markdown_path: Path, content: str) -> None: """Save the Markdown content to a file.""" - markdown_path.write_text(content, encoding="utf-8") + _ = markdown_path.write_text(content, encoding="utf-8") click.echo(f"Markdown generated: {markdown_path}") def update_or_add_entry(self, hunt_config: Hunt, toml_path: Path) -> None: @@ -92,9 +93,9 @@ def update_or_add_entry(self, hunt_config: Hunt, toml_path: Path) -> None: uuid = hunt_config.uuid entry = { - 'name': hunt_config.name, - 'path': f"./{toml_path.resolve().relative_to(self.base_path).as_posix()}", - 'mitre': hunt_config.mitre + "name": hunt_config.name, + "path": f"./{toml_path.resolve().relative_to(self.base_path).as_posix()}", + "mitre": hunt_config.mitre, } if folder_name not in self.hunting_index: @@ -112,16 +113,16 @@ def create_docs_folder(self, file_path: Path) -> Path: def generate_integration_links(self, integrations: list[str]) -> list[str]: """Generate integration links for the documentation.""" - base_url = 'https://docs.elastic.co/integrations' - generated = [] + base_url = "https://docs.elastic.co/integrations" + generated: list[str] = [] for integration in integrations: if integration in STATIC_INTEGRATION_LINK_MAP: link_str = STATIC_INTEGRATION_LINK_MAP[integration] else: - link_str = integration.replace('.', '/') - link = f'{base_url}/{link_str}' + link_str = integration.replace(".", "/") + link = f"{base_url}/{link_str}" validate_link(link) - generated.append(f'[{integration}]({link})') + generated.append(f"[{integration}]({link})") return generated def update_index_md(self) -> None: @@ -135,10 +136,10 @@ def update_index_md(self) -> None: for folder, files in sorted(self.hunting_index.items()): index_content += f"\n\n## {folder}\n" - for file_info in sorted(files.values(), key=lambda x: x['name']): - md_path = file_info['path'].replace('queries', 'docs').replace('.toml', '.md') + for file_info in sorted(files.values(), key=lambda x: x["name"]): + md_path = file_info["path"].replace("queries", "docs").replace(".toml", ".md") index_content += f"- [{file_info['name']}]({md_path}) (ES|QL)\n" index_md_path = self.base_path / "index.md" - index_md_path.write_text(index_content, encoding="utf-8") + _ = index_md_path.write_text(index_content, encoding="utf-8") click.echo(f"Index Markdown updated at: {index_md_path}") diff --git a/hunting/run.py b/hunting/run.py index 17ce2991474..8dfa74b4d83 100644 --- a/hunting/run.py +++ b/hunting/run.py @@ -5,6 +5,7 @@ import re import textwrap +from typing import Any from pathlib import Path import click @@ -15,7 +16,7 @@ class QueryRunner: - def __init__(self, es_config: dict): + def __init__(self, es_config: dict[str, Any]): """Initialize the QueryRunner with Elasticsearch config.""" self.es_config = es_config @@ -25,13 +26,13 @@ def load_hunting_file(self, file_path: Path): def preprocess_query(self, query: str) -> str: """Pre-process the query by removing comments and adding a LIMIT.""" - query = re.sub(r'//.*', '', query) - if not re.search(r'LIMIT', query, re.IGNORECASE): + query = re.sub(r"//.*", "", query) + if not re.search(r"LIMIT", query, re.IGNORECASE): query += " | LIMIT 10" click.echo("No LIMIT detected in query. Added LIMIT 10 to truncate output.") return query - def run_individual_query(self, query: str, wait_timeout: int): + def run_individual_query(self, query: str, _: int): """Run a single query with the Elasticsearch config.""" es = get_elasticsearch_client(**self.es_config) query = self.preprocess_query(query) @@ -42,7 +43,13 @@ def run_individual_query(self, query: str, wait_timeout: int): # Start the query synchronously response = es.esql.query(query=query) - self.process_results(response) + + response_data = response.body + if response_data.get("values"): + click.secho("Query matches found!", fg="red", bold=True) + else: + click.secho("No matches found!", fg="green", bold=True) + except Exception as e: # handle missing index error if "Unknown index" in str(e): @@ -52,25 +59,17 @@ def run_individual_query(self, query: str, wait_timeout: int): else: click.secho(f"Error running query: {str(e)}", fg="red") - def run_all_queries(self, queries: dict, wait_timeout: int): + def run_all_queries(self, queries: dict[int, Any], wait_timeout: int): """Run all eligible queries in the hunting file.""" click.secho("Running all eligible queries...", fg="green", bold=True) for i, query in queries.items(): click.secho(f"\nRunning Query {i + 1}:", fg="green", bold=True) - click.echo(self._format_query(query)) + click.echo(self.format_query(query)) self.run_individual_query(query, wait_timeout) click.secho("\n" + "-" * 120, fg="yellow") - def process_results(self, response): - """Process the Elasticsearch query results and display the outcome.""" - response_data = response.body - if response_data.get('values'): - click.secho("Query matches found!", fg="red", bold=True) - else: - click.secho("No matches found!", fg="green", bold=True) - - def _format_query(self, query: str) -> str: + def format_query(self, query: str) -> str: """Format the query with word wrapping for better readability.""" - lines = query.split('\n') - wrapped_lines = [textwrap.fill(line, width=120, subsequent_indent=' ') for line in lines] - return '\n'.join(wrapped_lines) + lines = query.split("\n") + wrapped_lines = [textwrap.fill(line, width=120, subsequent_indent=" ") for line in lines] + return "\n".join(wrapped_lines) diff --git a/hunting/search.py b/hunting/search.py index 615e8cd222d..101827abb56 100644 --- a/hunting/search.py +++ b/hunting/search.py @@ -5,6 +5,8 @@ from pathlib import Path +from typing import Any + import click from detection_rules.attack import tactics_map, technique_lookup from .utils import load_index_file, load_all_toml @@ -15,10 +17,10 @@ def __init__(self, base_path: Path): """Initialize with the base path and load the index.""" self.base_path = base_path self.hunting_index = load_index_file() - self.mitre_technique_ids = set() + self.mitre_technique_ids: set[str] = set() self.reverse_tactics_map = {v: k for k, v in tactics_map.items()} - def _process_mitre_filter(self, mitre_filter: tuple): + def _process_mitre_filter(self, mitre_filter: tuple[str, ...]): """Process the MITRE filter to gather all matching techniques.""" for filter_item in mitre_filter: if filter_item in self.reverse_tactics_map: @@ -26,29 +28,30 @@ def _process_mitre_filter(self, mitre_filter: tuple): elif filter_item in technique_lookup: self._process_technique_id(filter_item) - def _process_tactic_id(self, filter_item): + def _process_tactic_id(self, filter_item: str): """Helper method to process a tactic ID.""" tactic_name = self.reverse_tactics_map[filter_item] click.echo(f"Found tactic ID {filter_item} (Tactic Name: {tactic_name}). Searching for associated techniques.") for tech_id, details in technique_lookup.items(): - kill_chain_phases = details.get('kill_chain_phases', []) - if any(tactic_name.lower().replace(' ', '-') == phase['phase_name'] for phase in kill_chain_phases): + kill_chain_phases = details.get("kill_chain_phases", []) + if any(tactic_name.lower().replace(" ", "-") == phase["phase_name"] for phase in kill_chain_phases): self.mitre_technique_ids.add(tech_id) - def _process_technique_id(self, filter_item): + def _process_technique_id(self, filter_item: str): """Helper method to process a technique or sub-technique ID.""" self.mitre_technique_ids.add(filter_item) - if '.' not in filter_item: + if "." not in filter_item: sub_techniques = { - sub_tech_id for sub_tech_id in technique_lookup - if sub_tech_id.startswith(f"{filter_item}.") + sub_tech_id for sub_tech_id in technique_lookup if sub_tech_id.startswith(f"{filter_item}.") } self.mitre_technique_ids.update(sub_techniques) - def search(self, mitre_filter: tuple = (), data_source: str = None, keyword: str = None) -> list: + def search( + self, mitre_filter: tuple[str, ...] = (), data_source: str | None = None, keyword: str | None = None + ) -> list[dict[str, Any]]: """Search the index based on MITRE techniques, data source, or keyword.""" - results = [] + results: list[dict[str, Any]] = [] # Step 1: If data source is provided, filter by data source first if data_source: @@ -65,8 +68,9 @@ def search(self, mitre_filter: tuple = (), data_source: str = None, keyword: str self._process_mitre_filter(mitre_filter) if results: # Filter existing results further by MITRE if data source results already exist - results = [result for result in results if - any(tech in self.mitre_technique_ids for tech in result['mitre'])] + results = [ + result for result in results if any(tech in self.mitre_technique_ids for tech in result["mitre"]) + ] else: # Otherwise, perform a fresh search based on MITRE filter results = self._search_index(mitre_filter) @@ -83,9 +87,9 @@ def search(self, mitre_filter: tuple = (), data_source: str = None, keyword: str return self._handle_no_results(results, mitre_filter, data_source, keyword) - def _search_index(self, mitre_filter: tuple = ()) -> list: + def _search_index(self, mitre_filter: tuple[str, ...] = ()) -> list[dict[str, Any]]: """Private method to search the index based on MITRE filter.""" - results = [] + results: list[dict[str, Any]] = [] # Load all TOML data for detailed fields hunting_content = load_all_toml(self.base_path) @@ -96,23 +100,23 @@ def _search_index(self, mitre_filter: tuple = ()) -> list: # Prepare the result with full hunt content fields matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = file_path results.append(matches) return results - def _search_keyword(self, keyword: str) -> list: + def _search_keyword(self, keyword: str) -> list[dict[str, Any]]: """Private method to search description, name, notes, and references fields for a keyword.""" - results = [] + results: list[dict[str, Any]] = [] hunting_content = load_all_toml(self.base_path) for hunt_content, file_path in hunting_content: # Assign blank if notes or references are missing - notes = '::'.join(hunt_content.notes) if hunt_content.notes else '' - references = '::'.join(hunt_content.references) if hunt_content.references else '' + notes = "::".join(hunt_content.notes) if hunt_content.notes else "" + references = "::".join(hunt_content.references) if hunt_content.references else "" # Combine name, description, notes, and references for the search combined_content = f"{hunt_content.name}::{hunt_content.description}::{notes}::{references}" @@ -120,18 +124,18 @@ def _search_keyword(self, keyword: str) -> list: if keyword.lower() in combined_content.lower(): # Copy hunt_content data and prepare the result matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = file_path results.append(matches) return results - def _filter_by_data_source(self, data_source: str) -> list: + def _filter_by_data_source(self, data_source: str) -> list[dict[str, Any]]: """Filter the index by data source, checking both the actual files and the index.""" - results = [] - seen_uuids = set() # Track UUIDs to avoid duplicates + results: list[dict[str, Any]] = [] + seen_uuids: set[str] = set() # Track UUIDs to avoid duplicates # Load all TOML data for detailed fields hunting_content = load_all_toml(self.base_path) @@ -142,41 +146,48 @@ def _filter_by_data_source(self, data_source: str) -> list: if hunt_content.uuid not in seen_uuids: # Prepare the result with full hunt content fields matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = file_path results.append(matches) seen_uuids.add(hunt_content.uuid) # Step 2: Check the index for generic data sources (e.g., 'aws', 'linux') if data_source in self.hunting_index: - for query_uuid, query_data in self.hunting_index[data_source].items(): + for query_uuid, _ in self.hunting_index[data_source].items(): if query_uuid not in seen_uuids: # Find corresponding TOML content for this query - hunt_content = next((hunt for hunt, path in hunting_content if hunt.uuid == query_uuid), None) - if hunt_content: + h = next(((hunt, path) for hunt, path in hunting_content if hunt.uuid == query_uuid), None) + if h: + hunt_content, path = h # Prepare the result with full hunt content fields matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = path results.append(matches) seen_uuids.add(query_uuid) return results - def _matches_keyword(self, result: dict, keyword: str) -> bool: + def _matches_keyword(self, result: dict[str, Any], keyword: str) -> bool: """Check if the result matches the keyword in name, description, or notes.""" # Combine relevant fields for keyword search - notes = '::'.join(result.get('notes', [])) if 'notes' in result else '' - references = '::'.join(result.get('references', [])) if 'references' in result else '' + notes = "::".join(result.get("notes", [])) if "notes" in result else "" + references = "::".join(result.get("references", [])) if "references" in result else "" combined_content = f"{result['name']}::{result['description']}::{notes}::{references}" return keyword.lower() in combined_content.lower() - def _handle_no_results(self, results: list, mitre_filter=None, data_source=None, keyword=None) -> list: + def _handle_no_results( + self, + results: list[dict[str, Any]], + mitre_filter: tuple[str, ...] | None = None, + data_source: str | None = None, + keyword: str | None = None, + ) -> list[dict[str, Any]]: """Handle cases where no results are found.""" if not results: if mitre_filter and not self.mitre_technique_ids: diff --git a/hunting/utils.py b/hunting/utils.py index c704d16245d..4eadaa82dff 100644 --- a/hunting/utils.py +++ b/hunting/utils.py @@ -6,7 +6,7 @@ import inspect import tomllib from pathlib import Path -from typing import Union +from typing import Any import click import urllib3 @@ -17,17 +17,17 @@ from .definitions import HUNTING_DIR, Hunt -def get_hunt_path(uuid: str, file_path: str) -> (Path, str): +def get_hunt_path(uuid: str, file_path: str) -> tuple[Path | None, str | None]: """Resolve the path of the hunting query using either a UUID or file path.""" if uuid: # Load the index and find the hunt by UUID index_data = load_index_file() - for data_source, hunts in index_data.items(): + for _, hunts in index_data.items(): if uuid in hunts: hunt_data = hunts[uuid] # Combine the relative path from the index with the HUNTING_DIR - hunt_path = HUNTING_DIR / hunt_data['path'] + hunt_path = HUNTING_DIR / hunt_data["path"] return hunt_path.resolve(), None return None, f"No hunt found for UUID: {uuid}" @@ -41,20 +41,20 @@ def get_hunt_path(uuid: str, file_path: str) -> (Path, str): return None, "Either UUID or file path must be provided." -def load_index_file() -> dict: +def load_index_file() -> dict[str, Any]: """Load the hunting index.yml file.""" index_file = HUNTING_DIR / "index.yml" if not index_file.exists(): click.echo(f"No index.yml found at {index_file}.") return {} - with open(index_file, 'r') as f: + with open(index_file, "r") as f: hunting_index = yaml.safe_load(f) return hunting_index -def load_toml(source: Union[Path, str]) -> Hunt: +def load_toml(source: Path | str) -> Hunt: """Load and validate TOML content as Hunt dataclass.""" if isinstance(source, Path): if not source.is_file(): @@ -69,19 +69,19 @@ def load_toml(source: Union[Path, str]) -> Hunt: return Hunt(**toml_dict["hunt"]) -def load_all_toml(base_path: Path): +def load_all_toml(base_path: Path) -> list[tuple[Hunt, Path]]: """Load all TOML files from the directory and return a list of Hunt configurations and their paths.""" - hunts = [] + hunts: list[tuple[Hunt, Path]] = [] for toml_file in base_path.rglob("*.toml"): hunt_config = load_toml(toml_file) hunts.append((hunt_config, toml_file)) return hunts -def save_index_file(base_path: Path, directories: dict) -> None: +def save_index_file(base_path: Path, directories: dict[str, Any]): """Save the updated index.yml file.""" index_file = base_path / "index.yml" - with open(index_file, 'w') as f: + with open(index_file, "w") as f: yaml.safe_dump(directories, f, default_flow_style=False, sort_keys=False) print(f"Index YAML updated at: {index_file}") @@ -89,7 +89,7 @@ def save_index_file(base_path: Path, directories: dict) -> None: def validate_link(link: str): """Validate and return the link.""" http = urllib3.PoolManager() - response = http.request('GET', link) + response = http.request("GET", link) if response.status != 200: raise ValueError(f"Invalid link: {link}") @@ -109,9 +109,9 @@ def update_index_yml(base_path: Path) -> None: uuid = hunt_config.uuid entry = { - 'name': hunt_config.name, - 'path': f"./{toml_file.relative_to(base_path).as_posix()}", - 'mitre': hunt_config.mitre + "name": hunt_config.name, + "path": f"./{toml_file.relative_to(base_path).as_posix()}", + "mitre": hunt_config.mitre, } # Check if the folder_name exists and if it's a list, convert it to a dictionary @@ -120,14 +120,14 @@ def update_index_yml(base_path: Path) -> None: else: if isinstance(directories[folder_name], list): # Convert the list to a dictionary, using UUIDs as keys - directories[folder_name] = {item['uuid']: item for item in directories[folder_name]} + directories[folder_name] = {item["uuid"]: item for item in directories[folder_name]} directories[folder_name][uuid] = entry # Save the updated index.yml save_index_file(base_path, directories) -def filter_elasticsearch_params(config: dict) -> dict: +def filter_elasticsearch_params(config: dict[str, Any]) -> dict[str, Any]: """Filter out unwanted keys from the config by inspecting the Elasticsearch client constructor.""" # Get the parameter names from the Elasticsearch class constructor es_params = inspect.signature(get_elasticsearch_client).parameters diff --git a/pyproject.toml b/pyproject.toml index 010820e799c..2f5454d3683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.2.15" +version = "1.3.15" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12" @@ -19,32 +19,43 @@ classifiers = [ "Topic :: Utilities" ] dependencies = [ - "Click~=8.1.7", - "elasticsearch~=8.12.1", - "eql==0.9.19", - "jsl==0.2.4", - "jsonschema>=4.21.1", - "marko==2.0.3", - "marshmallow-dataclass[union]~=8.6.0", - "marshmallow-jsonschema~=0.13.0", - "marshmallow-union~=0.1.15", - "marshmallow~=3.21.1", - "pywin32 ; platform_system=='Windows'", - "pytoml==0.1.21", - "PyYAML~=6.0.1", - "requests~=2.31.0", - "toml==0.10.2", - "typing-inspect==0.9.0", - "typing-extensions==4.10.0", - "XlsxWriter~=3.2.0", - "semver==3.0.2", - "PyGithub==2.2.0", - "detection-rules-kql @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kql", - "detection-rules-kibana @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kibana", - "setuptools==75.2.0" + "Click~=8.1.7", + "elasticsearch~=8.12.1", + "eql==0.9.19", + "jsl==0.2.4", + "jsonschema>=4.21.1", + "marko==2.0.3", + "marshmallow-dataclass[union]>=8.7", + "marshmallow-jsonschema~=0.13.0", + "marshmallow-union~=0.1.15", + "marshmallow~=3.26.1", + "pywin32 ; platform_system=='Windows'", + # FIXME: pytoml is outdated and should not be used + "pytoml==0.1.21", + "PyYAML~=6.0.1", + "requests~=2.31.0", + "toml==0.10.2", + "typing-inspect==0.9.0", + "typing-extensions>=4.12", + "XlsxWriter~=3.2.0", + "semver==3.0.2", + "PyGithub==2.2.0", + "detection-rules-kql @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kql", + "detection-rules-kibana @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kibana", + "setuptools==75.2.0" ] [project.optional-dependencies] -dev = ["pep8-naming==0.13.0", "flake8==7.0.0", "pyflakes==3.2.0", "pytest>=8.1.1", "nodeenv==1.8.0", "pre-commit==3.6.2"] +dev = [ + "pep8-naming==0.13.0", + "flake8==7.0.0", + "pyflakes==3.2.0", + "pytest>=8.1.1", + "nodeenv==1.8.0", + "pre-commit==3.6.2", + "ruff>=0.11", + "pyright>=1.1", +] + hunting = ["tabulate==0.9.0"] [project.urls] @@ -53,15 +64,38 @@ hunting = ["tabulate==0.9.0"] "Research" = "https://www.elastic.co/security-labs" "Elastic" = "https://www.elastic.co" +[build-system] +requires = ["setuptools", "wheel", "setuptools_scm"] +build-backend = "setuptools.build_meta" + [tool.setuptools] package-data = {"kql" = ["*.g"]} packages = ["detection_rules", "hunting"] [tool.pytest.ini_options] filterwarnings = [ - "ignore::DeprecationWarning" + "ignore::DeprecationWarning" ] -[build-system] -requires = ["setuptools", "wheel", "setuptools_scm"] -build-backend = "setuptools.build_meta" +[tool.ruff] +line-length = 120 +indent-width = 4 +include = [ + "pyproject.toml", + "detection_rules/**/*.py", + "hunting/**/*.py", + "tests/**/*.py", +] +show-fixes = true + +[tool.pyright] +include = [ + "detection_rules/", + "hunting/", +] +exclude = [ + "tests/", +] +reportMissingTypeStubs = true +reportUnusedCallResult = "error" +typeCheckingMode = "strict" diff --git a/tests/__init__.py b/tests/__init__.py index 43f48cf2144..71598f578e6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,21 +4,22 @@ # 2.0. """Detection Rules tests.""" + import glob import json import os -from detection_rules.utils import combine_sources +from detection_rules.eswrap import combine_sources CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -DATA_DIR = os.path.join(CURRENT_DIR, 'data') -TP_DIR = os.path.join(DATA_DIR, 'true_positives') -FP_DIR = os.path.join(DATA_DIR, 'false_positives') +DATA_DIR = os.path.join(CURRENT_DIR, "data") +TP_DIR = os.path.join(DATA_DIR, "true_positives") +FP_DIR = os.path.join(DATA_DIR, "false_positives") def get_fp_dirs(): """Get a list of fp dir names.""" - return glob.glob(os.path.join(FP_DIR, '*')) + return glob.glob(os.path.join(FP_DIR, "*")) def get_fp_data_files(): @@ -26,31 +27,31 @@ def get_fp_data_files(): data = {} for fp_dir in get_fp_dirs(): fp_dir_name = os.path.basename(fp_dir) - relative_dir_name = os.path.join('false_positives', fp_dir_name) + relative_dir_name = os.path.join("false_positives", fp_dir_name) data[fp_dir_name] = combine_sources(*get_data_files(relative_dir_name).values()) return data -def get_data_files_list(*folder, ext='ndjson', recursive=False): +def get_data_files_list(*folder, ext="ndjson", recursive=False): """Get TP or FP file list.""" folder = os.path.sep.join(folder) data_dir = [DATA_DIR, folder] if recursive: - data_dir.append('**') + data_dir.append("**") - data_dir.append('*.{}'.format(ext)) + data_dir.append("*.{}".format(ext)) return glob.glob(os.path.join(*data_dir), recursive=recursive) -def get_data_files(*folder, ext='ndjson', recursive=False): +def get_data_files(*folder, ext="ndjson", recursive=False): """Get data from data files.""" data_files = {} for data_file in get_data_files_list(*folder, ext=ext, recursive=recursive): - with open(data_file, 'r') as f: + with open(data_file, "r") as f: file_name = os.path.splitext(os.path.basename(data_file))[0] - if ext in ('.ndjson', '.jsonl'): + if ext in (".ndjson", ".jsonl"): data = f.readlines() data_files[file_name] = [json.loads(d) for d in data] else: @@ -62,5 +63,5 @@ def get_data_files(*folder, ext='ndjson', recursive=False): def get_data_file(*folder): file = os.path.join(DATA_DIR, os.path.sep.join(folder)) if os.path.exists(file): - with open(file, 'r') as f: + with open(file, "r") as f: return json.load(f) diff --git a/tests/base.py b/tests/base.py index b56ede1c5c2..2ff0283131c 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ # 2.0. """Shared resources for tests.""" + import os import unittest from pathlib import Path @@ -19,7 +20,7 @@ RULE_LOADER_FAIL_MSG = None RULE_LOADER_FAIL_RAISED = False -CUSTOM_RULES_DIR = os.getenv('CUSTOM_RULES_DIR', None) +CUSTOM_RULES_DIR = os.getenv("CUSTOM_RULES_DIR", None) RULES_CONFIG = parse_rules_config() @@ -28,7 +29,7 @@ def load_rules() -> RuleCollection: if CUSTOM_RULES_DIR: rc = RuleCollection() path = Path(CUSTOM_RULES_DIR) - assert path.exists(), f'Custom rules directory {path} does not exist' + assert path.exists(), f"Custom rules directory {path} does not exist" rc.load_directories(directories=RULES_CONFIG.rule_dirs) rc.freeze() return rc @@ -36,7 +37,7 @@ def load_rules() -> RuleCollection: def default_bbr(rc: RuleCollection) -> RuleCollection: - rules = [r for r in rc.rules if 'rules_building_block' in r.path.parent.parts] + rules = [r for r in rc.rules if "rules_building_block" in r.path.parent.parts] return RuleCollection(rules=rules) @@ -70,8 +71,8 @@ def setUpClass(cls): cls.rules_config = RULES_CONFIG @staticmethod - def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=' ->') -> str: - return f'{rule.id} - {rule.name}{trailer or ""}' + def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=" ->") -> str: + return f"{rule.id} - {rule.name}{trailer or ''}" def setUp(self) -> None: global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG, RULE_LOADER_FAIL_RAISED @@ -81,9 +82,9 @@ def setUp(self) -> None: # raise a dedicated test failure for the loader if not RULE_LOADER_FAIL_RAISED: RULE_LOADER_FAIL_RAISED = True - with self.subTest('Test that the rule loader loaded with no validation or other failures.'): - self.fail(f'Rule loader failure: \n{RULE_LOADER_FAIL_MSG}') + with self.subTest("Test that the rule loader loaded with no validation or other failures."): + self.fail(f"Rule loader failure: {RULE_LOADER_FAIL_MSG}") - self.skipTest('Rule loader failure') + self.skipTest("Rule loader failure") else: super().setUp() diff --git a/tests/kuery/test_dsl.py b/tests/kuery/test_dsl.py index 4af3217ebc0..9f566d03a6b 100644 --- a/tests/kuery/test_dsl.py +++ b/tests/kuery/test_dsl.py @@ -51,10 +51,9 @@ def test_and_query(self): def test_not_query(self): self.validate("not field:value", {"must_not": [{"match": {"field": "value"}}]}) self.validate("field:(not value)", {"must_not": [{"match": {"field": "value"}}]}) - self.validate("field:(a and not b)", { - "filter": [{"match": {"field": "a"}}], - "must_not": [{"match": {"field": "b"}}] - }) + self.validate( + "field:(a and not b)", {"filter": [{"match": {"field": "a"}}], "must_not": [{"match": {"field": "b"}}]} + ) self.validate( "not field:value and not field2:value2", {"must_not": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}, @@ -74,13 +73,10 @@ def test_not_query(self): optimize=False, ) - self.validate("not (field:value and field2:value2)", - { - "must_not": [ - {"match": {"field": "value"}}, - {"match": {"field2": "value2"}} - ] - }) + self.validate( + "not (field:value and field2:value2)", + {"must_not": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}, + ) def test_optimizations(self): self.validate( @@ -120,21 +116,9 @@ def test_optimizations(self): self.validate( "a:(v1 or v2 or v3) and b:(v4 or v5)", { - "should": [ - {"match": {"a": "v1"}}, - {"match": {"a": "v2"}}, - {"match": {"a": "v3"}} - ], + "should": [{"match": {"a": "v1"}}, {"match": {"a": "v2"}}, {"match": {"a": "v3"}}], "filter": [ - { - "bool": { - "should": [ - {"match": {"b": "v4"}}, - {"match": {"b": "v5"}} - ], - "minimum_should_match": 1 - } - } + {"bool": {"should": [{"match": {"b": "v4"}}, {"match": {"b": "v5"}}], "minimum_should_match": 1}} ], "minimum_should_match": 1, }, diff --git a/tests/kuery/test_eql2kql.py b/tests/kuery/test_eql2kql.py index de0e4404cb7..17b84029523 100644 --- a/tests/kuery/test_eql2kql.py +++ b/tests/kuery/test_eql2kql.py @@ -9,7 +9,6 @@ class TestEql2Kql(unittest.TestCase): - def validate(self, kql_source, eql_source): self.assertEqual(kql_source, str(kql.from_eql(eql_source))) @@ -55,8 +54,8 @@ def test_ip_checks(self): def test_wildcard_field(self): with eql.parser.elasticsearch_validate_optional_fields: - self.validate('field:value-*', 'field : "value-*"') - self.validate('field:value-?', 'field : "value-?"') + self.validate("field:value-*", 'field : "value-*"') + self.validate("field:value-?", 'field : "value-?"') with eql.parser.elasticsearch_validate_optional_fields, self.assertRaises(AssertionError): self.validate('field:"value-*"', 'field == "value-*"') diff --git a/tests/kuery/test_evaluator.py b/tests/kuery/test_evaluator.py index 97033e97b03..1474450629b 100644 --- a/tests/kuery/test_evaluator.py +++ b/tests/kuery/test_evaluator.py @@ -9,23 +9,15 @@ class EvaluatorTests(unittest.TestCase): - document = { "number": 1, "boolean": True, "ip": "192.168.16.3", "string": "hello world", - "string_list": ["hello world", "example"], "number_list": [1, 2, 3], "boolean_list": [True, False], - "structured": [ - { - "a": [ - {"b": 1} - ] - } - ], + "structured": [{"a": [{"b": 1}]}], } def evaluate(self, source_text, document=None): @@ -36,87 +28,87 @@ def evaluate(self, source_text, document=None): return evaluator(document) def test_single_value(self): - self.assertTrue(self.evaluate('number:1')) + self.assertTrue(self.evaluate("number:1")) self.assertTrue(self.evaluate('number:"1"')) - self.assertTrue(self.evaluate('boolean:true')) + self.assertTrue(self.evaluate("boolean:true")) self.assertTrue(self.evaluate('string:"hello world"')) - self.assertFalse(self.evaluate('number:0')) - self.assertFalse(self.evaluate('boolean:false')) + self.assertFalse(self.evaluate("number:0")) + self.assertFalse(self.evaluate("boolean:false")) self.assertFalse(self.evaluate('string:"missing"')) def test_list_value(self): - self.assertTrue(self.evaluate('number_list:1')) - self.assertTrue(self.evaluate('number_list:2')) - self.assertTrue(self.evaluate('number_list:3')) + self.assertTrue(self.evaluate("number_list:1")) + self.assertTrue(self.evaluate("number_list:2")) + self.assertTrue(self.evaluate("number_list:3")) - self.assertTrue(self.evaluate('boolean_list:true')) - self.assertTrue(self.evaluate('boolean_list:false')) + self.assertTrue(self.evaluate("boolean_list:true")) + self.assertTrue(self.evaluate("boolean_list:false")) self.assertTrue(self.evaluate('string_list:"hello world"')) - self.assertTrue(self.evaluate('string_list:example')) + self.assertTrue(self.evaluate("string_list:example")) - self.assertFalse(self.evaluate('number_list:4')) + self.assertFalse(self.evaluate("number_list:4")) self.assertFalse(self.evaluate('string_list:"missing"')) def test_and_values(self): - self.assertTrue(self.evaluate('number_list:(1 and 2)')) - self.assertTrue(self.evaluate('boolean_list:(false and true)')) + self.assertTrue(self.evaluate("number_list:(1 and 2)")) + self.assertTrue(self.evaluate("boolean_list:(false and true)")) self.assertFalse(self.evaluate('string:("missing" and "hello world")')) - self.assertFalse(self.evaluate('number:(0 and 1)')) - self.assertFalse(self.evaluate('boolean:(false and true)')) + self.assertFalse(self.evaluate("number:(0 and 1)")) + self.assertFalse(self.evaluate("boolean:(false and true)")) def test_not_value(self): - self.assertTrue(self.evaluate('number_list:1')) - self.assertFalse(self.evaluate('not number_list:1')) - self.assertFalse(self.evaluate('number_list:(not 1)')) + self.assertTrue(self.evaluate("number_list:1")) + self.assertFalse(self.evaluate("not number_list:1")) + self.assertFalse(self.evaluate("number_list:(not 1)")) def test_or_values(self): - self.assertTrue(self.evaluate('number:(0 or 1)')) - self.assertTrue(self.evaluate('number:(1 or 2)')) - self.assertTrue(self.evaluate('boolean:(false or true)')) + self.assertTrue(self.evaluate("number:(0 or 1)")) + self.assertTrue(self.evaluate("number:(1 or 2)")) + self.assertTrue(self.evaluate("boolean:(false or true)")) self.assertTrue(self.evaluate('string:("missing" or "hello world")')) - self.assertFalse(self.evaluate('number:(0 or 3)')) + self.assertFalse(self.evaluate("number:(0 or 3)")) def test_and_expr(self): - self.assertTrue(self.evaluate('number:1 and boolean:true')) + self.assertTrue(self.evaluate("number:1 and boolean:true")) - self.assertFalse(self.evaluate('number:1 and boolean:false')) + self.assertFalse(self.evaluate("number:1 and boolean:false")) def test_or_expr(self): - self.assertTrue(self.evaluate('number:1 or boolean:false')) - self.assertFalse(self.evaluate('number:0 or boolean:false')) + self.assertTrue(self.evaluate("number:1 or boolean:false")) + self.assertFalse(self.evaluate("number:0 or boolean:false")) def test_range(self): - self.assertTrue(self.evaluate('number < 2')) - self.assertFalse(self.evaluate('number > 2')) + self.assertTrue(self.evaluate("number < 2")) + self.assertFalse(self.evaluate("number > 2")) def test_cidr_match(self): - self.assertTrue(self.evaluate('ip:192.168.0.0/16')) + self.assertTrue(self.evaluate("ip:192.168.0.0/16")) - self.assertFalse(self.evaluate('ip:10.0.0.0/8')) + self.assertFalse(self.evaluate("ip:10.0.0.0/8")) def test_quoted_wildcard(self): self.assertFalse(self.evaluate("string:'*'")) self.assertFalse(self.evaluate("string:'?'")) def test_wildcard(self): - self.assertTrue(self.evaluate('string:hello*')) - self.assertTrue(self.evaluate('string:*world')) - self.assertFalse(self.evaluate('string:foobar*')) + self.assertTrue(self.evaluate("string:hello*")) + self.assertTrue(self.evaluate("string:*world")) + self.assertFalse(self.evaluate("string:foobar*")) def test_field_exists(self): - self.assertTrue(self.evaluate('number:*')) - self.assertTrue(self.evaluate('boolean:*')) - self.assertTrue(self.evaluate('ip:*')) - self.assertTrue(self.evaluate('string:*')) - self.assertTrue(self.evaluate('string_list:*')) - self.assertTrue(self.evaluate('number_list:*')) - self.assertTrue(self.evaluate('boolean_list:*')) - - self.assertFalse(self.evaluate('a:*')) + self.assertTrue(self.evaluate("number:*")) + self.assertTrue(self.evaluate("boolean:*")) + self.assertTrue(self.evaluate("ip:*")) + self.assertTrue(self.evaluate("string:*")) + self.assertTrue(self.evaluate("string_list:*")) + self.assertTrue(self.evaluate("number_list:*")) + self.assertTrue(self.evaluate("boolean_list:*")) + + self.assertFalse(self.evaluate("a:*")) def test_flattening(self): self.assertTrue(self.evaluate("structured.a.b:*")) diff --git a/tests/kuery/test_kql2eql.py b/tests/kuery/test_kql2eql.py index bfa9a242589..446fe216b37 100644 --- a/tests/kuery/test_kql2eql.py +++ b/tests/kuery/test_kql2eql.py @@ -10,7 +10,6 @@ class TestKql2Eql(unittest.TestCase): - def validate(self, kql_source, eql_source, schema=None): self.assertEqual(kql.to_eql(kql_source, schema=schema), eql.parse_expression(eql_source)) @@ -54,7 +53,7 @@ def test_list_of_values(self): self.validate("a:(0 or 1 and 2 or (3 and 4))", "a == 0 or a == 1 and a == 2 or (a == 3 and a == 4)") def test_lone_value(self): - for value in ["1", "-1.4", "true", "\"string test\""]: + for value in ["1", "-1.4", "true", '"string test"']: with self.assertRaisesRegex(kql.KqlParseError, "Value not tied to field"): kql.to_eql(value) @@ -71,15 +70,15 @@ def test_schema(self): } self.validate("top.numF : 1", "top.numF == 1", schema=schema) - self.validate("top.numF : \"1\"", "top.numF == 1", schema=schema) + self.validate('top.numF : "1"', "top.numF == 1", schema=schema) self.validate("top.keyword : 1", "top.keyword == '1'", schema=schema) - self.validate("top.keyword : \"hello\"", "top.keyword == 'hello'", schema=schema) + self.validate('top.keyword : "hello"', "top.keyword == 'hello'", schema=schema) self.validate("dest:192.168.255.255", "dest == '192.168.255.255'", schema=schema) self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) - self.validate("dest:\"192.168.0.0/16\"", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) + self.validate('dest:"192.168.0.0/16"', "cidrMatch(dest, '192.168.0.0/16')", schema=schema) with self.assertRaises(eql.EqlSemanticError): - self.validate("top.text : \"hello\"", "top.text == 'hello'", schema=schema) + self.validate('top.text : "hello"', "top.text == 'hello'", schema=schema) with self.assertRaises(eql.EqlSemanticError): self.validate("top.text : 1 ", "top.text == '1'", schema=schema) @@ -93,7 +92,7 @@ def test_schema(self): with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): kql.to_eql("top:{middle:{bool: true}}", schema=schema) - invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", "\"1\""] + invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", '"1"'] for ip in invalid_ips: with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match dest's type: ip"): kql.to_eql("dest:{ip}".format(ip=ip), schema=schema) diff --git a/tests/kuery/test_lint.py b/tests/kuery/test_lint.py index 31953cbd792..d6602658627 100644 --- a/tests/kuery/test_lint.py +++ b/tests/kuery/test_lint.py @@ -8,17 +8,16 @@ class LintTests(unittest.TestCase): - def validate(self, source, linted, *args): self.assertEqual(kql.lint(source), linted, *args) def test_lint_field(self): self.validate("a : b", "a:b") - self.validate("\"a\": b", "a:b") - self.validate("a : \"b\"", "a:b") + self.validate('"a": b', "a:b") + self.validate('a : "b"', "a:b") self.validate("a : (b)", "a:b") self.validate("a:1.234", "a:1.234") - self.validate("a:\"1.234\"", "a:1.234") + self.validate('a:"1.234"', "a:1.234") def test_upper_tokens(self): queries = [ @@ -80,7 +79,7 @@ def test_double_negate(self): self.validate("not (not (a:(not b) or c:(not d)))", "not a:b or not c:d") def test_ip(self): - self.validate("a:ff02\\:\\:fb", "a:\"ff02::fb\"") + self.validate("a:ff02\\:\\:fb", 'a:"ff02::fb"') def test_compound(self): self.validate("a:1 and b:2 and not (c:3 or c:4)", "a:1 and b:2 and not c:(3 or 4)") diff --git a/tests/kuery/test_parser.py b/tests/kuery/test_parser.py index 444d55f1bd1..d7903110fac 100644 --- a/tests/kuery/test_parser.py +++ b/tests/kuery/test_parser.py @@ -16,7 +16,6 @@ class ParserTests(unittest.TestCase): - def validate(self, source, tree, *args, **kwargs): kwargs.setdefault("optimize", False) self.assertEqual(kql.parse(source, *args, **kwargs), tree) @@ -28,14 +27,14 @@ def test_keyword(self): "b": "long", } - self.validate('a.text:hello', FieldComparison(Field("a.text"), String("hello")), schema=schema) - self.validate('a.keyword:hello', FieldComparison(Field("a.keyword"), String("hello")), schema=schema) + self.validate("a.text:hello", FieldComparison(Field("a.text"), String("hello")), schema=schema) + self.validate("a.keyword:hello", FieldComparison(Field("a.keyword"), String("hello")), schema=schema) self.validate('a.text:"hello"', FieldComparison(Field("a.text"), String("hello")), schema=schema) self.validate('a.keyword:"hello"', FieldComparison(Field("a.keyword"), String("hello")), schema=schema) - self.validate('a.text:1', FieldComparison(Field("a.text"), String("1")), schema=schema) - self.validate('a.keyword:1', FieldComparison(Field("a.keyword"), String("1")), schema=schema) + self.validate("a.text:1", FieldComparison(Field("a.text"), String("1")), schema=schema) + self.validate("a.keyword:1", FieldComparison(Field("a.keyword"), String("1")), schema=schema) self.validate('a.text:"1"', FieldComparison(Field("a.text"), String("1")), schema=schema) self.validate('a.keyword:"1"', FieldComparison(Field("a.keyword"), String("1")), schema=schema) @@ -43,10 +42,10 @@ def test_keyword(self): def test_conversion(self): schema = {"num": "long", "text": "text"} - self.validate('num:1', FieldComparison(Field("num"), Number(1)), schema=schema) + self.validate("num:1", FieldComparison(Field("num"), Number(1)), schema=schema) self.validate('num:"1"', FieldComparison(Field("num"), Number(1)), schema=schema) - self.validate('text:1', FieldComparison(Field("text"), String("1")), schema=schema) + self.validate("text:1", FieldComparison(Field("text"), String("1")), schema=schema) self.validate('text:"1"', FieldComparison(Field("text"), String("1")), schema=schema) def test_list_equals(self): @@ -57,11 +56,11 @@ def test_number_exists(self): def test_multiple_types_success(self): schema = {"common.a": "keyword", "common.b": "keyword"} - self.validate("common.* : \"hello\"", FieldComparison(Field("common.*"), String("hello")), schema=schema) + self.validate('common.* : "hello"', FieldComparison(Field("common.*"), String("hello")), schema=schema) def test_multiple_types_fail(self): with self.assertRaises(kql.KqlParseError): - kql.parse("common.* : \"hello\"", schema={"common.a": "keyword", "common.b": "ip"}) + kql.parse('common.* : "hello"', schema={"common.a": "keyword", "common.b": "ip"}) def test_number_wildcard_fail(self): with self.assertRaises(kql.KqlParseError): @@ -81,7 +80,7 @@ def test_type_family_fail(self): def test_date(self): schema = {"@time": "date"} - self.validate('@time <= now-10d', FieldRange(Field("@time"), "<=", String("now-10d")), schema=schema) + self.validate("@time <= now-10d", FieldRange(Field("@time"), "<=", String("now-10d")), schema=schema) with self.assertRaises(kql.KqlParseError): kql.parse("@time > 5", schema=schema) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index d93d4df4270..da31b7bee59 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -4,6 +4,7 @@ # 2.0. """Test that all rules have valid metadata and syntax.""" + import os import re import unittest @@ -20,12 +21,20 @@ import kql from detection_rules import attack from detection_rules.config import load_current_package_version -from detection_rules.integrations import (find_latest_compatible_version, - load_integrations_manifests, - load_integrations_schemas) +from detection_rules.integrations import ( + find_latest_compatible_version, + load_integrations_manifests, + load_integrations_schemas, +) from detection_rules.packaging import current_stack_version -from detection_rules.rule import (AlertSuppressionMapping, EQLRuleData, QueryRuleData, QueryValidator, - ThresholdAlertSuppression, TOMLRuleContents) +from detection_rules.rule import ( + AlertSuppressionMapping, + EQLRuleData, + QueryRuleData, + QueryValidator, + ThresholdAlertSuppression, + TOMLRuleContents, +) from detection_rules.rule_loader import FILE_PATTERN, RULES_CONFIG from detection_rules.rule_validators import EQLValidator, KQLValidator from detection_rules.schemas import definitions, get_min_supported_stack_version, get_stack_schemas @@ -42,34 +51,37 @@ class TestValidRules(BaseRuleTest): def test_schema_and_dupes(self): """Ensure that every rule matches the schema and there are no duplicates.""" - self.assertGreaterEqual(len(self.all_rules), 1, 'No rules were loaded from rules directory!') + self.assertGreaterEqual(len(self.all_rules), 1, "No rules were loaded from rules directory!") def test_file_names(self): """Test that the file names meet the requirement.""" file_pattern = FILE_PATTERN - self.assertIsNone(re.match(file_pattern, 'NotValidRuleFile.toml'), - f'Incorrect pattern for verifying rule names: {file_pattern}') - self.assertIsNone(re.match(file_pattern, 'still_not_a_valid_file_name.not_json'), - f'Incorrect pattern for verifying rule names: {file_pattern}') + self.assertIsNone( + re.match(file_pattern, "NotValidRuleFile.toml"), + f"Incorrect pattern for verifying rule names: {file_pattern}", + ) + self.assertIsNone( + re.match(file_pattern, "still_not_a_valid_file_name.not_json"), + f"Incorrect pattern for verifying rule names: {file_pattern}", + ) for rule in self.all_rules: file_name = str(rule.path.name) - self.assertIsNotNone(re.match(file_pattern, file_name), f'Invalid file name for {rule.path}') + self.assertIsNotNone(re.match(file_pattern, file_name), f"Invalid file name for {rule.path}") def test_all_rule_queries_optimized(self): """Ensure that every rule query is in optimized form.""" for rule in self.all_rules: - if ( - rule.contents.data.get("language") == "kuery" and not any( - item in rule.contents.data.query for item in definitions.QUERY_FIELD_OP_EXCEPTIONS - ) + if rule.contents.data.get("language") == "kuery" and not any( + item in rule.contents.data.query for item in definitions.QUERY_FIELD_OP_EXCEPTIONS ): source = rule.contents.data.query tree = kql.parse(source, optimize=False, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) optimized = tree.optimize(recursive=True) - err_message = f'\n{self.rule_str(rule)} Query not optimized for rule\n' \ - f'Expected: {optimized}\nActual: {source}' + err_message = ( + f"\n{self.rule_str(rule)} Query not optimized for rule\nExpected: {optimized}\nActual: {source}" + ) self.assertEqual(tree, optimized, err_message) def test_duplicate_file_names(self): @@ -99,7 +111,7 @@ def test_bbr_validation(self): "rule_id": str(uuid.uuid4()), "severity": "low", "type": "query", - "timestamp_override": "event.ingested" + "timestamp_override": "event.ingested", } def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m"): @@ -107,7 +119,7 @@ def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m") "creation_date": "1970/01/01", "updated_date": "1970/01/01", "min_stack_version": load_current_package_version(), - "integration": ["cloud_defend"] + "integration": ["cloud_defend"], } data = base_fields.copy() data["query"] = query @@ -133,61 +145,65 @@ def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m") def test_max_signals_note(self): """Ensure the max_signals note is present when max_signals > 1000.""" - max_signal_standard_setup = 'For information on troubleshooting the maximum alerts warning '\ - 'please refer to this [guide]'\ - '(https://www.elastic.co/guide/en/security/current/alerts-ui-monitor.html#troubleshoot-max-alerts).' # noqa: E501 + max_signal_standard_setup = ( + "For information on troubleshooting the maximum alerts warning " + "please refer to this [guide]" + "(https://www.elastic.co/guide/en/security/current/alerts-ui-monitor.html#troubleshoot-max-alerts)." + ) # noqa: E501 for rule in self.all_rules: if rule.contents.data.max_signals and rule.contents.data.max_signals > 1000: - error_message = f'{self.rule_str(rule)} max_signals cannot exceed 1000.' - self.fail(f'{self.rule_str(rule)} max_signals cannot exceed 1000.') + error_message = f"{self.rule_str(rule)} max_signals cannot exceed 1000." + self.fail(f"{self.rule_str(rule)} max_signals cannot exceed 1000.") if rule.contents.data.max_signals and rule.contents.data.max_signals == 1000: - error_message = f'{self.rule_str(rule)} note required for max_signals == 1000' + error_message = f"{self.rule_str(rule)} note required for max_signals == 1000" self.assertIsNotNone(rule.contents.data.setup, error_message) if max_signal_standard_setup not in rule.contents.data.setup: - self.fail(f'{self.rule_str(rule)} expected max_signals note missing\n\n' - f'Expected: {max_signal_standard_setup}\n\n' - f'Actual: {rule.contents.data.setup}') + self.fail( + f"{self.rule_str(rule)} expected max_signals note missing\n\n" + f"Expected: {max_signal_standard_setup}\n\n" + f"Actual: {rule.contents.data.setup}" + ) def test_from_filed_value(self): - """ Add "from" Field Validation for All Rules""" + """Add "from" Field Validation for All Rules""" failures = [] - valid_format = re.compile(r'^now-\d+[yMwdhHms]$') + valid_format = re.compile(r"^now-\d+[yMwdhHms]$") for rule in self.all_rules: - from_field = rule.contents.data.get('from_') + from_field = rule.contents.data.get("from_") if from_field is not None: if not valid_format.match(from_field): - err_msg = f'{self.rule_str(rule)} has invalid value {from_field}' + err_msg = f"{self.rule_str(rule)} has invalid value {from_field}" failures.append(err_msg) if failures: fail_msg = """ The following rules have invalid 'from' filed value \n """ - self.fail(fail_msg + '\n'.join(failures)) + self.fail(fail_msg + "\n".join(failures)) def test_index_or_data_view_id_present(self): """Ensure that either 'index' or 'data_view_id' is present for prebuilt rules.""" failures = [] machine_learning_packages = [val.lower() for val in definitions.MACHINE_LEARNING_PACKAGES] for rule in self.all_rules: - rule_type = rule.contents.data.get('language') - rule_integrations = rule.contents.metadata.get('integration') or [] - if rule_type == 'esql': + rule_type = rule.contents.data.get("language") + rule_integrations = rule.contents.metadata.get("integration") or [] + if rule_type == "esql": continue # the index is part of the query and would be validated in the query - elif rule.contents.data.type == 'machine_learning' or rule_integrations in machine_learning_packages: + elif rule.contents.data.type == "machine_learning" or rule_integrations in machine_learning_packages: continue # Skip all rules of machine learning type or rules that are part of machine learning packages - elif rule.contents.data.type == 'threat_match': + elif rule.contents.data.type == "threat_match": continue # Skip all rules of threat_match type else: - index = rule.contents.data.get('index') - data_view_id = rule.contents.data.get('data_view_id') + index = rule.contents.data.get("index") + data_view_id = rule.contents.data.get("data_view_id") if index is None and data_view_id is None: - err_msg = f'{self.rule_str(rule)} does not have either index or data_view_id' + err_msg = f"{self.rule_str(rule)} does not have either index or data_view_id" failures.append(err_msg) if failures: fail_msg = """ The following prebuilt rules do not have either 'index' or 'data_view_id' \n """ - self.fail(fail_msg + '\n'.join(failures)) + self.fail(fail_msg + "\n".join(failures)) class TestThreatMappings(BaseRuleTest): @@ -205,14 +221,15 @@ def test_technique_deprecations(self): if threat_mapping: for entry in threat_mapping: - for technique in (entry.technique or []): + for technique in entry.technique or []: if technique.id in revoked + deprecated: - revoked_techniques[technique.id] = replacement_map.get(technique.id, - 'DEPRECATED - DO NOT USE') + revoked_techniques[technique.id] = replacement_map.get( + technique.id, "DEPRECATED - DO NOT USE" + ) if revoked_techniques: - old_new_mapping = "\n".join(f'Actual: {k} -> Expected {v}' for k, v in revoked_techniques.items()) - self.fail(f'{self.rule_str(rule)} Using deprecated ATT&CK techniques: \n{old_new_mapping}') + old_new_mapping = "\n".join(f"Actual: {k} -> Expected {v}" for k, v in revoked_techniques.items()) + self.fail(f"{self.rule_str(rule)} Using deprecated ATT&CK techniques: \n{old_new_mapping}") def test_tactic_to_technique_correlations(self): """Ensure rule threat info is properly related to a single tactic and technique.""" @@ -225,52 +242,73 @@ def test_tactic_to_technique_correlations(self): mismatched = [t.id for t in techniques if t.id not in attack.matrix[tactic.name]] if mismatched: - self.fail(f'mismatched ATT&CK techniques for rule: {self.rule_str(rule)} ' - f'{", ".join(mismatched)} not under: {tactic["name"]}') + self.fail( + f"mismatched ATT&CK techniques for rule: {self.rule_str(rule)} " + f"{', '.join(mismatched)} not under: {tactic['name']}" + ) # tactic expected_tactic = attack.tactics_map[tactic.name] - self.assertEqual(expected_tactic, tactic.id, - f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_tactic} for {tactic.name}\n' - f'actual: {tactic.id}') - - tactic_reference_id = tactic.reference.rstrip('/').split('/')[-1] - self.assertEqual(tactic.id, tactic_reference_id, - f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' - f'tactic ID {tactic.id} does not match the reference URL ID ' - f'{tactic.reference}') + self.assertEqual( + expected_tactic, + tactic.id, + f"ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n" + f"expected: {expected_tactic} for {tactic.name}\n" + f"actual: {tactic.id}", + ) + + tactic_reference_id = tactic.reference.rstrip("/").split("/")[-1] + self.assertEqual( + tactic.id, + tactic_reference_id, + f"ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n" + f"tactic ID {tactic.id} does not match the reference URL ID " + f"{tactic.reference}", + ) # techniques for technique in techniques: - expected_technique = attack.technique_lookup[technique.id]['name'] - self.assertEqual(expected_technique, technique.name, - f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_technique} for {technique.id}\n' - f'actual: {technique.name}') - - technique_reference_id = technique.reference.rstrip('/').split('/')[-1] - self.assertEqual(technique.id, technique_reference_id, - f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' - f'technique ID {technique.id} does not match the reference URL ID ' - f'{technique.reference}') + expected_technique = attack.technique_lookup[technique.id]["name"] + self.assertEqual( + expected_technique, + technique.name, + f"ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n" + f"expected: {expected_technique} for {technique.id}\n" + f"actual: {technique.name}", + ) + + technique_reference_id = technique.reference.rstrip("/").split("/")[-1] + self.assertEqual( + technique.id, + technique_reference_id, + f"ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n" + f"technique ID {technique.id} does not match the reference URL ID " + f"{technique.reference}", + ) # sub-techniques sub_techniques = technique.subtechnique or [] if sub_techniques: for sub_technique in sub_techniques: - expected_sub_technique = attack.technique_lookup[sub_technique.id]['name'] - self.assertEqual(expected_sub_technique, sub_technique.name, - f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_sub_technique} for {sub_technique.id}\n' - f'actual: {sub_technique.name}') - - sub_technique_reference_id = '.'.join( - sub_technique.reference.rstrip('/').split('/')[-2:]) - self.assertEqual(sub_technique.id, sub_technique_reference_id, - f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' - f'sub-technique ID {sub_technique.id} does not match the reference URL ID ' # noqa: E501 - f'{sub_technique.reference}') + expected_sub_technique = attack.technique_lookup[sub_technique.id]["name"] + self.assertEqual( + expected_sub_technique, + sub_technique.name, + f"ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n" + f"expected: {expected_sub_technique} for {sub_technique.id}\n" + f"actual: {sub_technique.name}", + ) + + sub_technique_reference_id = ".".join( + sub_technique.reference.rstrip("/").split("/")[-2:] + ) + self.assertEqual( + sub_technique.id, + sub_technique_reference_id, + f"ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n" + f"sub-technique ID {sub_technique.id} does not match the reference URL ID " # noqa: E501 + f"{sub_technique.reference}", + ) def test_duplicated_tactics(self): """Check that a tactic is only defined once.""" @@ -280,11 +318,13 @@ def test_duplicated_tactics(self): duplicates = sorted(set(t for t in tactics if tactics.count(t) > 1)) if duplicates: - self.fail(f'{self.rule_str(rule)} duplicate tactics defined for {duplicates}. ' - f'Flatten to a single entry per tactic') + self.fail( + f"{self.rule_str(rule)} duplicate tactics defined for {duplicates}. " + f"Flatten to a single entry per tactic" + ) -@unittest.skipIf(os.environ.get('DR_BYPASS_TAGS_VALIDATION') is not None, "Skipping tag validation") +@unittest.skipIf(os.environ.get("DR_BYPASS_TAGS_VALIDATION") is not None, "Skipping tag validation") class TestRuleTags(BaseRuleTest): """Test tags data for rules.""" @@ -297,59 +337,62 @@ def test_casing_and_spacing(self): rule_tags = rule.contents.data.tags if rule_tags: - invalid_tags = {t: expected_case[t.casefold()] for t in rule_tags - if t.casefold() in list(expected_case) and t != expected_case[t.casefold()]} + invalid_tags = { + t: expected_case[t.casefold()] + for t in rule_tags + if t.casefold() in list(expected_case) and t != expected_case[t.casefold()] + } if invalid_tags: - error_msg = f'{self.rule_str(rule)} Invalid casing for expected tags\n' - error_msg += f'Actual tags: {", ".join(invalid_tags)}\n' - error_msg += f'Expected tags: {", ".join(invalid_tags.values())}' + error_msg = f"{self.rule_str(rule)} Invalid casing for expected tags\n" + error_msg += f"Actual tags: {', '.join(invalid_tags)}\n" + error_msg += f"Expected tags: {', '.join(invalid_tags.values())}" self.fail(error_msg) def test_required_tags(self): """Test that expected tags are present within rules.""" required_tags_map = { - 'logs-endpoint.events.*': {'all': ['Domain: Endpoint', 'Data Source: Elastic Defend']}, - 'endgame-*': {'all': ['Data Source: Elastic Endgame']}, - 'logs-aws*': {'all': ['Data Source: AWS', 'Data Source: Amazon Web Services', 'Domain: Cloud']}, - 'logs-azure*': {'all': ['Data Source: Azure', 'Domain: Cloud']}, - 'logs-o365*': {'all': ['Data Source: Microsoft 365', 'Domain: Cloud']}, - 'logs-okta*': {'all': ['Data Source: Okta']}, - 'logs-gcp*': {'all': ['Data Source: Google Cloud Platform', 'Data Source: GCP', 'Domain: Cloud']}, - 'logs-google_workspace*': {'all': ['Data Source: Google Workspace', 'Domain: Cloud']}, - 'logs-cloud_defend.alerts-*': {'all': ['Data Source: Elastic Defend for Containers', 'Domain: Container']}, - 'logs-cloud_defend*': {'all': ['Data Source: Elastic Defend for Containers', 'Domain: Container']}, - 'logs-kubernetes.*': {'all': ['Data Source: Kubernetes']}, - 'apm-*-transaction*': {'all': ['Data Source: APM']}, - 'traces-apm*': {'all': ['Data Source: APM']}, - '.alerts-security.*': {'all': ['Rule Type: Higher-Order Rule']}, - 'logs-cyberarkpas.audit*': {'all': ['Data Source: CyberArk PAS']}, - 'logs-endpoint.alerts-*': {'all': ['Data Source: Elastic Defend']}, - 'logs-windows.sysmon_operational-*': {'all': ['Data Source: Sysmon']}, - 'logs-windows.powershell*': {'all': ['Data Source: PowerShell Logs']}, - 'logs-system.security*': {'all': ['Data Source: Windows Security Event Logs']}, - 'logs-system.forwarded*': {'all': ['Data Source: Windows Security Event Logs']}, - 'logs-system.system*': {'all': ['Data Source: Windows System Event Logs']}, - 'logs-sentinel_one_cloud_funnel.*': {'all': ['Data Source: SentinelOne']}, - 'logs-fim.event-*': {'all': ['Data Source: File Integrity Monitoring']}, - 'logs-m365_defender.event-*': {'all': ['Data Source: Microsoft Defender for Endpoint']}, - 'logs-crowdstrike.fdr*': {'all': ['Data Source: Crowdstrike']} + "logs-endpoint.events.*": {"all": ["Domain: Endpoint", "Data Source: Elastic Defend"]}, + "endgame-*": {"all": ["Data Source: Elastic Endgame"]}, + "logs-aws*": {"all": ["Data Source: AWS", "Data Source: Amazon Web Services", "Domain: Cloud"]}, + "logs-azure*": {"all": ["Data Source: Azure", "Domain: Cloud"]}, + "logs-o365*": {"all": ["Data Source: Microsoft 365", "Domain: Cloud"]}, + "logs-okta*": {"all": ["Data Source: Okta"]}, + "logs-gcp*": {"all": ["Data Source: Google Cloud Platform", "Data Source: GCP", "Domain: Cloud"]}, + "logs-google_workspace*": {"all": ["Data Source: Google Workspace", "Domain: Cloud"]}, + "logs-cloud_defend.alerts-*": {"all": ["Data Source: Elastic Defend for Containers", "Domain: Container"]}, + "logs-cloud_defend*": {"all": ["Data Source: Elastic Defend for Containers", "Domain: Container"]}, + "logs-kubernetes.*": {"all": ["Data Source: Kubernetes"]}, + "apm-*-transaction*": {"all": ["Data Source: APM"]}, + "traces-apm*": {"all": ["Data Source: APM"]}, + ".alerts-security.*": {"all": ["Rule Type: Higher-Order Rule"]}, + "logs-cyberarkpas.audit*": {"all": ["Data Source: CyberArk PAS"]}, + "logs-endpoint.alerts-*": {"all": ["Data Source: Elastic Defend"]}, + "logs-windows.sysmon_operational-*": {"all": ["Data Source: Sysmon"]}, + "logs-windows.powershell*": {"all": ["Data Source: PowerShell Logs"]}, + "logs-system.security*": {"all": ["Data Source: Windows Security Event Logs"]}, + "logs-system.forwarded*": {"all": ["Data Source: Windows Security Event Logs"]}, + "logs-system.system*": {"all": ["Data Source: Windows System Event Logs"]}, + "logs-sentinel_one_cloud_funnel.*": {"all": ["Data Source: SentinelOne"]}, + "logs-fim.event-*": {"all": ["Data Source: File Integrity Monitoring"]}, + "logs-m365_defender.event-*": {"all": ["Data Source: Microsoft Defender for Endpoint"]}, + "logs-crowdstrike.fdr*": {"all": ["Data Source: Crowdstrike"]}, } for rule in self.all_rules: rule_tags = rule.contents.data.tags - error_msg = f'{self.rule_str(rule)} Missing tags:\nActual tags: {", ".join(rule_tags)}' + error_msg = f"{self.rule_str(rule)} Missing tags:\nActual tags: {', '.join(rule_tags)}" consolidated_optional_tags = [] is_missing_any_tags = False missing_required_tags = set() if isinstance(rule.contents.data, QueryRuleData): - for index in rule.contents.data.get('index') or []: + for index in rule.contents.data.get("index") or []: expected_tags = required_tags_map.get(index, {}) - expected_all = expected_tags.get('all', []) - expected_any = expected_tags.get('any', []) + expected_all = expected_tags.get("all", []) + expected_any = expected_tags.get("any", []) existing_any_tags = [t for t in rule_tags if t in expected_any] if expected_any: @@ -360,8 +403,8 @@ def test_required_tags(self): is_missing_any_tags = expected_any and not set(expected_any) & set(existing_any_tags) consolidated_optional_tags = [t for t in consolidated_optional_tags if t not in missing_required_tags] - error_msg += f'\nMissing all of: {", ".join(missing_required_tags)}' if missing_required_tags else '' - error_msg += f'\nMissing any of: {", " .join(consolidated_optional_tags)}' if is_missing_any_tags else '' + error_msg += f"\nMissing all of: {', '.join(missing_required_tags)}" if missing_required_tags else "" + error_msg += f"\nMissing any of: {', '.join(consolidated_optional_tags)}" if is_missing_any_tags else "" if missing_required_tags or is_missing_any_tags: self.fail(error_msg) @@ -370,11 +413,11 @@ def test_bbr_tags(self): """Test that "Rule Type: BBR" tag is present for all BBR rules.""" invalid_bbr_rules = [] for rule in self.bbr: - if 'Rule Type: BBR' not in rule.contents.data.tags: + if "Rule Type: BBR" not in rule.contents.data.tags: invalid_bbr_rules.append(self.rule_str(rule)) if invalid_bbr_rules: - error_rules = '\n'.join(invalid_bbr_rules) - self.fail(f'The following building block rule(s) have missing tag: Rule Type: BBR:\n{error_rules}') + error_rules = "\n".join(invalid_bbr_rules) + self.fail(f"The following building block rule(s) have missing tag: Rule Type: BBR:\n{error_rules}") def test_primary_tactic_as_tag(self): """Test that the primary tactic is present as a tag.""" @@ -386,7 +429,7 @@ def test_primary_tactic_as_tag(self): for rule in self.all_rules: rule_tags = rule.contents.data.tags - if 'Continuous Monitoring' in rule_tags or rule.contents.data.type == 'machine_learning': + if "Continuous Monitoring" in rule_tags or rule.contents.data.type == "machine_learning": continue threat = rule.contents.data.threat @@ -406,37 +449,33 @@ def test_primary_tactic_as_tag(self): if missing or missing_from_threat: err_msg = self.rule_str(rule) if missing: - err_msg += f'\n expected: {missing}' + err_msg += f"\n expected: {missing}" if missing_from_threat: - err_msg += f'\n unexpected (or missing from threat mapping): {missing_from_threat}' + err_msg += f"\n unexpected (or missing from threat mapping): {missing_from_threat}" invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with misaligned tags and tactics:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with misaligned tags and tactics:\n{err_msg}") def test_os_tags(self): """Test that OS tags are present within rules.""" - required_tags_map = { - 'linux': 'OS: Linux', - 'macos': 'OS: macOS', - 'windows': 'OS: Windows' - } + required_tags_map = {"linux": "OS: Linux", "macos": "OS: macOS", "windows": "OS: Windows"} invalid = [] for rule in self.all_rules: dir_name = rule.path.parent.name # if directory name is linux, macos, or windows, # ensure the rule has the corresponding tag - if dir_name in ['linux', 'macos', 'windows']: + if dir_name in ["linux", "macos", "windows"]: if required_tags_map[dir_name] not in rule.contents.data.tags: err_msg = self.rule_str(rule) - err_msg += f'\n expected: {required_tags_map[dir_name]}' + err_msg += f"\n expected: {required_tags_map[dir_name]}" invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with missing OS tags:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with missing OS tags:\n{err_msg}") def test_ml_rule_type_tags(self): """Test that ML rule type tags are present within rules.""" @@ -445,36 +484,36 @@ def test_ml_rule_type_tags(self): for rule in self.all_rules: rule_tags = rule.contents.data.tags - if rule.contents.data.type == 'machine_learning': - if 'Rule Type: Machine Learning' not in rule_tags: + if rule.contents.data.type == "machine_learning": + if "Rule Type: Machine Learning" not in rule_tags: err_msg = self.rule_str(rule) - err_msg += '\n expected: Rule Type: Machine Learning' + err_msg += "\n expected: Rule Type: Machine Learning" invalid.append(err_msg) - if 'Rule Type: ML' not in rule_tags: + if "Rule Type: ML" not in rule_tags: err_msg = self.rule_str(rule) - err_msg += '\n expected: Rule Type: ML' + err_msg += "\n expected: Rule Type: ML" invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with misaligned ML rule type tags:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with misaligned ML rule type tags:\n{err_msg}") def test_investigation_guide_tag(self): """Test that investigation guide tags are present within rules.""" invalid = [] for rule in self.all_rules: - note = rule.contents.data.get('note') + note = rule.contents.data.get("note") if note is not None: - results = re.search(r'Investigating', note, re.M) + results = re.search(r"Investigating", note, re.M) if results is not None: # check if investigation guide tag is present - if 'Resources: Investigation Guide' not in rule.contents.data.tags: + if "Resources: Investigation Guide" not in rule.contents.data.tags: err_msg = self.rule_str(rule) - err_msg += '\n expected: Resources: Investigation Guide' + err_msg += "\n expected: Resources: Investigation Guide" invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with missing Investigation tag:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with missing Investigation tag:\n{err_msg}") def test_tag_prefix(self): """Ensure all tags have a prefix from an expected list.""" @@ -483,10 +522,13 @@ def test_tag_prefix(self): for rule in self.all_rules: rule_tags = rule.contents.data.tags expected_prefixes = set([tag.split(":")[0] + ":" for tag in definitions.EXPECTED_RULE_TAGS]) - [invalid.append(f"{self.rule_str(rule)}-{tag}") for tag in rule_tags - if not any(prefix in tag for prefix in expected_prefixes)] + [ + invalid.append(f"{self.rule_str(rule)}-{tag}") + for tag in rule_tags + if not any(prefix in tag for prefix in expected_prefixes) + ] if invalid: - self.fail(f'Rules with invalid tags:\n{invalid}') + self.fail(f"Rules with invalid tags:\n{invalid}") def test_no_duplicate_tags(self): """Ensure no rules have duplicate tags.""" @@ -498,7 +540,7 @@ def test_no_duplicate_tags(self): invalid.append(self.rule_str(rule)) if invalid: - self.fail(f'Rules with duplicate tags:\n{invalid}') + self.fail(f"Rules with duplicate tags:\n{invalid}") class TestRuleTimelines(BaseRuleTest): @@ -517,14 +559,15 @@ def test_timeline_has_title(self): self.fail(missing_err) if timeline_id: - unknown_id = f'{self.rule_str(rule)} Unknown timeline_id: {timeline_id}.' - unknown_id += f' replace with {", ".join(TIMELINE_TEMPLATES)} ' \ - f'or update this unit test with acceptable ids' + unknown_id = f"{self.rule_str(rule)} Unknown timeline_id: {timeline_id}." + unknown_id += ( + f" replace with {', '.join(TIMELINE_TEMPLATES)} or update this unit test with acceptable ids" + ) self.assertIn(timeline_id, list(TIMELINE_TEMPLATES), unknown_id) - unknown_title = f'{self.rule_str(rule)} unknown timeline_title: {timeline_title}' - unknown_title += f' replace with {", ".join(TIMELINE_TEMPLATES.values())}' - unknown_title += ' or update this unit test with acceptable titles' + unknown_title = f"{self.rule_str(rule)} unknown timeline_title: {timeline_title}" + unknown_title += f" replace with {', '.join(TIMELINE_TEMPLATES.values())}" + unknown_title += " or update this unit test with acceptable titles" self.assertEqual(timeline_title, TIMELINE_TEMPLATES[timeline_id], unknown_title) @@ -546,39 +589,48 @@ def test_rule_file_name_tactic(self): threat = rule.contents.data.threat authors = rule.contents.data.author - if threat and 'Elastic' in authors: + if threat and "Elastic" in authors: primary_tactic = threat[0].tactic.name - tactic_str = primary_tactic.lower().replace(' ', '_') + tactic_str = primary_tactic.lower().replace(" ", "_") - if tactic_str != filename[:len(tactic_str)]: - bad_name_rules.append(f'{rule.id} - {Path(rule.path).name} -> expected: {tactic_str}') + if tactic_str != filename[: len(tactic_str)]: + bad_name_rules.append(f"{rule.id} - {Path(rule.path).name} -> expected: {tactic_str}") if bad_name_rules: - error_msg = 'filename does not start with the primary tactic - update the tactic or the rule filename' - rule_err_str = '\n'.join(bad_name_rules) - self.fail(f'{error_msg}:\n{rule_err_str}') + error_msg = "filename does not start with the primary tactic - update the tactic or the rule filename" + rule_err_str = "\n".join(bad_name_rules) + self.fail(f"{error_msg}:\n{rule_err_str}") def test_bbr_in_correct_dir(self): """Ensure that BBR are in the correct directory.""" for rule in self.bbr: # Is the rule a BBR - self.assertEqual(rule.contents.data.building_block_type, 'default', - f'{self.rule_str(rule)} should have building_block_type = "default"') + self.assertEqual( + rule.contents.data.building_block_type, + "default", + f'{self.rule_str(rule)} should have building_block_type = "default"', + ) # Is the rule in the rules_building_block directory - self.assertEqual(rule.path.parent.name, 'rules_building_block', - f'{self.rule_str(rule)} should be in the rules_building_block directory') + self.assertEqual( + rule.path.parent.name, + "rules_building_block", + f"{self.rule_str(rule)} should be in the rules_building_block directory", + ) def test_non_bbr_in_correct_dir(self): """Ensure that non-BBR are not in BBR directory.""" - proper_directory = 'rules_building_block' + proper_directory = "rules_building_block" for rule in self.all_rules: - if rule.path.parent.name == 'rules_building_block': - self.assertIn(rule, self.bbr, f'{self.rule_str(rule)} should be in the {proper_directory}') + if rule.path.parent.name == "rules_building_block": + self.assertIn(rule, self.bbr, f"{self.rule_str(rule)} should be in the {proper_directory}") else: # Is the rule of type BBR and not in the correct directory - self.assertEqual(rule.contents.data.building_block_type, None, - f'{self.rule_str(rule)} should be in {proper_directory}') + self.assertEqual( + rule.contents.data.building_block_type, + None, + f"{self.rule_str(rule)} should be in {proper_directory}", + ) class TestRuleMetadata(BaseRuleTest): @@ -589,14 +641,14 @@ def test_updated_date_newer_than_creation(self): invalid = [] for rule in self.all_rules: - created = rule.contents.metadata.creation_date.split('/') - updated = rule.contents.metadata.updated_date.split('/') + created = rule.contents.metadata.creation_date.split("/") + updated = rule.contents.metadata.updated_date.split("/") if updated < created: invalid.append(rule) if invalid: - rules_str = '\n '.join(self.rule_str(r, trailer=None) for r in invalid) - err_msg = f'The following rules have an updated_date older than the creation_date\n {rules_str}' + rules_str = "\n ".join(self.rule_str(r, trailer=None) for r in invalid) + err_msg = f"The following rules have an updated_date older than the creation_date\n {rules_str}" self.fail(err_msg) @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Skipping deprecated version lock check") @@ -614,33 +666,39 @@ def test_deprecated_rules(self): misplaced_rules.append(r) else: for rules_path in rules_paths: - if "_deprecated" in r.path.relative_to(rules_path).parts \ - and r.contents.metadata.maturity != "deprecated": + if ( + "_deprecated" in r.path.relative_to(rules_path).parts + and r.contents.metadata.maturity != "deprecated" + ): misplaced_rules.append(r) break - misplaced = '\n'.join(f'{self.rule_str(r)} {r.contents.metadata.maturity}' for r in misplaced_rules) - err_str = f'The following rules are stored in _deprecated but are not marked as deprecated:\n{misplaced}' + misplaced = "\n".join(f"{self.rule_str(r)} {r.contents.metadata.maturity}" for r in misplaced_rules) + err_str = f"The following rules are stored in _deprecated but are not marked as deprecated:\n{misplaced}" self.assertListEqual(misplaced_rules, [], err_str) for rule in self.deprecated_rules: meta = rule.contents.metadata deprecated_rules[rule.id] = rule - err_msg = f'{self.rule_str(rule)} cannot be deprecated if it has not been version locked. ' \ - f'Convert to `development` or delete the rule file instead' + err_msg = ( + f"{self.rule_str(rule)} cannot be deprecated if it has not been version locked. " + f"Convert to `development` or delete the rule file instead" + ) self.assertIn(rule.id, versions, err_msg) rule_path = rule.path.relative_to(rules_path) - err_msg = f'{self.rule_str(rule)} deprecated rules should be stored in ' \ - f'"{rule_path.parent / "_deprecated"}" folder' - self.assertEqual('_deprecated', rule_path.parts[-2], err_msg) + err_msg = ( + f"{self.rule_str(rule)} deprecated rules should be stored in " + f'"{rule_path.parent / "_deprecated"}" folder' + ) + self.assertEqual("_deprecated", rule_path.parts[-2], err_msg) - err_msg = f'{self.rule_str(rule)} missing deprecation date' - self.assertIsNotNone(meta['deprecation_date'], err_msg) + err_msg = f"{self.rule_str(rule)} missing deprecation date" + self.assertIsNotNone(meta["deprecation_date"], err_msg) - err_msg = f'{self.rule_str(rule)} deprecation_date and updated_date should match' - self.assertEqual(meta['deprecation_date'], meta['updated_date'], err_msg) + err_msg = f"{self.rule_str(rule)} deprecation_date and updated_date should match" + self.assertEqual(meta["deprecation_date"], meta["updated_date"], err_msg) # skip this so the lock file can be shared across branches # @@ -656,16 +714,16 @@ def test_deprecated_rules(self): # will exist in the deprecated_rules.json file and not be in the _deprecated folder - this is expected. # However, that should not occur except by exception - the proper way to handle this situation is to # "fork" the existing rule by adding a new min_stack_version. - if PACKAGE_STACK_VERSION < Version.parse(entry['stack_version'], optional_minor_and_patch=True): + if PACKAGE_STACK_VERSION < Version.parse(entry["stack_version"], optional_minor_and_patch=True): continue - rule_str = f'{rule_id} - {entry["rule_name"]} ->' + rule_str = f"{rule_id} - {entry['rule_name']} ->" self.assertIn(rule_id, deprecated_rules, f'{rule_str} is logged in "deprecated_rules.json" but is missing') def test_deprecated_rules_modified(self): """Test to ensure deprecated rules are not modified.""" - rules_path = get_path("rules", "_deprecated") + rules_path = get_path(["rules", "_deprecated"]) # Use git diff to check if the file(s) has been modified in rules/_deprecated directory detection_rules_git = make_git() @@ -675,13 +733,12 @@ def test_deprecated_rules_modified(self): if result: self.fail(f"Deprecated rules {result} has been modified") - @unittest.skipIf(os.getenv('GITHUB_EVENT_NAME') == 'push', - "Skipping this test when not running on pull requests.") + @unittest.skipIf(os.getenv("GITHUB_EVENT_NAME") == "push", "Skipping this test when not running on pull requests.") def test_rule_change_has_updated_date(self): """Test to ensure modified rules have updated_date field updated.""" - rules_path = get_path("rules") - rules_bbr_path = get_path("rules_building_block") + rules_path = get_path(["rules"]) + rules_bbr_path = get_path(["rules_building_block"]) # Use git diff to check if the file(s) has been modified in rules/ rules_build_block/ directories. # For now this checks even rules/_deprecated any modification there will fail @@ -689,120 +746,150 @@ def test_rule_change_has_updated_date(self): # is not required as there is a specific test for deprecated rules. detection_rules_git = make_git() - result = detection_rules_git("diff", "--diff-filter=M", "origin/main", "--name-only", - rules_path, rules_bbr_path) + result = detection_rules_git( + "diff", "--diff-filter=M", "origin/main", "--name-only", rules_path, rules_bbr_path + ) # If the output is not empty, then file(s) have changed in the directory(s) if result: modified_rules = result.splitlines() failed_rules = [] for modified_rule_path in modified_rules: - diff_output = detection_rules_git('diff', 'origin/main', modified_rule_path) - if not re.search(r'\+\s*updated_date =', diff_output): + diff_output = detection_rules_git("diff", "origin/main", modified_rule_path) + if not re.search(r"\+\s*updated_date =", diff_output): # Rule has been modified but updated_date has not been changed, add to list of failed rules - failed_rules.append(f'{modified_rule_path}') + failed_rules.append(f"{modified_rule_path}") if failed_rules: fail_msg = """ The following rules in the below path(s) have been modified but updated_date has not been changed \n """ - self.fail(fail_msg + '\n'.join(failed_rules)) + self.fail(fail_msg + "\n".join(failed_rules)) - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), - "Test only applicable to 8.3+ stacks regarding related integrations build time field.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.3.0"), + "Test only applicable to 8.3+ stacks regarding related integrations build time field.", + ) def test_integration_tag(self): """Test integration rules defined by metadata tag.""" failures = [] non_dataset_packages = definitions.NON_DATASET_PACKAGES + ["winlog"] packages_manifest = load_integrations_manifests() - valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != 'endpoint'] + valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != "endpoint"] for rule in self.all_rules: # TODO: temp bypass for esql rules; once parsed, we should be able to look for indexes via `FROM` - if not rule.contents.data.get('index'): + if not rule.contents.data.get("index"): continue - if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != 'lucene': - rule_integrations = rule.contents.metadata.get('integration') or [] + if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != "lucene": + rule_integrations = rule.contents.metadata.get("integration") or [] rule_integrations = [rule_integrations] if isinstance(rule_integrations, str) else rule_integrations - rule_promotion = rule.contents.metadata.get('promotion') + rule_promotion = rule.contents.metadata.get("promotion") data = rule.contents.data meta = rule.contents.metadata package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) package_integrations_list = list(set([integration["package"] for integration in package_integrations])) - indices = data.get('index') or [] + indices = data.get("index") or [] for rule_integration in rule_integrations: - if ("even.dataset" in rule.contents.data.query and not package_integrations and # noqa: W504 - not rule_promotion and rule_integration not in definitions.NON_DATASET_PACKAGES): # noqa: W504 - err_msg = f'{self.rule_str(rule)} {rule_integration} tag, but integration not \ - found in manifests/schemas.' + if ( + "even.dataset" in rule.contents.data.query + and not package_integrations # noqa: W504 + and not rule_promotion + and rule_integration not in definitions.NON_DATASET_PACKAGES + ): # noqa: W504 + err_msg = f"{self.rule_str(rule)} {rule_integration} tag, but integration not \ + found in manifests/schemas." failures.append(err_msg) # checks if the rule path matches the intended integration # excludes BBR rules - if rule_integration in valid_integration_folders and \ - not hasattr(rule.contents.data, 'building_block_type'): + if rule_integration in valid_integration_folders and not hasattr( + rule.contents.data, "building_block_type" + ): if rule.path.parent.name not in rule_integrations: - err_msg = f'{self.rule_str(rule)} {rule_integration} tag, path is {rule.path.parent.name}' + err_msg = f"{self.rule_str(rule)} {rule_integration} tag, path is {rule.path.parent.name}" failures.append(err_msg) # checks if an index pattern exists if the package integration tag exists # and is of pattern logs-{integration}* integration_string = "|".join(indices) if not re.search(f"logs-{rule_integration}*", integration_string): - if rule_integration == "windows" and re.search("winlog", integration_string) or \ - any(ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] - for ri in rule_integrations): + if ( + rule_integration == "windows" + and re.search("winlog", integration_string) + or any( + ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] + for ri in rule_integrations + ) + ): continue - elif rule_integration == "apm" and \ - re.search("apm-*-transaction*|traces-apm*", integration_string): + elif rule_integration == "apm" and re.search( + "apm-*-transaction*|traces-apm*", integration_string + ): continue - elif rule.contents.data.type == 'threat_match': + elif rule.contents.data.type == "threat_match": continue - err_msg = f'{self.rule_str(rule)} {rule_integration} tag, index pattern missing or incorrect.' + err_msg = f"{self.rule_str(rule)} {rule_integration} tag, index pattern missing or incorrect." failures.append(err_msg) # checks if event.dataset exists in query object and a tag exists in metadata # checks if metadata tag matches from a list of integrations in EPR if package_integrations and sorted(rule_integrations) != sorted(package_integrations_list): - err_msg = f'{self.rule_str(rule)} integration tags: {rule_integrations} != ' \ - f'package integrations: {package_integrations_list}' + err_msg = ( + f"{self.rule_str(rule)} integration tags: {rule_integrations} != " + f"package integrations: {package_integrations_list}" + ) failures.append(err_msg) else: # checks if rule has index pattern integration and the integration tag exists # ignore the External Alerts rule, Threat Indicator Matching Rules, Guided onboarding - if any([re.search("|".join(non_dataset_packages), i, re.IGNORECASE) - for i in rule.contents.data.get('index') or []]): - if not rule.contents.metadata.integration and rule.id not in definitions.IGNORE_IDS and \ - rule.contents.data.type not in definitions.MACHINE_LEARNING: - err_msg = f'substrings {non_dataset_packages} found in '\ - f'{self.rule_str(rule)} rule index patterns are {rule.contents.data.index},' \ - f'but no integration tag found' + if any( + [ + re.search("|".join(non_dataset_packages), i, re.IGNORECASE) + for i in rule.contents.data.get("index") or [] + ] + ): + if ( + not rule.contents.metadata.integration + and rule.id not in definitions.IGNORE_IDS + and rule.contents.data.type not in definitions.MACHINE_LEARNING + ): + err_msg = ( + f"substrings {non_dataset_packages} found in " + f"{self.rule_str(rule)} rule index patterns are {rule.contents.data.index}," + f"but no integration tag found" + ) failures.append(err_msg) # checks for a defined index pattern, the related integration exists in metadata expected_integrations, missing_integrations = set(), set() - ignore_ml_packages = any(ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] - for ri in rule_integrations) + ignore_ml_packages = any( + ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] for ri in rule_integrations + ) for index in indices: - if index in definitions.IGNORE_INDICES or ignore_ml_packages or \ - rule.id in definitions.IGNORE_IDS or rule.contents.data.type == 'threat_match': + if ( + index in definitions.IGNORE_INDICES + or ignore_ml_packages + or rule.id in definitions.IGNORE_IDS + or rule.contents.data.type == "threat_match" + ): continue # Outlier integration log pattern to identify integration - if index == 'apm-*-transaction*': - index_map = ['apm'] + if index == "apm-*-transaction*": + index_map = ["apm"] else: # Split by hyphen to get the second part of index - index_part1, _, index_part2 = index.partition('-') + index_part1, _, index_part2 = index.partition("-") # Use regular expression to extract alphanumeric words, which is integration name - parsed_integration = re.search(r'\b\w+\b', index_part2 or index_part1) + parsed_integration = re.search(r"\b\w+\b", index_part2 or index_part1) index_map = [parsed_integration.group(0) if parsed_integration else None] if not index_map: - self.fail(f'{self.rule_str(rule)} Could not determine the integration from Index {index}') + self.fail(f"{self.rule_str(rule)} Could not determine the integration from Index {index}") expected_integrations.update(index_map) missing_integrations.update(expected_integrations.difference(set(rule_integrations))) if missing_integrations: - error_msg = f'{self.rule_str(rule)} Missing integration metadata: {", ".join(missing_integrations)}' + error_msg = f"{self.rule_str(rule)} Missing integration metadata: {', '.join(missing_integrations)}" failures.append(error_msg) if failures: @@ -811,7 +898,7 @@ def test_integration_tag(self): Try updating the integrations manifest file: - `python -m detection_rules dev integrations build-manifests`\n """ - self.fail(err_msg + '\n'.join(failures)) + self.fail(err_msg + "\n".join(failures)) def test_invalid_queries(self): invalid_queries_eql = [ @@ -836,7 +923,7 @@ def test_invalid_queries(self): "p7r", "p12", "asc", "jks", "p7b", "signature", "gpg", "pgp.sig", "sst", "pgp", "gpgz", "pfx", "crt", "p8", "sig", "pkcs7", "jceks", "pkcs8", "psc1", "p7c", "csr", "cer", "spc", "ps2xml") - """ + """, ] valid_queries_eql = [ @@ -851,8 +938,7 @@ def test_invalid_queries(self): "token","assig", "pssc", "keystore", "pub", "pgp.asc", "ps1xml", "pem", "gpg.sig", "der", "key", "p7r", "p12", "asc", "jks", "p7b", "signature", "gpg", "pgp.sig", "sst", "pgp", "gpgz", "pfx", "p8", "sig", "pkcs7", "jceks", "pkcs8", "psc1", "p7c", "csr", "cer", "spc", "ps2xml") - """ - + """, ] invalid_queries_kql = [ @@ -875,8 +961,7 @@ def test_invalid_queries(self): """, """ event.dataset:"google_workspace.admin" and event.action:"CREATE_DATA_TRANSFER_REQUEST" - """ - + """, ] base_fields_eql = { @@ -889,7 +974,7 @@ def test_invalid_queries(self): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "eql" + "type": "eql", } base_fields_kql = { @@ -902,7 +987,7 @@ def test_invalid_queries(self): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "query" + "type": "query", } def build_rule(query: str, query_language: str): @@ -912,7 +997,7 @@ def build_rule(query: str, query_language: str): "updated_date": "1970/01/01", "query_schema_validation": True, "maturity": "production", - "min_stack_version": load_current_package_version() + "min_stack_version": load_current_package_version(), } if query_language == "eql": data = base_fields_eql.copy() @@ -921,6 +1006,7 @@ def build_rule(query: str, query_language: str): data["query"] = query obj = {"metadata": metadata, "rule": data} return TOMLRuleContents.from_dict(obj) + # eql for query in valid_queries_eql: build_rule(query, "eql") @@ -957,7 +1043,7 @@ def test_event_dataset(self): data = rule.contents.data meta = rule.contents.metadata if meta.query_schema_validation is not False or meta.maturity != "deprecated": - if isinstance(data, QueryRuleData) and data.language != 'lucene': + if isinstance(data, QueryRuleData) and data.language != "lucene": packages_manifest = load_integrations_manifests() pkg_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) @@ -965,9 +1051,9 @@ def test_event_dataset(self): if pkg_integrations: # validate the query against related integration fields - validation_integrations_check = test_validator.validate_integration(data, - meta, - pkg_integrations) + validation_integrations_check = test_validator.validate_integration( + data, meta, pkg_integrations + ) if validation_integrations_check and "event.dataset" in rule.contents.data.query: raise validation_integrations_check @@ -976,63 +1062,37 @@ def test_event_dataset(self): class TestIntegrationRules(BaseRuleTest): """Test integration rules.""" - @unittest.skip("8.3+ Stacks Have Related Integrations Feature") - def test_integration_guide(self): - """Test that rules which require a config note are using standard verbiage.""" - config = '## Setup\n\n' - beats_integration_pattern = config + 'The {} Fleet integration, Filebeat module, or similarly ' \ - 'structured data is required to be compatible with this rule.' - render = beats_integration_pattern.format - integration_notes = { - 'aws': render('AWS'), - 'azure': render('Azure'), - 'cyberarkpas': render('CyberArk Privileged Access Security (PAS)'), - 'gcp': render('GCP'), - 'google_workspace': render('Google Workspace'), - 'o365': render('Office 365 Logs'), - 'okta': render('Okta'), - } - - for rule in self.all_rules: - integration = rule.contents.metadata.integration - note_str = integration_notes.get(integration) - - if note_str: - error_message = f'{self.rule_str(rule)} note required for config information' - self.assertIsNotNone(rule.contents.data.note, error_message) - - if note_str not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n' - f'Expected: {note_str}\n\n' - f'Actual: {rule.contents.data.note}') - def test_rule_demotions(self): """Test to ensure a locked rule is not dropped to development, only deprecated""" versions = loaded_version_lock.version_lock failures = [] for rule in self.all_rules: - if rule.id in versions and rule.contents.metadata.maturity not in ('production', 'deprecated'): - err_msg = f'{self.rule_str(rule)} a version locked rule can only go from production to deprecated\n' - err_msg += f'Actual: {rule.contents.metadata.maturity}' + if rule.id in versions and rule.contents.metadata.maturity not in ("production", "deprecated"): + err_msg = f"{self.rule_str(rule)} a version locked rule can only go from production to deprecated\n" + err_msg += f"Actual: {rule.contents.metadata.maturity}" failures.append(err_msg) if failures: - err_msg = '\n'.join(failures) - self.fail(f'The following rules have been improperly demoted:\n{err_msg}') + err_msg = "\n".join(failures) + self.fail(f"The following rules have been improperly demoted:\n{err_msg}") def test_all_min_stack_rules_have_comment(self): failures = [] for rule in self.all_rules: if rule.contents.metadata.min_stack_version and not rule.contents.metadata.min_stack_comments: - failures.append(f'{self.rule_str(rule)} missing `metadata.min_stack_comments`. min_stack_version: ' - f'{rule.contents.metadata.min_stack_version}') + failures.append( + f"{self.rule_str(rule)} missing `metadata.min_stack_comments`. min_stack_version: " + f"{rule.contents.metadata.min_stack_version}" + ) if failures: - err_msg = '\n'.join(failures) - self.fail(f'The following ({len(failures)}) rules have a `min_stack_version` defined but missing comments:' - f'\n{err_msg}') + err_msg = "\n".join(failures) + self.fail( + f"The following ({len(failures)}) rules have a `min_stack_version` defined but missing comments:" + f"\n{err_msg}" + ) def test_ml_integration_jobs_exist(self): """Test that machine learning jobs exist in the integration.""" @@ -1044,36 +1104,37 @@ def test_ml_integration_jobs_exist(self): for rule in self.all_rules: if rule.contents.data.type == "machine_learning": - ml_integration_name = next((i for i in rule.contents.metadata.integration - if i in ml_integration_names), None) + ml_integration_name = next( + (i for i in rule.contents.metadata.integration if i in ml_integration_names), None + ) if ml_integration_name: if "machine_learning_job_id" not in dir(rule.contents.data): - failures.append(f'{self.rule_str(rule)} missing `machine_learning_job_id`') + failures.append(f"{self.rule_str(rule)} missing `machine_learning_job_id`") else: rule_job_id = rule.contents.data.machine_learning_job_id ml_schema = integration_schemas.get(ml_integration_name) min_version = Version.parse( rule.contents.metadata.min_stack_version or load_current_package_version(), - optional_minor_and_patch=True + optional_minor_and_patch=True, ) latest_compat_ver = find_latest_compatible_version( package=ml_integration_name, integration="", rule_stack_version=min_version, - packages_manifest=integration_manifests + packages_manifest=integration_manifests, ) compat_integration_schema = ml_schema[latest_compat_ver[0]] - if rule_job_id not in compat_integration_schema['jobs']: + if rule_job_id not in compat_integration_schema["jobs"]: failures.append( - f'{self.rule_str(rule)} machine_learning_job_id `{rule_job_id}` not found ' - f'in version `{latest_compat_ver[0]}` of `{ml_integration_name}` integration. ' - f'existing jobs: {compat_integration_schema["jobs"]}' + f"{self.rule_str(rule)} machine_learning_job_id `{rule_job_id}` not found " + f"in version `{latest_compat_ver[0]}` of `{ml_integration_name}` integration. " + f"existing jobs: {compat_integration_schema['jobs']}" ) if failures: - err_msg = '\n'.join(failures) + err_msg = "\n".join(failures) self.fail( - f'The following ({len(failures)}) rules are missing a valid `machine_learning_job_id`:\n{err_msg}' + f"The following ({len(failures)}) rules are missing a valid `machine_learning_job_id`:\n{err_msg}" ) @@ -1095,24 +1156,24 @@ def test_event_override(self): # QueryRuleData should inherently ignore machine learning rules if isinstance(rule.contents.data, QueryRuleData): rule_language = rule.contents.data.language - has_event_ingested = rule.contents.data.get('timestamp_override') == 'event.ingested' + has_event_ingested = rule.contents.data.get("timestamp_override") == "event.ingested" rule_str = self.rule_str(rule, trailer=None) if not has_event_ingested: # TODO: determine if we expand this to ES|QL # ignores any rule that does not use EQL or KQL queries specifically # this does not avoid rule types where variants of KQL are used (e.g. new terms) - if rule_language not in ('eql', 'kuery') or getattr(rule.contents.data, 'is_sequence', False): + if rule_language not in ("eql", "kuery") or getattr(rule.contents.data, "is_sequence", False): continue else: - errors.append(f'{rule_str} - rule must have `timestamp_override: event.ingested`') + errors.append(f"{rule_str} - rule must have `timestamp_override: event.ingested`") if errors: - self.fail('The following rules are invalid:\n' + '\n'.join(errors)) + self.fail("The following rules are invalid:\n" + "\n".join(errors)) def test_required_lookback(self): """Ensure endpoint rules have the proper lookback time.""" - long_indexes = {'logs-endpoint.events.*'} + long_indexes = {"logs-endpoint.events.*"} missing = [] for rule in self.all_rules: @@ -1123,8 +1184,8 @@ def test_required_lookback(self): missing.append(rule) if missing: - rules_str = '\n '.join(self.rule_str(r, trailer=None) for r in missing) - err_msg = f'The following rules should have a longer `from` defined, due to indexes used\n {rules_str}' + rules_str = "\n ".join(self.rule_str(r, trailer=None) for r in missing) + err_msg = f"The following rules should have a longer `from` defined, due to indexes used\n {rules_str}" self.fail(err_msg) def test_eql_lookback(self): @@ -1134,8 +1195,8 @@ def test_eql_lookback(self): ten_minutes = 10 * 60 * 1000 for rule in self.all_rules: - if rule.contents.data.type == 'eql' and rule.contents.data.max_span: - if rule.contents.data.look_back == 'unknown': + if rule.contents.data.type == "eql" and rule.contents.data.max_span: + if rule.contents.data.look_back == "unknown": unknowns.append(self.rule_str(rule, trailer=None)) else: look_back = rule.contents.data.look_back @@ -1143,16 +1204,17 @@ def test_eql_lookback(self): expected = look_back + ten_minutes if expected < max_span: - invalids.append(f'{self.rule_str(rule)} lookback: {look_back}, maxspan: {max_span}, ' - f'expected: >={expected}') + invalids.append( + f"{self.rule_str(rule)} lookback: {look_back}, maxspan: {max_span}, expected: >={expected}" + ) if unknowns: - warn_str = '\n'.join(unknowns) - warnings.warn(f'Unable to determine lookbacks for the following rules:\n{warn_str}') + warn_str = "\n".join(unknowns) + warnings.warn(f"Unable to determine lookbacks for the following rules:\n{warn_str}") if invalids: - invalids_str = '\n'.join(invalids) - self.fail(f'The following rules have longer max_spans than lookbacks:\n{invalids_str}') + invalids_str = "\n".join(invalids) + self.fail(f"The following rules have longer max_spans than lookbacks:\n{invalids_str}") def test_eql_interval_to_maxspan(self): """Check the ratio of interval to maxspan for eql rules.""" @@ -1160,33 +1222,34 @@ def test_eql_interval_to_maxspan(self): five_minutes = 5 * 60 * 1000 for rule in self.all_rules: - if rule.contents.data.type == 'eql': + if rule.contents.data.type == "eql": interval = rule.contents.data.interval or five_minutes maxspan = rule.contents.data.max_span ratio = rule.contents.data.interval_ratio # we want to test for at least a ratio of: interval >= 1/2 maxspan # but we only want to make an exception and cap the ratio at 5m interval (2.5m maxspan) - if maxspan and maxspan > (five_minutes / 2) and ratio and ratio < .5: + if maxspan and maxspan > (five_minutes / 2) and ratio and ratio < 0.5: expected = maxspan // 2 - err_msg = f'{self.rule_str(rule)} interval: {interval}, maxspan: {maxspan}, expected: >={expected}' + err_msg = f"{self.rule_str(rule)} interval: {interval}, maxspan: {maxspan}, expected: >={expected}" invalids.append(err_msg) if invalids: - invalids_str = '\n'.join(invalids) - self.fail(f'The following rules have intervals too short for their given max_spans (ms):\n{invalids_str}') + invalids_str = "\n".join(invalids) + self.fail(f"The following rules have intervals too short for their given max_spans (ms):\n{invalids_str}") class TestLicense(BaseRuleTest): """Test rule license.""" - @unittest.skipIf(os.environ.get('CUSTOM_RULES_DIR'), 'Skipping test for custom rules.') + + @unittest.skipIf(os.environ.get("CUSTOM_RULES_DIR"), "Skipping test for custom rules.") def test_elastic_license_only_v2(self): """Test to ensure that production rules with the elastic license are only v2.""" for rule in self.all_rules: rule_license = rule.contents.data.license - if 'elastic license' in rule_license.lower(): - err_msg = f'{self.rule_str(rule)} If Elastic License is used, only v2 should be used' - self.assertEqual(rule_license, 'Elastic License v2', err_msg) + if "elastic license" in rule_license.lower(): + err_msg = f"{self.rule_str(rule)} If Elastic License is used, only v2 should be used" + self.assertEqual(rule_license, "Elastic License v2", err_msg) class TestIncompatibleFields(BaseRuleTest): @@ -1199,11 +1262,11 @@ def test_rule_backports_for_restricted_fields(self): for rule in self.all_rules: invalid = rule.contents.check_restricted_fields_compatibility() if invalid: - invalid_rules.append(f'{self.rule_str(rule)} {invalid}') + invalid_rules.append(f"{self.rule_str(rule)} {invalid}") if invalid_rules: - invalid_str = '\n'.join(invalid_rules) - err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n' + invalid_str = "\n".join(invalid_rules) + err_msg = "The following rules have min_stack_versions lower than allowed for restricted fields:\n" err_msg += invalid_str self.fail(err_msg) @@ -1231,12 +1294,14 @@ def test_build_fields_min_stack(self): # old and unneeded checks. (i.e. 8.3.0 < 8.9.0 min supported, so it is irrelevant now) if start_ver is not None and current_stack_ver >= start_ver >= min_supported_stack_version: if min_stack is None or not Version.parse(min_stack) >= start_ver: - errors.append(f'{build_field} >= {start_ver}') + errors.append(f"{build_field} >= {start_ver}") if errors: - err_str = ', '.join(errors) - invalids.append(f'{self.rule_str(rule)} uses a rule type with build fields requiring min_stack_versions' - f' to be set: {err_str}') + err_str = ", ".join(errors) + invalids.append( + f"{self.rule_str(rule)} uses a rule type with build fields requiring min_stack_versions" + f" to be set: {err_str}" + ) if invalids: self.fail(invalids) @@ -1249,9 +1314,9 @@ def test_rule_risk_score_severity_mismatch(self): invalid_list = [] risk_severity = { "critical": (74, 100), # updated range for critical - "high": (48, 73), # updated range for high - "medium": (22, 47), # updated range for medium - "low": (0, 21), # updated range for low + "high": (48, 73), # updated range for high + "medium": (22, 47), # updated range for medium + "low": (0, 21), # updated range for low } for rule in self.all_rules: severity = rule.contents.data.severity @@ -1260,11 +1325,11 @@ def test_rule_risk_score_severity_mismatch(self): # Check if the risk_score falls within the range for the severity level min_score, max_score = risk_severity[severity] if not min_score <= risk_score <= max_score: - invalid_list.append(f'{self.rule_str(rule)} Severity: {severity}, Risk Score: {risk_score}') + invalid_list.append(f"{self.rule_str(rule)} Severity: {severity}, Risk Score: {risk_score}") if invalid_list: - invalid_str = '\n'.join(invalid_list) - err_msg = 'The following rules have mismatches between Severity and Risk Score field values:\n' + invalid_str = "\n".join(invalid_list) + err_msg = "The following rules have mismatches between Severity and Risk Score field values:\n" err_msg += invalid_str self.fail(err_msg) @@ -1278,9 +1343,9 @@ def test_note_contains_triage_and_analysis(self): for rule in self.all_rules: if ( - not rule.contents.data.is_elastic_rule or # noqa: W504 - rule.contents.data.building_block_type or # noqa: W504 - rule.contents.data.severity in ("medium", "low") + not rule.contents.data.is_elastic_rule # noqa: W504 + or rule.contents.data.building_block_type # noqa: W504 + or rule.contents.data.severity in ("medium", "low") ): # dont enforce continue @@ -1299,16 +1364,16 @@ def test_investigation_guide_uses_rule_name(self): """Check if investigation guide uses rule name in the title.""" errors = [] for rule in self.all_rules.rules: - note = rule.contents.data.get('note') + note = rule.contents.data.get("note") if note is not None: # Check if `### Investigating` is present and if so, # check if it is followed by the rule name. - if '### Investigating' in note: - results = re.search(rf'### Investigating\s+{re.escape(rule.name)}', note, re.I | re.M) + if "### Investigating" in note: + results = re.search(rf"### Investigating\s+{re.escape(rule.name)}", note, re.I | re.M) if results is None: - errors.append(f'{self.rule_str(rule)} investigation guide does not use rule name in the title') + errors.append(f"{self.rule_str(rule)} investigation guide does not use rule name in the title") if errors: - self.fail('\n'.join(errors)) + self.fail("\n".join(errors)) class TestNoteMarkdownPlugins(BaseRuleTest): @@ -1316,50 +1381,57 @@ class TestNoteMarkdownPlugins(BaseRuleTest): def test_note_has_osquery_warning(self): """Test that all rules with osquery entries have the default notification of stack compatibility.""" - osquery_note_pattern = ('> **Note**:\n> This investigation guide uses the [Osquery Markdown Plugin]' - '(https://www.elastic.co/guide/en/security/current/invest-guide-run-osquery.html) ' - 'introduced in Elastic Stack version 8.5.0. Older Elastic Stack versions will display ' - 'unrendered Markdown in this guide.') + osquery_note_pattern = ( + "> **Note**:\n> This investigation guide uses the [Osquery Markdown Plugin]" + "(https://www.elastic.co/guide/en/security/current/invest-guide-run-osquery.html) " + "introduced in Elastic Stack version 8.5.0. Older Elastic Stack versions will display " + "unrendered Markdown in this guide." + ) invest_note_pattern = ( - '> This investigation guide uses the [Investigate Markdown Plugin]' - '(https://www.elastic.co/guide/en/security/current/interactive-investigation-guides.html)' - ' introduced in Elastic Stack version 8.8.0. Older Elastic Stack versions will display ' - 'unrendered Markdown in this guide.') + "> This investigation guide uses the [Investigate Markdown Plugin]" + "(https://www.elastic.co/guide/en/security/current/interactive-investigation-guides.html)" + " introduced in Elastic Stack version 8.8.0. Older Elastic Stack versions will display " + "unrendered Markdown in this guide." + ) for rule in self.all_rules: - if not rule.contents.get('transform'): + if not rule.contents.get("transform"): continue - osquery = rule.contents.transform.get('osquery') + osquery = rule.contents.transform.get("osquery") if osquery and osquery_note_pattern not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} Investigation guides using the Osquery Markdown must contain ' - f'the following note:\n{osquery_note_pattern}') + self.fail( + f"{self.rule_str(rule)} Investigation guides using the Osquery Markdown must contain " + f"the following note:\n{osquery_note_pattern}" + ) - investigate = rule.contents.transform.get('investigate') + investigate = rule.contents.transform.get("investigate") if investigate and invest_note_pattern not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} Investigation guides using the Investigate Markdown must contain ' - f'the following note:\n{invest_note_pattern}') + self.fail( + f"{self.rule_str(rule)} Investigation guides using the Investigate Markdown must contain " + f"the following note:\n{invest_note_pattern}" + ) def test_plugin_placeholders_match_entries(self): """Test that the number of plugin entries match their respective placeholders in note.""" for rule in self.all_rules: - has_transform = rule.contents.get('transform') is not None - has_note = rule.contents.data.get('note') is not None + has_transform = rule.contents.get("transform") is not None + has_note = rule.contents.data.get("note") is not None note = rule.contents.data.note if has_transform: if not has_note: - self.fail(f'{self.rule_str(rule)} transformed defined with no note') + self.fail(f"{self.rule_str(rule)} transformed defined with no note") else: if not has_note: continue note_template = PatchedTemplate(note) - identifiers = [i for i in note_template.get_identifiers() if '_' in i] + identifiers = [i for i in note_template.get_identifiers() if "_" in i] if not has_transform: if identifiers: - self.fail(f'{self.rule_str(rule)} note contains plugin placeholders with no transform entries') + self.fail(f"{self.rule_str(rule)} note contains plugin placeholders with no transform entries") else: continue @@ -1369,38 +1441,40 @@ def test_plugin_placeholders_match_entries(self): note_counts = defaultdict(int) for identifier in identifiers: # "$" is used for other things, so this verifies the pattern of a trailing "_" followed by ints - if '_' not in identifier: + if "_" not in identifier: continue - dash_index = identifier.rindex('_') - if dash_index == len(identifier) or not identifier[dash_index + 1:].isdigit(): + dash_index = identifier.rindex("_") + if dash_index == len(identifier) or not identifier[dash_index + 1 :].isdigit(): continue - plugin, _ = identifier.split('_') + plugin, _ = identifier.split("_") if plugin in transform_counts: note_counts[plugin] += 1 - err_msg = f'{self.rule_str(rule)} plugin entry count mismatch between transform and note' + err_msg = f"{self.rule_str(rule)} plugin entry count mismatch between transform and note" self.assertDictEqual(transform_counts, note_counts, err_msg) def test_if_plugins_explicitly_defined(self): """Check if plugins are explicitly defined with the pattern in note vs using transform.""" for rule in self.all_rules: - note = rule.contents.data.get('note') + note = rule.contents.data.get("note") if note is not None: - results = re.search(r'(!{osquery|!{investigate)', note, re.I | re.M) - err_msg = f'{self.rule_str(rule)} investigation guide plugin pattern detected! Use Transform' + results = re.search(r"(!{osquery|!{investigate)", note, re.I | re.M) + err_msg = f"{self.rule_str(rule)} investigation guide plugin pattern detected! Use Transform" self.assertIsNone(results, err_msg) class TestAlertSuppression(BaseRuleTest): """Test rule alert suppression.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), - "Test only applicable to 8.8+ stacks for rule alert suppression feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.8.0"), + "Test only applicable to 8.8+ stacks for rule alert suppression feature.", + ) def test_group_field_in_schemas(self): """Test to ensure the fields are defined is in ECS/Beats/Integrations schema.""" for rule in self.all_rules: - if rule.contents.data.get('alert_suppression'): + if rule.contents.data.get("alert_suppression"): if isinstance(rule.contents.data.alert_suppression, AlertSuppressionMapping): group_by_fields = rule.contents.data.alert_suppression.group_by elif isinstance(rule.contents.data.alert_suppression, ThresholdAlertSuppression): @@ -1411,8 +1485,8 @@ def test_group_field_in_schemas(self): else: min_stack_version = Version.parse(min_stack_version) integration_tag = rule.contents.metadata.get("integration") - ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] - beats_version = get_stack_schemas()[str(min_stack_version)]['beats'] + ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] + beats_version = get_stack_schemas()[str(min_stack_version)]["beats"] queryvalidator = QueryValidator(rule.contents.data.query) _, _, schema = queryvalidator.get_beats_schema([], beats_version, ecs_version) if integration_tag: @@ -1426,20 +1500,23 @@ def test_group_field_in_schemas(self): schema.update(**int_schema[data_source]) for fld in group_by_fields: if fld not in schema.keys(): - self.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \ - found in ECS, Beats, or non-ecs schemas") + self.fail( + f"{self.rule_str(rule)} alert suppression field {fld} not \ + found in ECS, Beats, or non-ecs schemas" + ) - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.14.0") or # noqa: W504 - PACKAGE_STACK_VERSION >= Version.parse("8.18.0"), # noqa: W504 - "Test is applicable to 8.14 --> 8.17 stacks for eql non-sequence rule alert suppression feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.14.0") # noqa: W504 + or PACKAGE_STACK_VERSION >= Version.parse("8.18.0"), # noqa: W504 + "Test is applicable to 8.14 --> 8.17 stacks for eql non-sequence rule alert suppression feature.", + ) def test_eql_non_sequence_support_only(self): for rule in self.all_rules: if ( - isinstance(rule.contents.data, EQLRuleData) and rule.contents.data.get("alert_suppression") + isinstance(rule.contents.data, EQLRuleData) + and rule.contents.data.get("alert_suppression") and rule.contents.data.is_sequence # noqa: W503 ): # is_sequence method not yet available during schema validation # so we have to check in a unit test - self.fail( - f"{self.rule_str(rule)} Sequence rules cannot have alert suppression" - ) + self.fail(f"{self.rule_str(rule)} Sequence rules cannot have alert suppression") diff --git a/tests/test_gh_workflows.py b/tests/test_gh_workflows.py index 10983d2057c..7225ac5e75e 100644 --- a/tests/test_gh_workflows.py +++ b/tests/test_gh_workflows.py @@ -6,27 +6,25 @@ """Tests for GitHub workflow functionality.""" import unittest -from pathlib import Path - import yaml from detection_rules.schemas import get_stack_versions, RULES_CONFIG -from detection_rules.utils import get_path +from detection_rules.utils import ROOT_DIR -GITHUB_FILES = Path(get_path()) / '.github' -GITHUB_WORKFLOWS = GITHUB_FILES / 'workflows' +GITHUB_FILES = ROOT_DIR / ".github" +GITHUB_WORKFLOWS = GITHUB_FILES / "workflows" class TestWorkflows(unittest.TestCase): """Test GitHub workflow functionality.""" - @unittest.skipIf(RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_matrix_to_lock_version_defaults(self): """Test that the default versions in the lock-versions workflow mirror those from the schema-map.""" - lock_workflow_file = GITHUB_WORKFLOWS / 'lock-versions.yml' + lock_workflow_file = GITHUB_WORKFLOWS / "lock-versions.yml" lock_workflow = yaml.safe_load(lock_workflow_file.read_text()) - lock_versions = lock_workflow[True]['workflow_dispatch']['inputs']['branches']['default'].split(',') + lock_versions = lock_workflow[True]["workflow_dispatch"]["inputs"]["branches"]["default"].split(",") matrix_versions = get_stack_versions(drop_patch=True) - err_msg = 'lock-versions workflow default does not match current matrix in stack-schema-map' + err_msg = "lock-versions workflow default does not match current matrix in stack-schema-map" self.assertListEqual(lock_versions, matrix_versions[:-1], err_msg) diff --git a/tests/test_hunt_data.py b/tests/test_hunt_data.py index d4a6b3e9eb9..74288ade7d0 100644 --- a/tests/test_hunt_data.py +++ b/tests/test_hunt_data.py @@ -4,6 +4,7 @@ # 2.0. """Test for hunt toml files.""" + import unittest from hunting.definitions import HUNTING_DIR @@ -33,9 +34,7 @@ def test_toml_loading(self): config = load_toml(example_toml) self.assertEqual(config.author, "Elastic") self.assertEqual(config.integration, "aws_bedrock.invocation") - self.assertEqual( - config.name, "Denial of Service or Resource Exhaustion Attacks Detection" - ) + self.assertEqual(config.name, "Denial of Service or Resource Exhaustion Attacks Detection") self.assertEqual(config.language, "ES|QL") def test_load_toml_files(self): @@ -53,9 +52,7 @@ def test_load_toml_files(self): def test_markdown_existence(self): """Ensure each TOML file has a corresponding Markdown file in the docs directory.""" for toml_file in HUNTING_DIR.rglob("*.toml"): - expected_markdown_path = ( - toml_file.parent.parent / "docs" / toml_file.with_suffix(".md").name - ) + expected_markdown_path = toml_file.parent.parent / "docs" / toml_file.with_suffix(".md").name self.assertTrue( expected_markdown_path.exists(), @@ -65,9 +62,7 @@ def test_markdown_existence(self): def test_toml_existence(self): """Ensure each Markdown file has a corresponding TOML file in the queries directory.""" for markdown_file in HUNTING_DIR.rglob("*/docs/*.md"): - expected_toml_path = ( - markdown_file.parent.parent / "queries" / markdown_file.with_suffix(".toml").name - ) + expected_toml_path = markdown_file.parent.parent / "queries" / markdown_file.with_suffix(".toml").name self.assertTrue( expected_toml_path.exists(), @@ -87,12 +82,14 @@ def test_mitre_techniques_present(self): """Ensure each query has at least one MITRE technique.""" for folder, queries in self.hunting_index.items(): for query_uuid, query_data in queries.items(): - self.assertTrue(query_data.get('mitre'), - f"No MITRE techniques found for query: {query_data.get('name', query_uuid)}") + self.assertTrue( + query_data.get("mitre"), + f"No MITRE techniques found for query: {query_data.get('name', query_uuid)}", + ) def test_valid_structure(self): """Ensure each query entry has a valid structure.""" - required_fields = ['name', 'path', 'mitre'] + required_fields = ["name", "path", "mitre"] for folder, queries in self.hunting_index.items(): for query_uuid, query_data in queries.items(): @@ -111,8 +108,7 @@ def test_all_files_in_index(self): missing_index_entries.append(query_uuid) self.assertFalse( - missing_index_entries, - f"Missing index entries for the following queries: {missing_index_entries}" + missing_index_entries, f"Missing index entries for the following queries: {missing_index_entries}" ) diff --git a/tests/test_packages.py b/tests/test_packages.py index e0adab2fac9..54129720f90 100644 --- a/tests/test_packages.py +++ b/tests/test_packages.py @@ -4,14 +4,14 @@ # 2.0. """Test that the packages are built correctly.""" + import unittest import uuid from semver import Version from marshmallow import ValidationError from detection_rules import rule_loader -from detection_rules.schemas.registry_package import (RegistryPackageManifestV1, - RegistryPackageManifestV3) +from detection_rules.schemas.registry_package import RegistryPackageManifestV1, RegistryPackageManifestV3 from detection_rules.packaging import PACKAGE_FILE, Package from tests.base import BaseRuleTest @@ -35,17 +35,13 @@ def get_rule_contents(): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "query" + "type": "query", } return contents - rules = [rule_loader.TOMLRule('test.toml', get_rule_contents()) for i in range(count)] + rules = [rule_loader.TOMLRule("test.toml", get_rule_contents()) for i in range(count)] version_info = { - rule.id: { - 'rule_name': rule.name, - 'sha256': rule.contents.get_hash(), - 'version': version - } for rule in rules + rule.id: {"rule_name": rule.name, "sha256": rule.contents.get_hash(), "version": version} for rule in rules } return rules, version_info @@ -53,19 +49,19 @@ def get_rule_contents(): def test_package_loader_production_config(self): """Test that packages are loading correctly.""" - @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_package_loader_default_configs(self): """Test configs in detection_rules/etc/packages.yaml.""" Package.from_config(rule_collection=self.rc, config=package_configs) - @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_package_summary(self): """Test the generation of the package summary.""" rules = self.rc - package = Package(rules, 'test-package') + package = Package(rules, "test-package") package.generate_summary_and_changelog(package.changed_ids, package.new_ids, package.removed_ids) - @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_rule_versioning(self): """Test that all rules are properly versioned and tracked""" self.maxDiff = None @@ -75,22 +71,22 @@ def test_rule_versioning(self): # test that no rules have versions defined for rule in rules: - self.assertGreaterEqual(rule.contents.autobumped_version, 1, '{} - {}: version is not being set in package') + self.assertGreaterEqual(rule.contents.autobumped_version, 1, "{} - {}: version is not being set in package") original_hashes.append(rule.contents.get_hash()) - package = Package(rules, 'test-package') + package = Package(rules, "test-package") # test that all rules have versions defined # package.bump_versions(save_changes=False) for rule in package.rules: - self.assertGreaterEqual(rule.contents.autobumped_version, 1, '{} - {}: version is not being set in package') + self.assertGreaterEqual(rule.contents.autobumped_version, 1, "{} - {}: version is not being set in package") # test that rules validate with version for rule in package.rules: post_bump_hashes.append(rule.contents.get_hash()) # test that no hashes changed as a result of the version bumps - self.assertListEqual(original_hashes, post_bump_hashes, 'Version bumping modified the hash of a rule') + self.assertListEqual(original_hashes, post_bump_hashes, "Version bumping modified the hash of a rule") class TestRegistryPackage(unittest.TestCase): @@ -98,11 +94,11 @@ class TestRegistryPackage(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - - assert 'registry_data' in package_configs, f'Missing registry_data in {PACKAGE_FILE}' - cls.registry_config = package_configs['registry_data'] - stack_version = Version.parse(cls.registry_config['conditions']['kibana.version'].strip("^"), - optional_minor_and_patch=True) + assert "registry_data" in package_configs, f"Missing registry_data in {PACKAGE_FILE}" + cls.registry_config = package_configs["registry_data"] + stack_version = Version.parse( + cls.registry_config["conditions"]["kibana.version"].strip("^"), optional_minor_and_patch=True + ) if stack_version >= Version.parse("8.12.0"): RegistryPackageManifestV3.from_dict(cls.registry_config) else: @@ -111,7 +107,7 @@ def setUpClass(cls) -> None: def test_registry_package_config(self): """Test that the registry package is validating properly.""" registry_config = self.registry_config.copy() - registry_config['version'] += '7.1.1.' + registry_config["version"] += "7.1.1." with self.assertRaises(ValidationError): RegistryPackageManifestV1.from_dict(registry_config) diff --git a/tests/test_python_library.py b/tests/test_python_library.py index 3d0ab8b09b6..7a02f8900a1 100644 --- a/tests/test_python_library.py +++ b/tests/test_python_library.py @@ -59,9 +59,7 @@ def test_eql_in_set(self): with self.assertRaisesRegex(ValueError, expected_error_message): rc.load_dict(eql_rule) # Change to appropriate destination.address field - eql_rule["rule"][ - "query" - ] = """ + eql_rule["rule"]["query"] = """ sequence by host.id, process.entity_id with maxspan = 10s [network where destination.address in ("192.168.1.1", "::1")] """ diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py index e422239ce62..24402061feb 100644 --- a/tests/test_rules_remote.py +++ b/tests/test_rules_remote.py @@ -10,7 +10,7 @@ # from detection_rules.remote_validation import RemoteValidator -@unittest.skipIf(get_default_config() is None, 'Skipping remote validation due to missing config') +@unittest.skipIf(get_default_config() is None, "Skipping remote validation due to missing config") class TestRemoteRules(BaseRuleTest): """Test rules against a remote Elastic stack instance.""" diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 5adb5a77073..5f0426c056e 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -4,6 +4,7 @@ # 2.0. """Test stack versioned schemas.""" + import copy import unittest import uuid @@ -42,7 +43,7 @@ def setUpClass(cls): "tactic": { "id": "TA0001", "name": "Execution", - "reference": "https://attack.mitre.org/tactics/TA0001/" + "reference": "https://attack.mitre.org/tactics/TA0001/", }, "technique": [ { @@ -52,25 +53,25 @@ def setUpClass(cls): } ], } - ] + ], } cls.v79_kql = dict(cls.v78_kql, author=["Elastic"], license="Elastic License v2") cls.v711_kql = copy.deepcopy(cls.v79_kql) # noinspection PyTypeChecker - cls.v711_kql["threat"][0]["technique"][0]["subtechnique"] = [{ - "id": "T1059.001", - "name": "PowerShell", - "reference": "https://attack.mitre.org/techniques/T1059/001/" - }] + cls.v711_kql["threat"][0]["technique"][0]["subtechnique"] = [ + {"id": "T1059.001", "name": "PowerShell", "reference": "https://attack.mitre.org/techniques/T1059/001/"} + ] # noinspection PyTypeChecker - cls.v711_kql["threat"].append({ - "framework": "MITRE ATT&CK", - "tactic": { - "id": "TA0008", - "name": "Lateral Movement", - "reference": "https://attack.mitre.org/tactics/TA0008/" - }, - }) + cls.v711_kql["threat"].append( + { + "framework": "MITRE ATT&CK", + "tactic": { + "id": "TA0008", + "name": "Lateral Movement", + "reference": "https://attack.mitre.org/tactics/TA0008/", + }, + } + ) cls.v79_threshold_contents = { "author": ["Elastic"], @@ -88,14 +89,14 @@ def setUpClass(cls): }, "type": "threshold", } - cls.v712_threshold_rule = dict(copy.deepcopy(cls.v79_threshold_contents), threshold={ - 'field': ['destination.bytes', 'process.args'], - 'value': 75, - 'cardinality': [{ - 'field': 'user.name', - 'value': 2 - }] - }) + cls.v712_threshold_rule = dict( + copy.deepcopy(cls.v79_threshold_contents), + threshold={ + "field": ["destination.bytes", "process.args"], + "value": 75, + "cardinality": [{"field": "user.name", "value": 2}], + }, + ) def test_query_downgrade_7_x(self): """Downgrade a standard KQL rule.""" @@ -142,21 +143,21 @@ def test_threshold_downgrade_7_x(self): return api_contents = self.v712_threshold_rule - self.assertDictEqual(downgrade(api_contents, '7.13'), api_contents) - self.assertDictEqual(downgrade(api_contents, '7.13.1'), api_contents) + self.assertDictEqual(downgrade(api_contents, "7.13"), api_contents) + self.assertDictEqual(downgrade(api_contents, "7.13.1"), api_contents) - exc_msg = 'Cannot downgrade a threshold rule that has multiple threshold fields defined' + exc_msg = "Cannot downgrade a threshold rule that has multiple threshold fields defined" with self.assertRaisesRegex(ValueError, exc_msg): - downgrade(api_contents, '7.9') + downgrade(api_contents, "7.9") v712_threshold_contents_single_field = copy.deepcopy(api_contents) - v712_threshold_contents_single_field['threshold']['field'].pop() + v712_threshold_contents_single_field["threshold"]["field"].pop() with self.assertRaisesRegex(ValueError, "Cannot downgrade a threshold rule that has a defined cardinality"): downgrade(v712_threshold_contents_single_field, "7.9") v712_no_cardinality = copy.deepcopy(v712_threshold_contents_single_field) - v712_no_cardinality['threshold'].pop('cardinality') + v712_no_cardinality["threshold"].pop("cardinality") self.assertEqual(downgrade(v712_no_cardinality, "7.9"), self.v79_threshold_contents) with self.assertRaises(ValueError): @@ -191,14 +192,14 @@ def test_eql_validation(self): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "eql" + "type": "eql", } def build_rule(query): metadata = { "creation_date": "1970/01/01", "updated_date": "1970/01/01", - "min_stack_version": load_current_package_version() + "min_stack_version": load_current_package_version(), } data = base_fields.copy() data["query"] = query @@ -209,12 +210,22 @@ def build_rule(query): process where process.name == "cmd.exe" """) - example_text_fields = ['client.as.organization.name.text', 'client.user.full_name.text', - 'client.user.name.text', 'destination.as.organization.name.text', - 'destination.user.full_name.text', 'destination.user.name.text', - 'error.message', 'error.stack_trace.text', 'file.path.text', - 'file.target_path.text', 'host.os.full.text', 'host.os.name.text', - 'host.user.full_name.text', 'host.user.name.text'] + example_text_fields = [ + "client.as.organization.name.text", + "client.user.full_name.text", + "client.user.name.text", + "destination.as.organization.name.text", + "destination.user.full_name.text", + "destination.user.name.text", + "error.message", + "error.stack_trace.text", + "file.path.text", + "file.target_path.text", + "host.os.full.text", + "host.os.name.text", + "host.user.full_name.text", + "host.user.name.text", + ] for text_field in example_text_fields: with self.assertRaises(eql.parser.EqlSchemaError): build_rule(f""" @@ -247,7 +258,7 @@ def setUpClass(cls): "rule_name": "Remote File Download via PowerShell", "sha256": "8679cd72bf85b67dde3dcfdaba749ed1fa6560bca5efd03ed41c76a500ce31d6", "type": "eql", - "version": 4 + "version": 4, }, "34fde489-94b0-4500-a76f-b8a157cf9269": { "min_stack_version": "8.2", @@ -256,14 +267,14 @@ def setUpClass(cls): "rule_name": "Telnet Port Activity", "sha256": "3dd4a438c915920e6ddb0a5212603af5d94fb8a6b51a32f223d930d7e3becb89", "type": "query", - "version": 9 + "version": 9, } }, "rule_name": "Telnet Port Activity", "sha256": "b0bdfa73639226fb83eadc0303ad1801e0707743f96a36209aa58228d3bf6a89", "type": "query", - "version": 10 - } + "version": 10, + }, } def test_version_lock_no_previous(self): @@ -271,7 +282,7 @@ def test_version_lock_no_previous(self): version_lock_contents = copy.deepcopy(self.version_lock_contents) VersionLockFile.from_dict(dict(data=version_lock_contents)) - @unittest.skipIf(RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_version_lock_has_nested_previous(self): """Fail field validation on version lock with nested previous fields""" version_lock_contents = copy.deepcopy(self.version_lock_contents) @@ -287,6 +298,6 @@ class TestVersions(unittest.TestCase): def test_stack_schema_map(self): """Test to ensure that an entry exists in the stack-schema-map for the current package version.""" package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - stack_map = utils.load_etc_dump('stack-schema-map.yaml') - err_msg = f'There is no entry defined for the current package ({package_version}) in the stack-schema-map' + stack_map = utils.load_etc_dump(["stack-schema-map.yaml"]) + err_msg = f"There is no entry defined for the current package ({package_version}) in the stack-schema-map" self.assertIn(package_version, [Version.parse(v) for v in stack_map], err_msg) diff --git a/tests/test_specific_rules.py b/tests/test_specific_rules.py index 2236745a488..d04e8d45a15 100644 --- a/tests/test_specific_rules.py +++ b/tests/test_specific_rules.py @@ -5,7 +5,6 @@ import unittest from copy import deepcopy -from pathlib import Path import eql.ast @@ -40,7 +39,7 @@ class TestEndpointQuery(BaseRuleTest): def test_os_and_platform_in_query(self): """Test that all endpoint rules have an os defined and linux includes platform.""" for rule in self.all_rules: - if not rule.contents.data.get('language') in ('eql', 'kuery'): + if rule.contents.data.get("language") not in ("eql", "kuery"): continue if rule.path.parent.name not in ("windows", "macos", "linux"): # skip cross-platform for now @@ -75,14 +74,13 @@ def test_history_window_start(self): for rule in self.all_rules: if rule.contents.data.type == "new_terms": - # validate history window start field exists and is correct - assert ( - rule.contents.data.new_terms.history_window_start - ), "new terms field found with no history_window_start field defined" - assert ( - rule.contents.data.new_terms.history_window_start[0].field == "history_window_start" - ), f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" + assert rule.contents.data.new_terms.history_window_start, ( + "new terms field found with no history_window_start field defined" + ) + assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", ( + f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -91,9 +89,9 @@ def test_new_terms_field_exists(self): # validate new terms and history window start fields are correct for rule in self.all_rules: if rule.contents.data.type == "new_terms": - assert ( - rule.contents.data.new_terms.field == "new_terms_fields" - ), f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" + assert rule.contents.data.new_terms.field == "new_terms_fields", ( + f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -115,9 +113,9 @@ def test_new_terms_fields(self): else min_stack_version ) - assert ( - min_stack_version >= feature_min_stack - ), f"New Terms rule types only compatible with {feature_min_stack}+" + assert min_stack_version >= feature_min_stack, ( + f"New Terms rule types only compatible with {feature_min_stack}+" + ) ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] beats_version = get_stack_schemas()[str(min_stack_version)]["beats"] @@ -142,9 +140,9 @@ def test_new_terms_fields(self): for policy_template in integration_schema.keys(): schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template]) for new_terms_field in rule.contents.data.new_terms.value: - assert ( - new_terms_field in schema.keys() - ), f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" + assert new_terms_field in schema.keys(), ( + f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -167,9 +165,9 @@ def test_new_terms_max_limit(self): else min_stack_version ) if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields: - assert ( - len(rule.contents.data.new_terms.value) == 1 - ), f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" + assert len(rule.contents.data.new_terms.value) == 1, ( + f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.6.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -179,9 +177,9 @@ def test_new_terms_fields_unique(self): # validate fields are unique for rule in self.all_rules: if rule.contents.data.type == "new_terms": - assert len(set(rule.contents.data.new_terms.value)) == len( - rule.contents.data.new_terms.value - ), f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" + assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), ( + f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" + ) class TestESQLRules(BaseRuleTest): @@ -190,7 +188,7 @@ class TestESQLRules(BaseRuleTest): def run_esql_test(self, esql_query, expectation, message): """Test that the query validation is working correctly.""" rc = RuleCollection() - file_path = Path(get_path("tests", "data", "command_control_dummy_production_rule.toml")) + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) original_production_rule = load_rule_contents(file_path) # Test that a ValidationError is raised if the query doesn't match the schema diff --git a/tests/test_toml_formatter.py b/tests/test_toml_formatter.py index 4787354fa83..177874f6e98 100644 --- a/tests/test_toml_formatter.py +++ b/tests/test_toml_formatter.py @@ -7,32 +7,38 @@ import json import os import unittest +from pathlib import Path import pytoml from detection_rules.rule_formatter import nested_normalize, toml_write from detection_rules.utils import get_etc_path -tmp_file = 'tmp_file.toml' +tmp_file = "tmp_file.toml" class TestRuleTomlFormatter(unittest.TestCase): """Test that the custom toml formatting is not compromising the integrity of the data.""" - with open(get_etc_path("test_toml.json"), "r") as f: - test_data = json.load(f) + + maxDiff = None + + def setUp(self): + with open(get_etc_path(["test_toml.json"]), "r") as f: + self.test_data = json.load(f) def compare_formatted(self, data, callback=None, kwargs=None): """Compare formatted vs expected.""" try: - toml_write(copy.deepcopy(data), tmp_file) + tmp_path = Path(tmp_file) + toml_write(copy.deepcopy(data), tmp_path) - with open(tmp_file, 'r') as f: - formatted_contents = pytoml.load(f) + formatted_data = tmp_path.read_text() + formatted_contents = pytoml.loads(formatted_data) # callbacks such as nested normalize leave in line breaks, so this must be manually done - query = data.get('rule', {}).get('query') + query = data.get("rule", {}).get("query") if query: - data['rule']['query'] = query.strip() + data["rule"]["query"] = query.strip() original = json.dumps(copy.deepcopy(data), sort_keys=True) @@ -41,15 +47,15 @@ def compare_formatted(self, data, callback=None, kwargs=None): formatted_contents = callback(formatted_contents, **kwargs) # callbacks such as nested normalize leave in line breaks, so this must be manually done - query = formatted_contents.get('rule', {}).get('query') + query = formatted_contents.get("rule", {}).get("query") if query: - formatted_contents['rule']['query'] = query.strip() + formatted_contents["rule"]["query"] = query.strip() formatted = json.dumps(formatted_contents, sort_keys=True) - self.assertEqual(original, formatted, 'Formatting may be modifying contents') - + self.assertEqual(original, formatted, "Formatting may be modifying contents") finally: - os.remove(tmp_file) + if os.path.exists(tmp_file): + os.remove(tmp_file) def compare_test_data(self, test_dicts, callback=None): """Compare test data against expected.""" @@ -67,6 +73,7 @@ def test_formatter_rule(self): def test_formatter_deep(self): """Test that the data remains unchanged from formatting.""" self.compare_test_data(self.test_data[1:]) + # # def test_format_of_all_rules(self): # """Test all rules.""" diff --git a/tests/test_transform_fields.py b/tests/test_transform_fields.py index 53352584181..2b7a0fe2c84 100644 --- a/tests/test_transform_fields.py +++ b/tests/test_transform_fields.py @@ -4,6 +4,7 @@ # 2.0. """Test fields in TOML [transform].""" + import copy import unittest from textwrap import dedent @@ -45,7 +46,7 @@ def load_rule() -> TOMLRule: "license": "Elastic License v2", "from": "now-9m", "name": "Test Suspicious Print Spooler SPL File Created", - "note": 'Test note', + "note": "Test note", "references": ["https://safebreach.com/Post/How-we-bypassed-CVE-2020-1048-Patch-and-got-CVE-2020-1337"], "risk_score": 47, "rule_id": "43716252-4a45-4694-aff0-5245b7b6c7cd", @@ -91,7 +92,8 @@ def load_rule() -> TOMLRule: def test_transform_guide_markdown_plugins(self) -> None: sample_rule = self.load_rule() rule_dict = sample_rule.contents.to_dict() - osquery_toml = dedent(""" + osquery_toml = dedent( + """ [transform] [[transform.osquery]] label = "Osquery - Retrieve DNS Cache" @@ -108,9 +110,11 @@ def test_transform_guide_markdown_plugins(self) -> None: [[transform.osquery]] label = "Retrieve Service Unisgned Executables with Virustotal Link" query = "SELECT concat('https://www.virustotal.com/gui/file/', sha1) AS VtLink, name, description, start_type, status, pid, services.path FROM services JOIN authenticode ON services.path = authenticode.path OR services.module_path = authenticode.path JOIN hash ON services.path = hash.path WHERE authenticode.result != 'trusted'" - """.strip()) # noqa: E501 + """.strip() + ) # noqa: E501 - sample_note = dedent(""" + sample_note = dedent( + """ ## Triage and analysis ### Investigating Unusual Process For a Windows Host @@ -135,15 +139,16 @@ def test_transform_guide_markdown_plugins(self) -> None: - $osquery_2 - $osquery_3 - Retrieve the files' SHA-256 hash values using the PowerShell `Get-FileHash` cmdlet and search for the existence and reputation of the hashes in resources like VirusTotal, Hybrid-Analysis, CISCO Talos, Any.run, etc. - """.strip()) # noqa: E501 + """.strip() + ) # noqa: E501 transform = pytoml.loads(osquery_toml) - rule_dict['rule']['note'] = sample_note + rule_dict["rule"]["note"] = sample_note rule_dict.update(**transform) new_rule_contents = TOMLRuleContents.from_dict(rule_dict) new_rule = TOMLRule(path=sample_rule.path, contents=new_rule_contents) - rendered_note = new_rule.contents.to_api_format()['note'] + rendered_note = new_rule.contents.to_api_format()["note"] for pattern in self.osquery_patterns: self.assertIn(pattern, rendered_note) @@ -152,7 +157,7 @@ def test_plugin_conversion(self): """Test the conversion function to ensure parsing is correct.""" sample_rule = self.load_rule() rule_dict = sample_rule.contents.to_dict() - rule_dict['rule']['note'] = "$osquery_0" + rule_dict["rule"]["note"] = "$osquery_0" for pattern in self.osquery_patterns: transform = guide_plugin_convert_(contents=pattern) @@ -160,6 +165,6 @@ def test_plugin_conversion(self): rule_dict_copy.update(**transform) new_rule_contents = TOMLRuleContents.from_dict(rule_dict_copy) new_rule = TOMLRule(path=sample_rule.path, contents=new_rule_contents) - rendered_note = new_rule.contents.to_api_format()['note'] + rendered_note = new_rule.contents.to_api_format()["note"] self.assertIn(pattern, rendered_note) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4788906420f..eecab63df71 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,7 @@ # 2.0. """Test util time functions.""" + import random import time import unittest @@ -17,24 +18,24 @@ class TestTimeUtils(unittest.TestCase): """Test util time functions.""" @staticmethod - def get_events(timestamp_field='@timestamp'): + def get_events(timestamp_field="@timestamp"): """Get test data.""" date_formats = { - 'epoch_millis': lambda x: int(round(time.time(), 3) + x) * 1000, - 'epoch_second': lambda x: round(time.time()) + x, - 'unix_micros': lambda x: time.time() + x, - 'unix_millis': lambda x: round(time.time(), 3) + x, - 'strict_date_optional_time': lambda x: '2020-05-13T04:36:' + str(15 + x) + '.394Z' + "epoch_millis": lambda x: int(round(time.time(), 3) + x) * 1000, + "epoch_second": lambda x: round(time.time()) + x, + "unix_micros": lambda x: time.time() + x, + "unix_millis": lambda x: round(time.time(), 3) + x, + "strict_date_optional_time": lambda x: "2020-05-13T04:36:" + str(15 + x) + ".394Z", } def _get_data(func): data = [ - {timestamp_field: func(0), 'foo': 'bar', 'id': 1}, - {timestamp_field: func(1), 'foo': 'bar', 'id': 2}, - {timestamp_field: func(2), 'foo': 'bar', 'id': 3}, - {timestamp_field: func(3), 'foo': 'bar', 'id': 4}, - {timestamp_field: func(4), 'foo': 'bar', 'id': 5}, - {timestamp_field: func(5), 'foo': 'bar', 'id': 6} + {timestamp_field: func(0), "foo": "bar", "id": 1}, + {timestamp_field: func(1), "foo": "bar", "id": 2}, + {timestamp_field: func(2), "foo": "bar", "id": 3}, + {timestamp_field: func(3), "foo": "bar", "id": 4}, + {timestamp_field: func(4), "foo": "bar", "id": 5}, + {timestamp_field: func(5), "foo": "bar", "id": 6}, ] random.shuffle(data) return data @@ -43,8 +44,8 @@ def _get_data(func): def assert_sort(self, normalized_events, date_format): """Assert normalize and sort.""" - order = [e['id'] for e in normalized_events] - self.assertListEqual([1, 2, 3, 4, 5, 6], order, 'Sorting failed for date_format: {}'.format(date_format)) + order = [e["id"] for e in normalized_events] + self.assertListEqual([1, 2, 3, 4, 5, 6], order, "Sorting failed for date_format: {}".format(date_format)) def test_time_normalize(self): """Test normalize_timing_from_date_format.""" @@ -57,8 +58,8 @@ def test_event_class_normalization(self): """Test that events are normalized properly within Events.""" events_data = self.get_events() for date_format, events in events_data.items(): - normalized = Events({'winlogbeat': events}) - self.assert_sort(normalized.events['winlogbeat'], date_format) + normalized = Events({"winlogbeat": events}) + self.assert_sort(normalized.events["winlogbeat"], date_format) def test_schema_multifields(self): """Tests that schemas are loading multifields correctly.""" @@ -88,15 +89,15 @@ def increment(*args, **kwargs): self.assertEqual(increment(), 1) self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment({"hello": [("world", )]}), 3) - self.assertEqual(increment({"hello": [("world", )]}), 3) + self.assertEqual(increment({"hello": [("world",)]}), 3) + self.assertEqual(increment({"hello": [("world",)]}), 3) self.assertEqual(increment(), 1) self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment({"hello": [("world", )]}), 3) + self.assertEqual(increment({"hello": [("world",)]}), 3) increment.clear() - self.assertEqual(increment({"hello": [("world", )]}), 4) + self.assertEqual(increment({"hello": [("world",)]}), 4) self.assertEqual(increment(["hello", "world"]), 5) self.assertEqual(increment(), 6) self.assertEqual(increment(None), 7) diff --git a/tests/test_version_locking.py b/tests/test_version_locking.py index 37e0e5e420e..2da56cbdc01 100644 --- a/tests/test_version_locking.py +++ b/tests/test_version_locking.py @@ -16,14 +16,14 @@ class TestVersionLock(unittest.TestCase): """Test version locking.""" - @unittest.skipIf(RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_previous_entries_gte_current_min_stack(self): """Test that all previous entries for all locks in the version lock are >= the current min_stack.""" errors = {} min_version = get_min_supported_stack_version() for rule_id, lock in loaded_version_lock.version_lock.to_dict().items(): - if 'previous' in lock: - prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock['previous'])] + if "previous" in lock: + prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock["previous"])] outdated = [f"{v.major}.{v.minor}" for v in prev_vers if v < min_version] if outdated: errors[rule_id] = outdated @@ -31,7 +31,9 @@ def test_previous_entries_gte_current_min_stack(self): # This should only ever happen when bumping the backport matrix support up, which is based on the # stack-schema-map if errors: - err_str = '\n'.join(f'{k}: {", ".join(v)}' for k, v in errors.items()) - self.fail(f'The following version.lock entries have previous locked versions which are lower than the ' - f'currently supported min_stack ({min_version}). To address this, run the ' - f'`dev trim-version-lock {min_version}` command.\n\n{err_str}') + err_str = "\n".join(f"{k}: {', '.join(v)}" for k, v in errors.items()) + self.fail( + f"The following version.lock entries have previous locked versions which are lower than the " + f"currently supported min_stack ({min_version}). To address this, run the " + f"`dev trim-version-lock {min_version}` command.\n\n{err_str}" + )