diff --git a/.circleci/config.yml b/.circleci/config.yml index 70e8932f..d47d3db5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -227,6 +227,19 @@ jobs: fi docker save -o ~/openaev/images/collector-splunk-es openaev/collector-splunk-es:${CIRCLE_SHA1} docker save -o ~/openaev/images/collector-splunk-es-ubi9 openaev/collector-splunk-es:${CIRCLE_SHA1}-ubi9 + - run: + working_directory: ~/openaev/netwitness + name: Build Docker image openaev/collector-netwitness + command: | + if [[ "${CIRCLE_BRANCH}" == "main" ]]; then + docker build --pull --progress=plain -t openaev/collector-netwitness:${CIRCLE_SHA1} --build-arg PYOAEV_GIT_BRANCH_OVERRIDE="${CIRCLE_BRANCH}" . + docker build --pull --progress=plain -t openaev/collector-netwitness:${CIRCLE_SHA1}-ubi9 -f Dockerfile_ubi9 --build-arg PYOAEV_GIT_BRANCH_OVERRIDE="${CIRCLE_BRANCH}" . + else + docker build --pull --progress=plain -t openaev/collector-netwitness:${CIRCLE_SHA1} . + docker build --pull --progress=plain -t openaev/collector-netwitness:${CIRCLE_SHA1}-ubi9 -f Dockerfile_ubi9 . + fi + docker save -o ~/openaev/images/collector-netwitness openaev/collector-netwitness:${CIRCLE_SHA1} + docker save -o ~/openaev/images/collector-netwitness-ubi9 openaev/collector-netwitness:${CIRCLE_SHA1}-ubi9 - run: working_directory: ~/openaev/openaev name: Build Docker image openaev/collector-openaev @@ -414,6 +427,13 @@ jobs: docker tag openaev/collector-splunk-es:${CIRCLE_SHA1}-ubi9 openaev/collector-splunk-es:${IMAGETAG}-ubi9 docker tag openaev/collector-splunk-es:${CIRCLE_SHA1}-ubi9 openbas/collector-splunk-es:${IMAGETAG}-ubi9 + docker image load < collector-netwitness + docker tag openaev/collector-netwitness:${CIRCLE_SHA1} openaev/collector-netwitness:${IMAGETAG} + docker tag openaev/collector-netwitness:${CIRCLE_SHA1} openbas/collector-netwitness:${IMAGETAG} + docker image load < collector-netwitness-ubi9 + docker tag openaev/collector-netwitness:${CIRCLE_SHA1}-ubi9 openaev/collector-netwitness:${IMAGETAG}-ubi9 + docker tag openaev/collector-netwitness:${CIRCLE_SHA1}-ubi9 openbas/collector-netwitness:${IMAGETAG}-ubi9 + docker image load < collector-openaev docker tag openaev/collector-openaev:${CIRCLE_SHA1} openaev/collector-openaev:${IMAGETAG} docker tag openaev/collector-openaev:${CIRCLE_SHA1} openbas/collector-openbas:${IMAGETAG} @@ -497,6 +517,10 @@ jobs: docker push openaev/collector-splunk-es:${IMAGETAG}-ubi9 docker push openbas/collector-splunk-es:${IMAGETAG} docker push openbas/collector-splunk-es:${IMAGETAG}-ubi9 + docker push openaev/collector-netwitness:${IMAGETAG} + docker push openaev/collector-netwitness:${IMAGETAG}-ubi9 + docker push openbas/collector-netwitness:${IMAGETAG} + docker push openbas/collector-netwitness:${IMAGETAG}-ubi9 docker push openaev/collector-openaev:${IMAGETAG} docker push openaev/collector-openaev:${IMAGETAG}-ubi9 docker push openbas/collector-openbas:${IMAGETAG} @@ -544,6 +568,8 @@ jobs: docker tag openaev/collector-sentinelone:${IMAGETAG} openbas/collector-sentinelone:latest docker tag openaev/collector-splunk-es:${IMAGETAG} openaev/collector-splunk-es:latest docker tag openaev/collector-splunk-es:${IMAGETAG} openbas/collector-splunk-es:latest + docker tag openaev/collector-netwitness:${IMAGETAG} openaev/collector-netwitness:latest + docker tag openaev/collector-netwitness:${IMAGETAG} openbas/collector-netwitness:latest docker tag openaev/collector-openaev:${IMAGETAG} openaev/collector-openaev:latest docker tag openaev/collector-openaev:${IMAGETAG} openbas/collector-openbas:latest docker tag openaev/collector-microsoft-azure:${IMAGETAG} openaev/collector-microsoft-azure:latest @@ -577,6 +603,8 @@ jobs: docker push openbas/collector-sentinelone:latest docker push openaev/collector-splunk-es:latest docker push openbas/collector-splunk-es:latest + docker push openaev/collector-netwitness:latest + docker push openbas/collector-netwitness:latest docker push openaev/collector-openaev:latest docker push openbas/collector-openbas:latest docker push openaev/collector-microsoft-azure:latest diff --git a/netwitness/.build.env b/netwitness/.build.env new file mode 100644 index 00000000..d5dbd10c --- /dev/null +++ b/netwitness/.build.env @@ -0,0 +1,2 @@ +COLLECTOR_CMD=src + diff --git a/netwitness/.dockerignore b/netwitness/.dockerignore new file mode 100644 index 00000000..ac8f8398 --- /dev/null +++ b/netwitness/.dockerignore @@ -0,0 +1,11 @@ +# Configuration files +config.yml + +# Build artifacts +dist + +# Cache directories +__pycache__ +.ruff_cache +.mypy_cache +.pytest_cache diff --git a/netwitness/.gitignore b/netwitness/.gitignore new file mode 100644 index 00000000..c95f9119 --- /dev/null +++ b/netwitness/.gitignore @@ -0,0 +1,10 @@ +config.yml + +# Build artifacts +dist + +# Cache directories +__pycache__ +.ruff_cache +.mypy_cache +.pytest_cache diff --git a/netwitness/Dockerfile b/netwitness/Dockerfile new file mode 100644 index 00000000..f810023b --- /dev/null +++ b/netwitness/Dockerfile @@ -0,0 +1,33 @@ + +FROM python:3.13-alpine AS builder + +# poetry version available on Ubuntu 24.04 +RUN pip3 install poetry==2.1.3 + +RUN apk update && apk upgrade + +ARG installdir=/collector +ADD . ${installdir} +RUN cd ${installdir} && poetry build + +FROM python:3.13-alpine AS runner + +# Declare the build argument +ARG PYOAEV_GIT_BRANCH_OVERRIDE + +ARG installdir=/collector +COPY --from=builder ${installdir} ${installdir} +RUN cd ${installdir}/dist && pip3 install --no-cache-dir "$(ls *.whl)[prod]" + +RUN if [[ ${PYOAEV_GIT_BRANCH_OVERRIDE} ]] ; then \ + echo "Forcing specific version of client-python" && \ + apk add --no-cache git && \ + pip install pip3-autoremove && \ + pip-autoremove pyoaev -y && \ + pip install git+https://github.com/OpenAEV-Platform/client-python@${PYOAEV_GIT_BRANCH_OVERRIDE} ; \ + fi + +# necessary for icon location +WORKDIR ${installdir} + +CMD ["NetWitnessCollector"] diff --git a/netwitness/Dockerfile_ubi9 b/netwitness/Dockerfile_ubi9 new file mode 100644 index 00000000..0c1b4bf5 --- /dev/null +++ b/netwitness/Dockerfile_ubi9 @@ -0,0 +1,43 @@ +FROM registry.access.redhat.com/ubi9/ubi-minimal AS base + +RUN set -eux; \ + microdnf -y --setopt=install_weak_deps=0 install python3.12; \ + microdnf clean all; + + +FROM base AS builder + +RUN set -eux; \ + microdnf -y --setopt=install_weak_deps=0 install python3.12-pip; \ + pip3.12 install poetry==2.1.3; \ + microdnf -y remove python3.12-pip; \ + microdnf clean all; + +WORKDIR /collector +COPY ./ ./ +RUN set -eux; \ + poetry build + + +FROM base AS runner + +ARG PYOAEV_GIT_BRANCH_OVERRIDE="" + +WORKDIR /collector +COPY --from=builder /collector/ ./ + +RUN set -eux; \ + microdnf -y --setopt=install_weak_deps=0 install python3.12-pip; \ + (cd dist && pip3.12 install --no-cache-dir "$(ls *.whl)[prod]"); \ + if [ -n "${PYOAEV_GIT_BRANCH_OVERRIDE}" ] ; then \ + echo "Forcing specific version of client-python"; \ + microdnf -y --setopt=install_weak_deps=0 install git-core; \ + pip3.12 install pip3-autoremove; \ + pip-autoremove pyoaev -y; \ + pip3.12 install git+https://github.com/OpenAEV-Platform/client-python@${PYOAEV_GIT_BRANCH_OVERRIDE}; \ + microdnf -y remove git-core; \ + fi; \ + microdnf -y remove python3.12-pip; \ + microdnf clean all; + +CMD ["NetWitnessCollector"] diff --git a/netwitness/README.md b/netwitness/README.md new file mode 100644 index 00000000..77ea24bb --- /dev/null +++ b/netwitness/README.md @@ -0,0 +1,137 @@ +# OpenAEV NetWitness Collector + +A NetWitness integration for OpenAEV that validates detection expectations by querying the NetWitness Core SDK and matching the returned sessions against expected outcomes. + +**Note**: Requires network access to a NetWitness Core service (Broker or Concentrator) with the RESTful API enabled. + +## Overview + +This collector validates OpenAEV expectations by querying your NetWitness environment for matching sessions via the Core SDK query API (NWQL). When OpenAEV runs security exercises, this collector automatically checks whether the expected activity was observed by NetWitness, providing visibility into your detection capabilities. + +The collector builds an NWQL query from the attack signatures, executes it against the Core SDK, and parses the returned session metadata, with support for IP-based matching and parent process tracking through the URL meta. + +## Features + +- **Detection Validation**: Runs NetWitness Core SDK queries to verify detections +- **IP-based Matching**: Supports both source and destination IPv4 / IPv6 address matching (`ip.src` / `ip.dst`) +- **Parent Process Tracking**: Extracts inject/agent identifiers from parent process names and matches them against the `url` meta +- **Flexible Authentication**: HTTP basic authentication (Core SDK) or a bearer token (NetWitness Platform API) +- **Retry Mechanism**: Built-in retry logic with a configurable offset to handle ingestion latency +- **Trace Generation**: Creates traces with links back to NetWitness Investigate +- **Flexible Configuration**: Support for YAML, environment variables, and multiple deployment scenarios + +## Required permissions + +The NetWitness collector requires credentials (a Core service user or token) with: +- Permission to run queries against the Core SDK (`/sdk?msg=query`) + +See the NetWitness documentation on the [Core RESTful API](https://community.netwitness.com/s/article/SDKCommands). + +## Requirements + +- OpenAEV Platform +- A NetWitness Core service (Broker on port 50103 or Concentrator on port 50105) with the RESTful API reachable +- Python 3.11+ (for manual deployment) +- A user account or token with permission to query the Core SDK + +## Configuration + +There are a number of configuration options, which are set either in `docker-compose.yml` (for Docker) or in `config.yml` (for manual deployment). + +The collector loads configuration from a single source, selected in this order (the first one found wins; sources are not merged): +1. `.env` file (`src/.env`), if present +2. YAML configuration file (`src/config.yml`), if present +3. Environment variables + +Any value not provided by the selected source falls back to its default. + +### OpenAEV environment variables + +Below are the parameters you'll need to set for OpenAEV: + +| Parameter | config.yml | Docker environment variable | Mandatory | Description | +|-------------------|-------------------|-----------------------------|-----------|-------------------------------------------------------| +| OpenAEV URL | openaev.url | `OPENAEV_URL` | Yes | The URL of the OpenAEV platform. | +| OpenAEV Token | openaev.token | `OPENAEV_TOKEN` | Yes | The default admin token set in the OpenAEV platform. | +| OpenAEV Tenant ID | openaev.tenant_id | `OPENAEV_TENANT_ID` | No | Identifier of the tenant within the OpenAEV platform. | + +> Warning +> +> The `tenant_id` parameter is a new configuration option. A period of backward compatibility is ensured: if this key is not defined, +> existing configurations will not be affected, and the default value will be `None`. However, if a value is provided, it will be +> validated by Pydantic and must conform to a valid UUID format, otherwise a validation error will be returned. + +### Base collector environment variables + +Below are the parameters you'll need to set for running the collector properly: + +| Parameter | config.yml | Docker environment variable | Default | Mandatory | Description | +|------------------|---------------------|-----------------------------|-------------------------------------------------|-----------|-----------------------------------------------------------------------------------------------| +| Collector ID | collector.id | `COLLECTOR_ID` | netwitness--0b13e3f7-5c9e-46f5-acc4-33032e9b... | Yes | A unique `UUIDv4` identifier for this collector instance. | +| Collector Name | collector.name | `COLLECTOR_NAME` | NetWitness | No | Name of the collector. | +| Collector Period | collector.period | `COLLECTOR_PERIOD` | PT1M | No | Collection interval (ISO 8601 format). | +| Log Level | collector.log_level | `COLLECTOR_LOG_LEVEL` | error | No | Determines the verbosity of the logs. Options are `debug`, `info`, `warn`, or `error`. | +| Platform | collector.platform | `COLLECTOR_PLATFORM` | SIEM | No | Type of security platform this collector works for. One of: `EDR, XDR, SIEM, SOAR, NDR, ISPM` | + +### Collector extra parameters environment variables + +Below are the parameters you'll need to set for the collector: + +| Parameter | config.yml | Docker environment variable | Default | Mandatory | Description | +|--------------|-------------------------|-----------------------------|--------------------------------------|-----------|-----------------------------------------------------------------------------------------| +| Base URL | netwitness.base_url | `NETWITNESS_BASE_URL` | https://netwitness.company.com:50103 | Yes | Base URL of a NetWitness Core service (Broker/Concentrator). | +| Username | netwitness.username | `NETWITNESS_USERNAME` | | No* | Username for HTTP basic authentication to the Core SDK. | +| Password | netwitness.password | `NETWITNESS_PASSWORD` | | No* | Password for HTTP basic authentication. | +| Token | netwitness.token | `NETWITNESS_TOKEN` | | No* | Bearer token for the NetWitness Platform API (optional). | +| Max Results | netwitness.max_results | `NETWITNESS_MAX_RESULTS` | 100 | No | Maximum number of sessions to return per query. | +| Console URL | netwitness.console_url | `NETWITNESS_CONSOLE_URL` | | No | NetWitness console URL used to build trace links (defaults to base_url). | +| Verify SSL | netwitness.verify_ssl | `NETWITNESS_VERIFY_SSL` | true | No | Whether to verify the NetWitness TLS certificate. | +| Time Window | netwitness.time_window | `NETWITNESS_TIME_WINDOW` | PT1H | No | Default search window when no date signatures are provided (ISO 8601 format). | +| Offset | netwitness.offset | `NETWITNESS_OFFSET` | PT30S | No | Delay between retry attempts to account for ingestion latency (ISO 8601 format). | +| Max Retry | netwitness.max_retry | `NETWITNESS_MAX_RETRY` | 3 | No | Maximum number of retry attempts after the initial query fails or returns no results. | + +> \* Authentication is required: provide either `NETWITNESS_TOKEN` or both `NETWITNESS_USERNAME` and `NETWITNESS_PASSWORD`. + +### Example Configuration Files + +#### YAML Configuration (`src/config.yml`) +```yaml +openaev: + url: "https://your-openaev-instance.com" + token: "your-openaev-token" +# tenant_id: "your-openaev-tenant-id" +collector: + id: "netwitness--your-unique-uuid" + name: "NetWitness Production" + period: "PT10M" + log_level: "info" + +netwitness: + base_url: "https://your-netwitness.company.com:50103" + username: "api" + password: "your-password" + max_results: 100 + verify_ssl: true + offset: "PT45S" + max_retry: 5 +``` + +#### Environment Variables +```bash +export OPENAEV_URL="https://your-openaev-instance.com" +export OPENAEV_TOKEN="your-openaev-token" +export OPENAEV_TENANT_ID="your-openaev-tenant-id" +export COLLECTOR_ID="netwitness--your-unique-uuid" +export NETWITNESS_BASE_URL="https://your-netwitness.company.com:50103" +export NETWITNESS_USERNAME="api" +export NETWITNESS_PASSWORD="your-password" +``` + +## API endpoints used + +- **Authentication**: HTTP basic (Core SDK) or bearer token +- **Query**: `GET /sdk?msg=query&query=&force-content-type=application/json` +- **NWQL meta used for matching**: `ip.src`, `ip.dst`, `url`, `time` +- **Reference**: [NetWitness Core SDK commands](https://community.netwitness.com/s/article/SDKCommands) + +> **Note**: The required permissions and endpoints listed above are based on the current code and documentation. NetWitness may change API requirements at any time. Always check the official NetWitness documentation for the latest requirements before deploying. diff --git a/netwitness/docker-compose.yml b/netwitness/docker-compose.yml new file mode 100644 index 00000000..fc1b4334 --- /dev/null +++ b/netwitness/docker-compose.yml @@ -0,0 +1,16 @@ +version: "3" +services: + collector-netwitness: + image: openaev/collector-netwitness:rolling + environment: + - OPENAEV_URL=http://localhost + - OPENAEV_TOKEN=ChangeMe + - OPENAEV_TENANT_ID=ChangeMe + - COLLECTOR_ID=ChangeMe + - NETWITNESS_BASE_URL=https://change.me:50103 + # Authentication: provide username/password (Core SDK) or NETWITNESS_TOKEN (bearer) + - NETWITNESS_USERNAME=ChangeMe + - NETWITNESS_PASSWORD=ChangeMe + # - NETWITNESS_TOKEN=ChangeMe + - NETWITNESS_MAX_RESULTS=100 + restart: always diff --git a/netwitness/manifest-metadata.json b/netwitness/manifest-metadata.json new file mode 100644 index 00000000..b9e55ef2 --- /dev/null +++ b/netwitness/manifest-metadata.json @@ -0,0 +1,18 @@ +{ + "title": "NetWitness", + "slug": "openaev_netwitness", + "description": "Collect responses from NetWitness", + "short_description": "Collect responses from NetWitness", + "use_cases": ["Security response"], + "verified": false, + "last_verified_date": "", + "playbook_supported": false, + "max_confidence_level": 80, + "support_version": "", + "subscription_link": "https://www.netwitness.com/", + "source_code": "", + "manager_supported": true, + "container_version": "rolling", + "container_image": "openaev/collector-netwitness", + "container_type": "COLLECTOR" +} \ No newline at end of file diff --git a/netwitness/pyproject.toml b/netwitness/pyproject.toml new file mode 100644 index 00000000..878a0001 --- /dev/null +++ b/netwitness/pyproject.toml @@ -0,0 +1,127 @@ +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +packages = [{ include = "src" }, { include = "tests" }] + +[project] +name = "NetWitnessCollector" +version = "2.260615.0" +description = "Collector for NetWitness." +readme = "README.md" + +requires-python = ">=3.11,<3.14" +dynamic = ["dependencies"] + +[tool.poetry.dependencies] +pyoaev = [ + {markers = "extra == 'prod' and extra != 'local' and extra != 'current'",version = "2.260615.0"}, + {markers = "extra == 'local' and extra != 'current' and extra != 'prod'",path = "../../client-python", develop = true}, +] +pydantic = "^2.11.7" +pydantic-settings = "^2.11.0" +requests = "~2.33.0" + +[tool.poetry.extras] +prod = ["pyoaev"] +local = ["pyoaev"] + +[tool.poetry.group.dev.dependencies] +isort = "^6.0.1" +ruff = "^0.12.11" +mypy = "^1.17.1" +black = "^25.1.0" +flake8 = "^7.3.0" +pip-audit = "^2.9.0" +pre-commit = "^4.3.0" + +[tool.poetry.group.test.dependencies] +pytest = "^9.0.0" +polyfactory = "^2.22.2" + +[project.scripts] +NetWitnessCollector = "src.__main__:main" + +[tool.pytest.ini_options] +testpaths = ["./tests"] + +[tool.ruff] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +target-version = "py312" + +[tool.ruff.lint] +# Never enforce `I001` (unsorted import). Already handle with isort +# Never enforce `E501` (line length violations). Already handle with black +# Never enforce `F821` (Undefined name `null`). incorrect issue with notebook +# Never enforce `D213` (Multi-line docstring summary should start at the second line) conflict with our docstring convention +# Never enforce `D211` (NoBlankLinesBeforeClass)` +# Never enforce `G004` (logging-f-string) Logging statement uses f-string +# Never enforce `TRY003`() Avoid specifying long messages outside the exception class not useful +# Never enforce `D104` (Missing docstring in public package) +# Never enforce `D407` (Missing dashed underline after section) +# Never enforce `D408` (Section underline should be in the line following the section's name) +# Never enforce `D409` (Section underline should match the length of its name) +ignore = [ + "I001", + "D203", + "E501", + "F821", + "D205", + "D213", + "D211", + "G004", + "TRY003", + "D104", + "D407", + "D408", + "D409", +] +select = ["E", "F", "W", "D", "G", "T", "B", "C", "N", "I", "S"] + +[tool.mypy] +strict = true +exclude = [ + '^tests', + '^docs', + '^build', + '^dist', + '^venv', + '^site-packages', + '^__pypackages__', + '^.venv', +] +plugins = ["pydantic.mypy"] + +[tool.cmw] +install-command = "poetry install --extras local" +config-dump-command = "poetry run python src --dump-config-schema" +icon-path = "src/img/netwitness-logo.png" + diff --git a/netwitness/src/.env.sample b/netwitness/src/.env.sample new file mode 100644 index 00000000..5bdbaf88 --- /dev/null +++ b/netwitness/src/.env.sample @@ -0,0 +1,20 @@ +### More env vars are defined in the example docker-compose.yml file +### with appropriate defaults. + +# base URL to reach the OpenAEV server +# note this URL must be routable from inside the container +# so `localhost` will most likely not work +OPENAEV_URL=ChangeMe +# admin account API token from the OpenAEV server +OPENAEV_TOKEN=ChangeMe +OPENAEV_TENANT_ID=ChangeMe + +# collector ID must be a unique string, e.g. UUIDv4 +COLLECTOR_ID=ChangeMe + +NETWITNESS_BASE_URL=ChangeMe +# Authentication: provide a username/password pair (Core SDK) OR a bearer token +NETWITNESS_USERNAME=ChangeMe +NETWITNESS_PASSWORD=ChangeMe +# NETWITNESS_TOKEN=ChangeMe +NETWITNESS_MAX_RESULTS=100 diff --git a/netwitness/src/__init__.py b/netwitness/src/__init__.py new file mode 100644 index 00000000..1527c20e --- /dev/null +++ b/netwitness/src/__init__.py @@ -0,0 +1,3 @@ +from src.models import ConfigLoader + +__all__ = ["ConfigLoader"] diff --git a/netwitness/src/__main__.py b/netwitness/src/__main__.py new file mode 100644 index 00000000..45c85172 --- /dev/null +++ b/netwitness/src/__main__.py @@ -0,0 +1,33 @@ +"""Main entry point for the collector.""" + +import logging +import os +import sys + +from src.collector import Collector +from src.collector.exception import CollectorConfigError + +LOG_PREFIX = "[Main]" + + +def main() -> None: + """Define the main function to run the collector.""" + logger = logging.getLogger(__name__) + + try: + logger.info(f"{LOG_PREFIX} Starting NetWitness collector...") + collector = Collector() + collector.start() + except KeyboardInterrupt: + logger.info(f"{LOG_PREFIX} Collector stopped by user (Ctrl+C)") + os._exit(0) + except CollectorConfigError as e: + logger.error(f"{LOG_PREFIX} Configuration error: {e}") + sys.exit(2) + except Exception as e: + logger.exception(f"{LOG_PREFIX} Fatal error starting collector: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/netwitness/src/collector/__init__.py b/netwitness/src/collector/__init__.py new file mode 100644 index 00000000..897ddf7f --- /dev/null +++ b/netwitness/src/collector/__init__.py @@ -0,0 +1,3 @@ +from src.collector.collector import Collector + +__all__ = ["Collector"] diff --git a/netwitness/src/collector/collector.py b/netwitness/src/collector/collector.py new file mode 100644 index 00000000..e65c791f --- /dev/null +++ b/netwitness/src/collector/collector.py @@ -0,0 +1,142 @@ +"""Core collector.""" + +import os + +from pyoaev.daemons import CollectorDaemon # type: ignore[import-untyped] +from pyoaev.helpers import OpenAEVDetectionHelper # type: ignore[import-untyped] +from src.services.expectation_service import NetWitnessExpectationService +from src.services.trace_service import NetWitnessTraceService +from src.services.utils import NetWitnessConfig + +from .exception import ( + CollectorConfigError, + CollectorProcessingError, + CollectorSetupError, +) +from .expectation_handler import GenericExpectationHandler +from .expectation_manager import GenericExpectationManager + +LOG_PREFIX = "[Collector]" +COLLECTOR_TYPE = "openaev_netwitness" + + +class Collector(CollectorDaemon): # type: ignore[misc] + """Generic Collector using service provider pattern. + + This collector is use-case agnostic and works with any service provider. + """ + + def __init__(self) -> None: + """Initialize the collector. + + Raises: + CollectorConfigError: If collector initialization fails. + + """ + try: + self.config = NetWitnessConfig() + self.config_instance = self.config.load + + super().__init__( + configuration=self.config_instance.to_daemon_config(), + callback=self._process_callback, + collector_type=COLLECTOR_TYPE, + ) + + self.logger.info( # type: ignore[has-type] + f"{LOG_PREFIX} NetWitness Collector initialized successfully" + ) + + except Exception as err: + import logging + + logging.basicConfig(level=logging.ERROR) + self.logger = logging.getLogger(__name__) + self.logger.error(f"{LOG_PREFIX} Failed to initialize collector: {err}") + raise CollectorConfigError( + f"Failed to initialize the collector: {err}" + ) from err + + def _setup(self) -> None: + """Set up the collector. + + Initializes NetWitness services, expectation handler, expectation manager, + and OpenAEV detection helper. Sets up the collector for processing expectations. + + Raises: + CollectorSetupError: If collector setup fails. + + """ + try: + self.logger.info(f"{LOG_PREFIX} Starting collector setup...") + + super()._setup() + + self.logger.debug(f"{LOG_PREFIX} Initializing NetWitness services...") + + self.netwitness_service = NetWitnessExpectationService(self.config_instance) + + self.trace_service = NetWitnessTraceService(self.config_instance) + + self.expectation_handler = GenericExpectationHandler( + self.netwitness_service + ) + + self.expectation_manager = GenericExpectationManager( + oaev_api=self.api, + collector_id=self.get_id(), + expectation_handler=self.expectation_handler, + trace_service=self.trace_service, + ) + + supported_signatures = self.netwitness_service.get_supported_signatures() + self.oaev_detection_helper = OpenAEVDetectionHelper( + logger=self.logger, + relevant_signatures_types=supported_signatures, + ) + + self.logger.info(f"{LOG_PREFIX} Collector setup completed successfully") + self.logger.info( + f"{LOG_PREFIX} Supported signatures: {[sig.value for sig in supported_signatures]}" + ) + + service_info = self.netwitness_service.get_service_info() + self.logger.debug(f"{LOG_PREFIX} Service info: {service_info}") + + except Exception as err: + self.logger.error(f"{LOG_PREFIX} Collector setup failed: {err}") + raise CollectorSetupError(f"Failed to setup the collector: {err}") from err + + def _process_callback(self) -> None: + """Process the callback for expectation processing. + + Executes a single processing cycle, handling expectations through the + expectation manager and logging results. Handles keyboard interrupts + and system exits gracefully. + + Raises: + CollectorProcessingError: If processing cycle fails. + + """ + try: + self.logger.info(f"{LOG_PREFIX} Starting processing cycle...") + self.logger.debug( + f"{LOG_PREFIX} Processing expectations using NetWitness services" + ) + + results = self.expectation_manager.process_expectations( + detection_helper=self.oaev_detection_helper + ) + + self.logger.info( + f"{LOG_PREFIX} Processing cycle completed: {results.processed} total, " + f"{results.valid} valid, {results.invalid} invalid, " + f"{results.skipped} skipped" + ) + + except (KeyboardInterrupt, SystemExit): + self.logger.info(f"{LOG_PREFIX} Collector stopping...") + os._exit(0) + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Error during processing cycle: {str(e)}") + raise CollectorProcessingError(f"Processing error: {str(e)}") from e diff --git a/netwitness/src/collector/exception.py b/netwitness/src/collector/exception.py new file mode 100644 index 00000000..9dafb505 --- /dev/null +++ b/netwitness/src/collector/exception.py @@ -0,0 +1,73 @@ +"""Custom exceptions for the collector.""" + + +class CollectorError(Exception): + """Base exception for the collector.""" + + pass + + +class CollectorConfigError(CollectorError): + """Exception raised when there is an error in the collector configuration.""" + + pass + + +class CollectorSetupError(CollectorError): + """Exception raised when there is an error setting up the collector.""" + + pass + + +class CollectorProcessingError(CollectorError): + """Exception raised when there is an error processing data in the collector.""" + + pass + + +class ExpectationHandlerError(CollectorError): + """Exception raised when there is an error in expectation handling.""" + + pass + + +class ExpectationProcessingError(CollectorError): + """Exception raised when there is an error processing expectations.""" + + pass + + +class ExpectationUpdateError(CollectorError): + """Exception raised when there is an error updating expectations.""" + + pass + + +class BulkUpdateError(ExpectationUpdateError): + """Exception raised when there is an error during bulk update operations.""" + + pass + + +class APIError(CollectorError): + """Exception raised when there is an error with API operations.""" + + pass + + +class TracingError(CollectorError): + """Exception raised when there is an error with tracing operations.""" + + pass + + +class TraceSubmissionError(TracingError): + """Exception raised when there is an error submitting traces.""" + + pass + + +class TraceCreationError(TracingError): + """Exception raised when there is an error creating traces.""" + + pass diff --git a/netwitness/src/collector/expectation_handler.py b/netwitness/src/collector/expectation_handler.py new file mode 100644 index 00000000..e8c5a23a --- /dev/null +++ b/netwitness/src/collector/expectation_handler.py @@ -0,0 +1,202 @@ +"""Generic Expectation Handler.""" + +import logging +from typing import Any + +from pyoaev.apis.inject_expectation.model import ( # type: ignore[import-untyped] + DetectionExpectation, + PreventionExpectation, +) +from pyoaev.helpers import OpenAEVDetectionHelper # type: ignore[import-untyped] +from pyoaev.signatures.types import SignatureTypes # type: ignore[import-untyped] + +from .exception import ExpectationHandlerError +from .expectation_service_provider import ExpectationServiceProvider +from .models import ExpectationResult +from .signature_registry import ExpectationHandlerType, get_registry + +LOG_PREFIX = "[CollectorExpectationHandler]" + + +class GenericExpectationHandler: + """Generic expectation handler that delegates to service providers. + + This handler is completely agnostic to the specific use case and + delegates all processing logic to the injected service provider. + """ + + def __init__(self, service_provider: ExpectationServiceProvider) -> None: + """Initialize the generic handler. + + Args: + service_provider: Service provider implementing business logic. + + """ + self.logger = logging.getLogger(__name__) + self.service_provider = service_provider + + self.logger.debug(f"{LOG_PREFIX} Initializing generic expectation handler") + self._register_with_registry() + self.logger.info( + f"{LOG_PREFIX} Generic expectation handler initialized successfully" + ) + + def _register_with_registry(self) -> None: + """Register handler capabilities with the signature registry. + + Registers detection and prevention handlers with the signature registry + for all supported signature types from the service provider. + + Raises: + Exception: If registration with registry fails. + + """ + try: + registry = get_registry() + supported_signatures = self.service_provider.get_supported_signatures() + + registry.register_handler( + handler_type=ExpectationHandlerType.DETECTION, + handler_func=self.handle_expectation, + signature_types=supported_signatures, + ) + + registry.register_handler( + handler_type=ExpectationHandlerType.PREVENTION, + handler_func=self.handle_expectation, + signature_types=supported_signatures, + ) + + self.logger.info( + f"{LOG_PREFIX} Registered handler for {len(supported_signatures)} signature types: {[sig.value for sig in supported_signatures]}" + ) + + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Failed to register handler with registry: {e}" + ) + raise + + def handle_expectation( + self, + expectation: Any, + detection_helper: OpenAEVDetectionHelper, + ) -> ExpectationResult: + """Handle an expectation by delegating to the service provider. + + Args: + expectation: The expectation to process. + detection_helper: OpenAEV detection helper instance. + + Returns: + ExpectationResult containing processing results. + + Raises: + Exception: If expectation handling fails. + + """ + expectation_id = ( + str(expectation.inject_expectation_id) + if hasattr(expectation, "inject_expectation_id") + else "unknown" + ) + + try: + if isinstance(expectation, DetectionExpectation): + self.logger.debug( + f"{LOG_PREFIX} Processing detection expectation: {expectation_id}" + ) + result = self.service_provider.handle_detection_expectation( + expectation, detection_helper + ) + elif isinstance(expectation, PreventionExpectation): + self.logger.debug( + f"{LOG_PREFIX} Processing prevention expectation: {expectation_id}" + ) + result = self.service_provider.handle_prevention_expectation( + expectation, detection_helper + ) + else: + self.logger.warning( + f"{LOG_PREFIX} Unsupported expectation type for {expectation_id}: {type(expectation)}" + ) + result = ExpectationResult( + expectation_id=expectation_id, + is_valid=False, + expectation=expectation, + error_message="Unsupported expectation type", + ) + + return result + + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Error handling expectation {expectation_id}: {e}" + ) + raise + + def handle_batch_expectations( + self, + expectations: list[Any], + detection_helper: OpenAEVDetectionHelper, + ) -> list[ExpectationResult]: + """Handle a batch of expectations by delegating to the service provider. + + Post-processes results to ensure completeness by filling in missing + expectation IDs and expectation objects. + + Args: + expectations: List of expectations to process. + detection_helper: OpenAEV detection helper instance. + + Returns: + List of ExpectationResult objects. + + Raises: + ExpectationHandlerError: If batch processing fails. + + """ + try: + self.logger.info( + f"{LOG_PREFIX} Starting batch processing of {len(expectations)} expectations" + ) + + results = self.service_provider.handle_batch_expectations( + expectations, detection_helper + ) + + # Post-process results to ensure completeness + self.logger.debug(f"{LOG_PREFIX} Post-processing batch results...") + for i, result in enumerate(results): + if result.expectation is None and i < len(expectations): + result.expectation = expectations[i] + if not result.expectation_id and result.expectation: + result.expectation_id = str( + result.expectation.inject_expectation_id + ) + + valid_count = sum(1 for r in results if r.is_valid) + invalid_count = len(results) - valid_count + + self.logger.info( + f"{LOG_PREFIX} Batch processing completed: {valid_count} valid, {invalid_count} invalid out of {len(results)} total" + ) + + return results + + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Batch processing failed: {e}") + raise ExpectationHandlerError(f"Error in batch processing: {e}") from e + + def get_supported_signatures(self) -> list[SignatureTypes]: + """Get supported signature types from service provider. + + Returns: + List of SignatureTypes supported by the service provider. + + """ + signatures = self.service_provider.get_supported_signatures() + self.logger.debug( + f"{LOG_PREFIX} Supported signatures: {[sig.value for sig in signatures]}" + ) + return signatures diff --git a/netwitness/src/collector/expectation_manager.py b/netwitness/src/collector/expectation_manager.py new file mode 100644 index 00000000..a71ef509 --- /dev/null +++ b/netwitness/src/collector/expectation_manager.py @@ -0,0 +1,538 @@ +"""Generic Expectation Manager.""" + +import logging +import time +from datetime import datetime, timedelta +from typing import Any + +from pyoaev.apis.inject_expectation.model import ( # type: ignore[import-untyped] + DetectionExpectation, + PreventionExpectation, +) +from pyoaev.client import OpenAEV # type: ignore[import-untyped] +from pyoaev.helpers import OpenAEVDetectionHelper # type: ignore[import-untyped] +from pyoaev.signatures.types import SignatureTypes # type: ignore[import-untyped] + +from .exception import ( + APIError, + BulkUpdateError, + ExpectationProcessingError, + ExpectationUpdateError, +) +from .expectation_handler import GenericExpectationHandler +from .models import ExpectationResult, ProcessingSummary +from .trace_manager import TraceManager +from .trace_service_provider import TraceServiceProvider + +LOG_PREFIX = "[CollectorExpectationManager]" + +# Constants +FETCH_TIMEOUT_MINUTES = 5 +SLEEP_INTERVAL_SECONDS = 30 +PROGRESS_LOG_INTERVAL = 10 + + +class GenericExpectationManager: + """Generic expectation manager that works with any service provider. + + This manager is completely agnostic to the specific use case and + delegates all processing logic to the injected service providers. + """ + + def __init__( + self, + oaev_api: OpenAEV, + collector_id: str, + expectation_handler: GenericExpectationHandler, + trace_service: TraceServiceProvider | None = None, + ) -> None: + """Initialize generic expectation manager. + + Args: + oaev_api: OpenAEV API client. + collector_id: ID of the collector. + expectation_handler: Handler for processing expectations. + trace_service: Optional service for creating traces. + + Raises: + ValueError: If required parameters are None or empty. + + """ + if not oaev_api: + raise ValueError("oaev_api cannot be None") + if not collector_id: + raise ValueError("collector_id cannot be empty") + if not expectation_handler: + raise ValueError("expectation_handler cannot be None") + + self.logger = logging.getLogger(__name__) + self.oaev_api = oaev_api + self.collector_id = collector_id + self.expectation_handler = expectation_handler + self.trace_manager = TraceManager( + oaev_api=oaev_api, + collector_id=collector_id, + trace_service=trace_service, + ) + + self.logger.info( + f"{LOG_PREFIX} Expectation manager initialized for collector: {collector_id}" + ) + + def process_expectations( + self, detection_helper: OpenAEVDetectionHelper + ) -> ProcessingSummary: + """Process all expectations using the injected handler. + + Fetches expectations from OpenAEV, processes them through the handler, + updates expectations in OpenAEV, and creates traces. + + Args: + detection_helper: OpenAEV detection helper. + + Returns: + ProcessingSummary containing processing results. + + Raises: + ExpectationProcessingError: If processing fails. + + """ + try: + self.logger.info(f"{LOG_PREFIX} Starting expectation processing cycle") + + self.logger.debug(f"{LOG_PREFIX} Fetching expectations from OpenAEV...") + expectations = self._fetch_expectations_with_timeout() + + if not expectations: + self.logger.warning(f"{LOG_PREFIX} No expectations found to process") + return ProcessingSummary(processed=0, valid=0, invalid=0, skipped=0) + + supported_expectations = [ + exp + for exp in expectations + if isinstance(exp, (DetectionExpectation, PreventionExpectation)) + ] + + total_expectations = len(expectations) + supported_count = len(supported_expectations) + skipped_count = total_expectations - supported_count + + self.logger.info( + f"{LOG_PREFIX} Found {total_expectations} total expectations: " + f"{supported_count} supported, {skipped_count} skipped" + ) + + if skipped_count > 0: + self.logger.debug( + f"{LOG_PREFIX} Skipped {skipped_count} unsupported expectation types" + ) + + self.logger.debug( + f"{LOG_PREFIX} Processing expectations through handler..." + ) + results = self.expectation_handler.handle_batch_expectations( + supported_expectations, detection_helper + ) + + self.logger.debug(f"{LOG_PREFIX} Updating expectations in OpenAEV...") + self._bulk_update_expectations(results) + + self.logger.debug(f"{LOG_PREFIX} Creating and submitting traces...") + self.trace_manager.create_and_submit_traces(results) + + valid_count = sum(1 for r in results if r.is_valid) + invalid_count = len(results) - valid_count + + summary = ProcessingSummary( + processed=len(results), + valid=valid_count, + invalid=invalid_count, + skipped=skipped_count, + ) + + self.logger.info( + f"{LOG_PREFIX} Expectation processing: processed {total_expectations} items -> {len(results)} results" + ) + + self.logger.info( + f"{LOG_PREFIX} Processing cycle completed: {valid_count} valid, " + f"{invalid_count} invalid, {skipped_count} skipped" + ) + + return summary + + except (BulkUpdateError, APIError) as e: + self.logger.error(f"{LOG_PREFIX} API operation failed: {e}") + raise ExpectationProcessingError(f"API error during processing: {e}") from e + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Unexpected error during processing: {e}") + raise ExpectationProcessingError( + f"Unexpected error processing expectations: {e}" + ) from e + + def _bulk_update_expectations(self, results: list[ExpectationResult]) -> None: + """Bulk update expectations in OpenAEV. + + Prepares bulk data from results and attempts to update expectations + using the OpenAEV bulk update API. + + Args: + results: List of ExpectationResult objects to update. + + Raises: + BulkUpdateError: If bulk update fails. + + """ + if not results: + self.logger.debug( + f"{LOG_PREFIX} No results to update, skipping bulk update" + ) + return + + try: + self.logger.debug( + f"{LOG_PREFIX} Preparing bulk data for {len(results)} results..." + ) + bulk_data = self._prepare_bulk_data(results) + + if bulk_data: + self.logger.debug( + f"{LOG_PREFIX} Attempting bulk update of {len(bulk_data)} expectations..." + ) + self._attempt_bulk_update(bulk_data) + else: + self.logger.debug( + f"{LOG_PREFIX} No valid bulk data prepared, skipping update" + ) + + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Bulk update failed: {e}") + raise BulkUpdateError(f"Error in bulk update: {e}") from e + + def _prepare_bulk_data( + self, results: list[ExpectationResult] + ) -> dict[str, dict[str, Any]]: + """Prepare bulk data from results. + + Transforms ExpectationResult objects into dictionary format + required by the OpenAEV bulk update API. + + Args: + results: List of ExpectationResult objects. + + Returns: + Dictionary mapping expectation IDs to update data. + + """ + bulk_data = {} + skipped_count = 0 + + for result in results: + try: + expectation_id = result.expectation_id + if not expectation_id: + skipped_count += 1 + self.logger.debug( + f"{LOG_PREFIX} Skipping result without expectation_id" + ) + continue + + is_valid = result.is_valid + expectation = result.expectation + if expectation: + result_text = self._get_result_text(expectation, is_valid) + bulk_data[expectation_id] = { + "collector_id": self.collector_id, + "result": result_text, + "is_success": is_valid, + } + self.logger.debug( + f"{LOG_PREFIX} Prepared update for expectation {expectation_id}: " + f"result='{result_text}', success={is_valid}" + ) + else: + skipped_count += 1 + self.logger.debug( + f"{LOG_PREFIX} Skipping result {expectation_id} without expectation object" + ) + except Exception as e: + skipped_count += 1 + self.logger.warning(f"{LOG_PREFIX} Error processing result: {e}") + + if skipped_count > 0: + self.logger.debug( + f"{LOG_PREFIX} Skipped {skipped_count} results during bulk data preparation" + ) + return bulk_data + + def _get_result_text( + self, expectation: DetectionExpectation | PreventionExpectation, is_valid: bool + ) -> str: + """Get result text based on expectation type and validity. + + Args: + expectation: The expectation object (Detection or Prevention). + is_valid: Whether the expectation was successfully validated. + + Returns: + Human-readable result text for the expectation. + + """ + try: + base_text = ( + "Detected" + if isinstance(expectation, DetectionExpectation) + else "Prevented" + ) + result_text = base_text if is_valid else f"Not {base_text}" + + self.logger.debug( + f"{LOG_PREFIX} Generated result text: '{result_text}' for {type(expectation).__name__}" + ) + return result_text + except Exception as e: + self.logger.warning(f"{LOG_PREFIX} Error generating result text: {e}") + return "Unknown" + + def _attempt_bulk_update(self, bulk_data: dict[str, dict[str, Any]]) -> None: + """Attempt bulk update with fallback to individual updates. + + Tries to use the bulk update API first, then falls back to individual + updates if the bulk operation fails. + + Args: + bulk_data: Dictionary of expectation updates to apply. + + Raises: + BulkUpdateError: If both bulk and individual updates fail. + + """ + try: + self.logger.debug(f"{LOG_PREFIX} Attempting bulk update via OpenAEV API...") + self.oaev_api.inject_expectation.bulk_update( + inject_expectation_input_by_id=bulk_data + ) + self.logger.info( + f"{LOG_PREFIX} Successfully bulk updated {len(bulk_data)} expectations" + ) + + except Exception as bulk_error: + self.logger.warning( + f"{LOG_PREFIX} Bulk update failed, falling back to individual updates: {bulk_error}" + ) + try: + self._fallback_individual_updates(bulk_data) + except Exception as fallback_error: + raise BulkUpdateError( + f"Both bulk and individual updates failed: {fallback_error}" + ) from fallback_error + + def _fallback_individual_updates( + self, bulk_data: dict[str, dict[str, Any]] + ) -> None: + """Fallback to individual expectation updates. + + Updates expectations one by one when bulk update fails. + + Args: + bulk_data: Dictionary of expectation updates to apply. + + """ + self.logger.info( + f"{LOG_PREFIX} Attempting individual updates for {len(bulk_data)} expectations" + ) + success_count = 0 + error_count = 0 + + for expectation_id, update_data in bulk_data.items(): + try: + self._update_expectation(expectation_id, update_data) + success_count += 1 + except (APIError, ExpectationUpdateError) as e: + error_count += 1 + self.logger.error( + f"{LOG_PREFIX} Failed to update expectation {expectation_id}: {e}" + ) + except Exception as e: + error_count += 1 + self.logger.error( + f"{LOG_PREFIX} Unexpected error updating expectation {expectation_id}: {e}" + ) + + self.logger.info( + f"{LOG_PREFIX} Individual updates completed: {success_count} successful, {error_count} failed" + ) + + def _update_expectation( + self, expectation_id: str, update_data: dict[str, Any] + ) -> None: + """Update a single expectation. + + Args: + expectation_id: ID of the expectation to update. + update_data: Update data to apply to the expectation. + + Raises: + ExpectationUpdateError: If the update fails. + + """ + self.logger.debug( + f"{LOG_PREFIX} Updating individual expectation: {expectation_id}" + ) + + try: + self.oaev_api.inject_expectation.update( + inject_expectation_id=expectation_id, + inject_expectation=update_data, + ) + self.logger.debug( + f"{LOG_PREFIX} Successfully updated expectation {expectation_id}" + ) + + except Exception as individual_error: + raise ExpectationUpdateError( + f"Failed to update expectation {expectation_id}: {individual_error}" + ) from individual_error + + def _fetch_expectations_with_timeout( + self, + ) -> list[DetectionExpectation | PreventionExpectation]: + """Keep fetching expectations until we get ones with end_date or 5min timeout. + + Continuously fetches expectations from OpenAEV until either: + 1. Expectations with end_date signatures are found, or + 2. The 5-minute timeout is reached. + + Returns: + List of expectations that meet the criteria. + + """ + start_time = datetime.utcnow() + timeout = timedelta(minutes=FETCH_TIMEOUT_MINUTES) + attempt_count = 0 + + self.logger.debug( + f"{LOG_PREFIX} Fetching expectations for collector: {self.collector_id}" + ) + + while (datetime.utcnow() - start_time) < timeout: + attempt_count += 1 + elapsed = datetime.utcnow() - start_time + + self.logger.debug( + f"{LOG_PREFIX} Expectation fetch attempt {attempt_count} (elapsed: {elapsed.total_seconds():.1f}s)" + ) + + try: + expectations = ( + self.oaev_api.inject_expectation.expectations_models_for_source( + source_id=self.collector_id + ) + ) + except Exception as e: + self.logger.warning( + f"{LOG_PREFIX} Error fetching expectations: {e}, retrying..." + ) + self._interruptible_sleep(SLEEP_INTERVAL_SECONDS) + continue + + if not expectations: + self.logger.debug( + f"{LOG_PREFIX} No expectations found, waiting {SLEEP_INTERVAL_SECONDS}s before retry..." + ) + self._interruptible_sleep(SLEEP_INTERVAL_SECONDS) + continue + + self.logger.debug( + f"{LOG_PREFIX} Found {len(expectations)} expectations, checking for end_date..." + ) + + has_end_date = self._check_for_end_date(expectations) + + if has_end_date: + self.logger.info( + f"{LOG_PREFIX} Found {len(expectations)} expectations with end_date after {attempt_count} attempts" + ) + return expectations # type: ignore[no-any-return] + + self.logger.debug( + f"{LOG_PREFIX} No end_date found in expectations, waiting {SLEEP_INTERVAL_SECONDS}s before retry..." + ) + self._interruptible_sleep(SLEEP_INTERVAL_SECONDS) + + self.logger.warning( + f"{LOG_PREFIX} Timeout reached after {attempt_count} attempts ({timeout.total_seconds()}s)" + ) + + try: + final_expectations = ( + self.oaev_api.inject_expectation.expectations_models_for_source( + source_id=self.collector_id + ) + ) + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Final expectations fetch failed: {e}") + return [] + + if final_expectations: + self.logger.info( + f"{LOG_PREFIX} Processing {len(final_expectations)} expectations without end_date requirement" + ) + return final_expectations or [] + + def _check_for_end_date( + self, expectations: list[DetectionExpectation | PreventionExpectation] + ) -> bool: + """Check if any expectation has end_date signature. + + Args: + expectations: List of expectations to check. + + Returns: + True if any expectation contains an end_date signature. + + """ + try: + for expectation in expectations: + if hasattr(expectation, "inject_expectation_signatures"): + for signature in expectation.inject_expectation_signatures: + if signature.type == SignatureTypes.SIG_TYPE_END_DATE: + return True + return False + except Exception as e: + self.logger.debug(f"{LOG_PREFIX} Error checking for end_date: {e}") + return False + + def _interruptible_sleep(self, seconds: int) -> None: + """Sleep for the given seconds, but check for interrupts every second. + + Provides interruptible sleep that responds to KeyboardInterrupt (Ctrl+C) + and logs progress for longer sleep periods. + + Args: + seconds: Number of seconds to sleep. + + """ + if seconds <= 0: + return + + self.logger.debug( + f"{LOG_PREFIX} Sleeping for {seconds} seconds (interruptible)..." + ) + + for i in range(seconds): + try: + time.sleep(1) + + if ( + seconds >= SLEEP_INTERVAL_SECONDS + and (i + 1) % PROGRESS_LOG_INTERVAL == 0 + ): + self.logger.debug( + f"{LOG_PREFIX} Sleep progress: {i + 1}/{seconds} seconds" + ) + except KeyboardInterrupt: + import sys + + self.logger.info(f"{LOG_PREFIX} Sleep interrupted by user (Ctrl+C)") + sys.exit(0) diff --git a/netwitness/src/collector/expectation_service_provider.py b/netwitness/src/collector/expectation_service_provider.py new file mode 100644 index 00000000..01351681 --- /dev/null +++ b/netwitness/src/collector/expectation_service_provider.py @@ -0,0 +1,74 @@ +"""Protocol defining the interface for expectation service providers.""" + +from typing import Any, Protocol + +from pyoaev.apis.inject_expectation.model import ( # type: ignore[import-untyped] + DetectionExpectation, + PreventionExpectation, +) +from pyoaev.helpers import OpenAEVDetectionHelper # type: ignore[import-untyped] +from pyoaev.signatures.types import SignatureTypes # type: ignore[import-untyped] + +from .models import ExpectationResult + + +class ExpectationServiceProvider(Protocol): + """Protocol defining the interface for expectation service providers.""" + + def get_supported_signatures(self) -> list[SignatureTypes]: + """Get list of signature types this provider supports. + + Returns: + List of SignatureTypes that this provider can handle. + + """ + ... + + def handle_detection_expectation( + self, + expectation: DetectionExpectation, + detection_helper: OpenAEVDetectionHelper, + ) -> ExpectationResult: + """Handle a detection expectation. + + Args: + expectation: The detection expectation to process. + detection_helper: OpenAEV detection helper instance. + + Returns: + ExpectationResult containing the processing outcome. + + """ + ... + + def handle_prevention_expectation( + self, + expectation: PreventionExpectation, + detection_helper: OpenAEVDetectionHelper, + ) -> ExpectationResult: + """Handle a prevention expectation. + + Args: + expectation: The prevention expectation to process. + detection_helper: OpenAEV detection helper instance. + + Returns: + ExpectationResult containing the processing outcome. + + """ + ... + + def handle_batch_expectations( + self, expectations: list[Any], detection_helper: OpenAEVDetectionHelper + ) -> list[ExpectationResult]: + """Handle a batch of expectations efficiently. + + Args: + expectations: List of expectations to process in batch. + detection_helper: OpenAEV detection helper instance. + + Returns: + List of ExpectationResult objects for each processed expectation. + + """ + ... diff --git a/netwitness/src/collector/models.py b/netwitness/src/collector/models.py new file mode 100644 index 00000000..316ae118 --- /dev/null +++ b/netwitness/src/collector/models.py @@ -0,0 +1,168 @@ +"""Pydantic models for collector data structures.""" + +from typing import Any + +from pydantic import BaseModel, Field, field_validator + + +class ExpectationTrace(BaseModel): + """Pydantic model for expectation trace data. + + This model represents the structure of trace data that gets sent to the + OpenAEV API for expectation tracking and validation. + """ + + inject_expectation_trace_expectation: str = Field( + description="The expectation ID this trace is associated with" + ) + inject_expectation_trace_source_id: str = Field( + description="The collector/source ID that generated this trace" + ) + inject_expectation_trace_alert_name: str = Field( + description="Name of the alert that was matched" + ) + inject_expectation_trace_alert_link: str = Field( + description="Link to the alert in the source system" + ) + inject_expectation_trace_date: str = Field( + description="Date when the trace was created (ISO format string)" + ) + + @field_validator("inject_expectation_trace_expectation") + @classmethod + def expectation_must_not_be_empty(cls, v: str) -> str: + """Validate that expectation ID is not empty. + + Args: + v: The expectation ID value to validate. + + Returns: + The trimmed expectation ID. + + Raises: + ValueError: If the expectation ID is empty or whitespace only. + + """ + if not v or not v.strip(): + raise ValueError("Expectation ID cannot be empty") + return v.strip() + + @field_validator("inject_expectation_trace_source_id") + @classmethod + def source_id_must_not_be_empty(cls, v: str) -> str: + """Validate that source ID is not empty. + + Args: + v: The source ID value to validate. + + Returns: + The trimmed source ID. + + Raises: + ValueError: If the source ID is empty or whitespace only. + + """ + if not v or not v.strip(): + raise ValueError("Source ID cannot be empty") + return v.strip() + + @field_validator("inject_expectation_trace_alert_name") + @classmethod + def alert_name_must_not_be_empty(cls, v: str) -> str: + """Validate that alert name is not empty. + + Args: + v: The alert name value to validate. + + Returns: + The trimmed alert name. + + Raises: + ValueError: If the alert name is empty or whitespace only. + + """ + if not v or not v.strip(): + raise ValueError("Alert name cannot be empty") + return v.strip() + + @field_validator("inject_expectation_trace_alert_link") + @classmethod + def alert_link_must_not_be_empty(cls, v: str) -> str: + """Validate that alert link is not empty. + + Args: + v: The alert link value to validate. + + Returns: + The trimmed alert link. + + Raises: + ValueError: If the alert link is empty or whitespace only. + + """ + if not v or not v.strip(): + raise ValueError("Alert link cannot be empty") + return v.strip() + + @field_validator("inject_expectation_trace_date") + @classmethod + def date_must_not_be_empty(cls, v: str) -> str: + """Validate that date is not empty. + + Args: + v: The date value to validate. + + Returns: + The trimmed date string. + + Raises: + ValueError: If the date is empty or whitespace only. + + """ + if not v or not v.strip(): + raise ValueError("Trace date cannot be empty") + return v.strip() + + def to_api_dict(self) -> dict[str, str]: + """Convert the model to a dictionary suitable for API submission. + + This method ensures all values are strings as expected by the API, + replacing the manual sanitization logic in the expectation manager. + + Returns: + Dict with all values converted to strings. + + """ + return { + key: str(value) if value is not None else "" + for key, value in self.model_dump().items() + } + + +class ExpectationResult(BaseModel): + """Model for expectation processing results.""" + + expectation_id: str = Field(..., description="ID of the processed expectation") + is_valid: bool = Field(..., description="Whether the expectation was validated") + expectation: Any | None = Field(None, description="The original expectation object") + matched_alerts: list[dict[str, Any]] | None = Field( + None, description="List of alerts that matched this expectation" + ) + error_message: str | None = Field( + None, description="Error message if processing failed" + ) + processing_time: float | None = Field( + None, description="Time taken to process this expectation in seconds" + ) + + +class ProcessingSummary(BaseModel): + """Model for expectation processing summary.""" + + processed: int = Field(..., description="Total number of expectations processed") + valid: int = Field(..., description="Number of valid expectations") + invalid: int = Field(..., description="Number of invalid expectations") + skipped: int = Field(..., description="Number of skipped expectations") + total_processing_time: float | None = Field( + None, description="Total processing time in seconds" + ) diff --git a/netwitness/src/collector/signature_registry.py b/netwitness/src/collector/signature_registry.py new file mode 100644 index 00000000..bbaa3d08 --- /dev/null +++ b/netwitness/src/collector/signature_registry.py @@ -0,0 +1,159 @@ +"""Signature Registry for dynamic expectation handling.""" + +from enum import Enum +from typing import Any, Callable + +from pyoaev.signatures.types import SignatureTypes # type: ignore[import-untyped] + +from .models import ExpectationResult + + +class ExpectationHandlerType(Enum): + """Types of expectation handlers.""" + + DETECTION = "detection" + PREVENTION = "prevention" + + +class SignatureRegistry: + """Simple registry for managing signature subscriptions and expectation handlers. + + This registry allows components to dynamically register: + - Which signature types they're interested in + - How to handle different types of expectations + + Keeps it simple by using basic data structures and clear interfaces. + """ + + def __init__(self) -> None: + """Initialize the registry. + + Creates empty data structures for managing signature subscriptions + and expectation handlers. + """ + self._subscribed_signatures: set[SignatureTypes] = set() + self._handlers: dict[ + ExpectationHandlerType, Callable[[Any, Any], ExpectationResult] + ] = {} + self._handler_signatures: dict[ExpectationHandlerType, set[SignatureTypes]] = {} + + def subscribe_to_signatures(self, signature_types: list[SignatureTypes]) -> None: + """Subscribe to specific signature types. + + Args: + signature_types: List of signature types to subscribe to. + + """ + self._subscribed_signatures.update(signature_types) + + def register_handler( + self, + handler_type: ExpectationHandlerType, + handler_func: Callable[[Any, Any], ExpectationResult], + signature_types: list[SignatureTypes], + ) -> None: + """Register an expectation handler for specific signature types. + + Args: + handler_type: Type of handler (detection/prevention). + handler_func: Function to handle expectations. + signature_types: Signature types this handler supports. + + """ + self._handlers[handler_type] = handler_func + self._handler_signatures[handler_type] = set(signature_types) + + self.subscribe_to_signatures(signature_types) + + def get_subscribed_signatures(self) -> list[SignatureTypes]: + """Get all subscribed signature types. + + Returns: + List of subscribed signature types. + + """ + return list(self._subscribed_signatures) + + def has_handler_for_signatures( + self, + handler_type: ExpectationHandlerType, + signature_types: list[SignatureTypes], + ) -> bool: + """Check if a handler supports the given signature types. + + Args: + handler_type: Type of handler to check. + signature_types: Signature types to check. + + Returns: + True if handler supports any of the signature types. + + """ + if handler_type not in self._handler_signatures: + return False + + handler_sigs = self._handler_signatures[handler_type] + return any(sig in handler_sigs for sig in signature_types) + + def get_handler( + self, handler_type: ExpectationHandlerType + ) -> Callable[[Any, Any], ExpectationResult]: + """Get handler function for the given type. + + Args: + handler_type: Type of handler to retrieve. + + Returns: + Handler function. + + Raises: + KeyError: If no handler registered for the type. + + """ + if handler_type not in self._handlers: + raise KeyError(f"No handler registered for type: {handler_type}") + return self._handlers[handler_type] + + def is_signature_supported(self, signature_type: SignatureTypes) -> bool: + """Check if a signature type is supported by any registered handler. + + Args: + signature_type: Signature type to check. + + Returns: + True if supported. + + """ + return signature_type in self._subscribed_signatures + + def get_handler_types(self) -> list[ExpectationHandlerType]: + """Get all registered handler types. + + Returns: + List of registered handler types. + + """ + return list(self._handlers.keys()) + + def clear(self) -> None: + """Clear all registrations. + + Removes all signature subscriptions and handler registrations. + Useful for testing and cleanup scenarios. + """ + self._subscribed_signatures.clear() + self._handlers.clear() + self._handler_signatures.clear() + + +_registry = SignatureRegistry() + + +def get_registry() -> SignatureRegistry: + """Get the global signature registry instance. + + Returns: + The global registry instance. + + """ + return _registry diff --git a/netwitness/src/collector/trace_manager.py b/netwitness/src/collector/trace_manager.py new file mode 100644 index 00000000..a05a5b00 --- /dev/null +++ b/netwitness/src/collector/trace_manager.py @@ -0,0 +1,201 @@ +"""Trace Manager for handling expectation traces. + +This module provides the TraceManager class which handles all trace-related operations +for expectation processing. It separates trace concerns from the main expectation +""" + +import logging +from typing import Any + +from pyoaev.client import OpenAEV # type: ignore[import-untyped] + +from .exception import TraceCreationError, TraceSubmissionError, TracingError +from .models import ExpectationResult +from .trace_service_provider import TraceServiceProvider + +LOG_PREFIX = "[CollectorTraceManager]" + + +class TraceManager: + """Manages trace creation and submission for expectations. + + This manager handles all trace-related operations, including creating traces + from expectation results and submitting them to the OpenAEV API. + """ + + def __init__( + self, + oaev_api: OpenAEV, + collector_id: str, + trace_service: TraceServiceProvider | None = None, + ) -> None: + """Initialize trace manager. + + Args: + oaev_api: OpenAEV API client. + collector_id: ID of the collector. + trace_service: Service for creating traces from results. + + """ + self.logger = logging.getLogger(__name__) + self.oaev_api = oaev_api + self.collector_id = collector_id + self.trace_service = trace_service + + self.logger.info( + f"{LOG_PREFIX} Trace manager initialized for collector: {collector_id}" + ) + if trace_service: + self.logger.debug( + f"{LOG_PREFIX} Trace service available for trace creation" + ) + else: + self.logger.debug( + f"{LOG_PREFIX} No trace service provided - traces will be skipped" + ) + + def create_and_submit_traces(self, results: list[ExpectationResult]) -> None: + """Create and submit traces from expectation results. + + Creates traces from the provided expectation results using the trace service + and submits them to the OpenAEV API. + + Args: + results: List of ExpectationResult objects. + + Raises: + TracingError: If trace creation or submission fails. + + """ + try: + if not self.trace_service: + self.logger.debug( + f"{LOG_PREFIX} No trace service provided, skipping trace creation" + ) + return + + self.logger.debug( + f"{LOG_PREFIX} Creating traces from {len(results)} expectation results..." + ) + traces = self.trace_service.create_traces_from_results( + results, self.collector_id + ) + + if not traces: + self.logger.info(f"{LOG_PREFIX} No traces created from results") + return + + self.logger.info( + f"{LOG_PREFIX} Created {len(traces)} traces, submitting to OpenAEV..." + ) + self._submit_traces(traces) + + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Error creating and submitting traces: {e} (Context: results_count={len(results)}, collector_id={self.collector_id})" + ) + raise TracingError(f"Error creating and submitting traces: {e}") from e + + def _submit_traces(self, traces: list[Any]) -> None: + """Submit traces to the OpenAEV API. + + Converts traces to API format and submits them using bulk creation. + Falls back to individual creation if bulk submission fails. + + Args: + traces: List of trace objects to submit. + + Raises: + TraceSubmissionError: If trace submission fails. + + """ + try: + self.logger.debug(f"{LOG_PREFIX} Converting traces to API format...") + trace_dicts = [trace.to_api_dict() for trace in traces] + + if not trace_dicts: + self.logger.warning( + f"{LOG_PREFIX} No trace dictionaries generated from traces" + ) + return + + self.logger.debug( + f"{LOG_PREFIX} Submitting {len(trace_dicts)} trace dictionaries to OpenAEV" + ) + self.logger.debug( + f"{LOG_PREFIX} Trace data preview: {trace_dicts[:2] if len(trace_dicts) > 2 else trace_dicts}" + ) + + response = self.oaev_api.inject_expectation_trace.bulk_create( + payload={"expectation_traces": trace_dicts} + ) + + self.logger.info( + f"{LOG_PREFIX} Successfully created {len(trace_dicts)} expectation traces" + ) + self.logger.debug(f"{LOG_PREFIX} OpenAEV response: {response}") + + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Bulk trace submission failed: {e}") + try: + self.logger.info( + f"{LOG_PREFIX} Attempting individual trace creation as fallback..." + ) + self._fallback_individual_trace_creation(traces) + except TraceCreationError as fallback_error: + self.logger.error( + f"{LOG_PREFIX} Fallback trace creation also failed: {fallback_error}" + ) + raise TraceSubmissionError(f"Error submitting traces: {e}") from e + + def _fallback_individual_trace_creation(self, traces: list[Any]) -> None: + """Fallback method to create traces individually if bulk creation fails. + + Creates traces one by one when bulk creation fails, providing + resilience for trace submission. + + Args: + traces: List of trace objects to create individually. + + Raises: + TraceCreationError: If all individual trace creations fail. + + """ + try: + self.logger.info( + f"{LOG_PREFIX} Creating {len(traces)} traces individually as fallback" + ) + success_count = 0 + error_count = 0 + + for i, trace in enumerate(traces, 1): + try: + self.logger.debug( + f"{LOG_PREFIX} Creating individual trace {i}/{len(traces)}" + ) + r = self.oaev_api.inject_expectation_trace.create( + trace.to_api_dict() + ) + success_count += 1 + self.logger.debug( + f"{LOG_PREFIX} Individual trace {i} created successfully" + ) + self.logger.debug(f"{LOG_PREFIX} single Response: {r}") + except Exception as individual_error: + error_count += 1 + self.logger.error( + f"{LOG_PREFIX} Failed to create individual trace {i}: {individual_error}" + ) + + self.logger.info( + f"{LOG_PREFIX} Individual trace creation completed: {success_count} successful, {error_count} failed" + ) + + if success_count == 0: + raise TraceCreationError("All individual trace creations failed") + + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Error in fallback trace creation: {e} (Context: traces_count={len(traces)}, success_count={success_count})" + ) + raise TraceCreationError(f"Error in fallback trace creation: {e}") from e diff --git a/netwitness/src/collector/trace_service_provider.py b/netwitness/src/collector/trace_service_provider.py new file mode 100644 index 00000000..99f229be --- /dev/null +++ b/netwitness/src/collector/trace_service_provider.py @@ -0,0 +1,24 @@ +"""Protocol for trace creation services.""" + +from typing import Protocol + +from .models import ExpectationResult, ExpectationTrace + + +class TraceServiceProvider(Protocol): + """Protocol for trace creation services.""" + + def create_traces_from_results( + self, results: list[ExpectationResult], collector_id: str + ) -> list[ExpectationTrace]: + """Create trace data from processing results. + + Args: + results: List of ExpectationResult objects to create traces from. + collector_id: ID of the collector creating the traces. + + Returns: + List of ExpectationTrace objects for successful expectations. + + """ + ... diff --git a/netwitness/src/config.yml.sample b/netwitness/src/config.yml.sample new file mode 100644 index 00000000..a11e3685 --- /dev/null +++ b/netwitness/src/config.yml.sample @@ -0,0 +1,24 @@ +openaev: + url: "http://change.me" + token: "ChangeMe" +# tenant_id: "ChangeMe" + +collector: + id: "ChangeMe" + name: "NetWitness" + period: 'PT1M' + log_level: "error" + +netwitness: + base_url: "https://netwitness.company.com:50103" + # Authentication: provide a username/password pair (Core SDK, primary) + # OR a bearer token (NetWitness Platform API). + username: "ChangeMe" + password: "ChangeMe" + # token: "ChangeMe" + max_results: 100 + # console_url: "https://netwitness.company.com" + verify_ssl: true + time_window: "PT1H" + offset: "PT30S" + max_retry: 3 diff --git a/netwitness/src/img/netwitness-logo.png b/netwitness/src/img/netwitness-logo.png new file mode 100644 index 00000000..127434b5 Binary files /dev/null and b/netwitness/src/img/netwitness-logo.png differ diff --git a/netwitness/src/models/__init__.py b/netwitness/src/models/__init__.py new file mode 100644 index 00000000..167e5d71 --- /dev/null +++ b/netwitness/src/models/__init__.py @@ -0,0 +1,3 @@ +from src.models.configs.config_loader import ConfigLoader + +__all__ = ["ConfigLoader"] diff --git a/netwitness/src/models/configs/__init__.py b/netwitness/src/models/configs/__init__.py new file mode 100644 index 00000000..65d61d66 --- /dev/null +++ b/netwitness/src/models/configs/__init__.py @@ -0,0 +1,13 @@ +from src.models.configs.base_settings import ConfigBaseSettings +from src.models.configs.collector_configs import ( + _ConfigLoaderCollector, + _ConfigLoaderOAEV, +) +from src.models.configs.netwitness_configs import _ConfigLoaderNetWitness + +__all__ = [ + "ConfigBaseSettings", + "_ConfigLoaderCollector", + "_ConfigLoaderOAEV", + "_ConfigLoaderNetWitness", +] diff --git a/netwitness/src/models/configs/base_settings.py b/netwitness/src/models/configs/base_settings.py new file mode 100644 index 00000000..a1159675 --- /dev/null +++ b/netwitness/src/models/configs/base_settings.py @@ -0,0 +1,23 @@ +"""Base class for global config models.""" + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class ConfigBaseSettings(BaseSettings): + """Base class for global config models. + + Provides common configuration settings and prevents attributes from being + modified after initialization by using frozen=True in the model config. + """ + + model_config = SettingsConfigDict( + env_nested_delimiter="_", + env_nested_max_split=1, + frozen=True, + str_strip_whitespace=True, + str_min_length=1, + extra="ignore", + # Allow both alias and field name for input + validate_by_name=True, + validate_by_alias=True, + ) diff --git a/netwitness/src/models/configs/collector_configs.py b/netwitness/src/models/configs/collector_configs.py new file mode 100644 index 00000000..f1142821 --- /dev/null +++ b/netwitness/src/models/configs/collector_configs.py @@ -0,0 +1,72 @@ +"""Base class for global config models.""" + +from datetime import timedelta +from typing import Annotated, Literal +from uuid import UUID + +from pydantic import Field, HttpUrl, PlainSerializer +from src.models.configs import ConfigBaseSettings + +LogLevelToLower = Annotated[ + Literal["debug", "info", "warn", "error"], + PlainSerializer(lambda v: "".join(v), return_type=str), +] + +HttpUrlToString = Annotated[HttpUrl, PlainSerializer(str, return_type=str)] +TimedeltaInSeconds = Annotated[ + timedelta, PlainSerializer(lambda v: int(v.total_seconds()), return_type=int) +] + + +class _ConfigLoaderOAEV(ConfigBaseSettings): + """OpenAEV/OpenAEV platform configuration settings. + + Contains URL and authentication token for connecting to the OpenAEV platform. + """ + + url: HttpUrlToString = Field( + alias="OPENAEV_URL", + description="The OpenAEV platform URL.", + ) + token: str = Field( + alias="OPENAEV_TOKEN", + description="The token for the OpenAEV platform.", + ) + tenant_id: UUID | None = Field( + default=None, + alias="OPENAEV_TENANT_ID", + description="Identifier of the tenant within the OpenAEV platform. Used in multi-tenant environments to scope " + "API requests and ensure data isolation between different tenants.", + ) + + +class _ConfigLoaderCollector(ConfigBaseSettings): + """Base collector configuration settings. + + Contains common collector settings including identification, logging, + scheduling, and platform information. + """ + + id: str + name: str + + platform: str | None = Field( + alias="COLLECTOR_PLATFORM", + default="SIEM", + description="Platform type for the collector (e.g., EDR, SIEM, etc.).", + ) + log_level: LogLevelToLower | None = Field( + alias="COLLECTOR_LOG_LEVEL", + default="error", + description="Determines the verbosity of the logs.", + ) + period: timedelta | None = Field( + alias="COLLECTOR_PERIOD", + default=timedelta(minutes=1), + description="Duration between two scheduled runs of the collector (ISO 8601 format).", + ) + icon_filepath: str | None = Field( + alias="COLLECTOR_ICON_FILEPATH", + default="src/img/netwitness-logo.png", + description="Path to the icon file of the collector.", + ) diff --git a/netwitness/src/models/configs/config_loader.py b/netwitness/src/models/configs/config_loader.py new file mode 100644 index 00000000..3af354eb --- /dev/null +++ b/netwitness/src/models/configs/config_loader.py @@ -0,0 +1,169 @@ +"""Base class for global config models.""" + +from pathlib import Path + +from pydantic import Field +from pydantic_settings import ( + BaseSettings, + DotEnvSettingsSource, + EnvSettingsSource, + PydanticBaseSettingsSource, + YamlConfigSettingsSource, +) +from pyoaev.configuration import Configuration +from src.models.configs import ( + ConfigBaseSettings, + _ConfigLoaderCollector, + _ConfigLoaderNetWitness, + _ConfigLoaderOAEV, +) + + +class ConfigLoaderCollector(_ConfigLoaderCollector): + """Basic collector configurations. + + Extends the base collector configuration with specific default values + for the NetWitness collector instance. + """ + + id: str = Field( + alias="Collector_ID", + default="netwitness--0b13e3f7-5c9e-46f5-acc4-33032e9b4921", + description="A unique UUIDv4 identifier for this collector instance.", + ) + name: str = Field( + alias="Collector_NAME", + default="NetWitness", + description="Name of the collector.", + ) + + +class ConfigLoader(ConfigBaseSettings): + """Configuration loader for the collector. + + Main configuration class that combines OpenAEV, collector, and NetWitness + settings. Supports loading from YAML files, environment variables, and + provides methods for converting to daemon-compatible format. + """ + + openaev: _ConfigLoaderOAEV = Field( + default_factory=_ConfigLoaderOAEV, # type: ignore[unused-ignore] + description="OpenAEV configurations.", + ) + collector: ConfigLoaderCollector = Field( + default_factory=ConfigLoaderCollector, # type: ignore[unused-ignore] + description="Collector configurations.", + ) + netwitness: _ConfigLoaderNetWitness = Field( + default_factory=_ConfigLoaderNetWitness, + description="NetWitness configurations.", + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource]: + """Pydantic settings customisation sources. + + Defines the priority order for loading configuration settings: + 1. .env file (if exists) + 2. config.yml file (if exists) + 3. Environment variables (fallback) + + Args: + settings_cls: The settings class being configured. + init_settings: Initialization settings source. + env_settings: Environment variables settings source. + dotenv_settings: .env file settings source. + file_secret_settings: File secrets settings source. + + Returns: + Tuple containing the selected settings source. + + """ + env_path = Path(__file__).parents[2] / ".env" + yaml_path = Path(__file__).parents[2] / "config.yml" + + if env_path.exists(): + return ( + DotEnvSettingsSource( + settings_cls, + env_file=env_path, + env_ignore_empty=True, + env_file_encoding="utf-8", + ), + ) + elif yaml_path.exists(): + return ( + YamlConfigSettingsSource( + settings_cls, + yaml_file=yaml_path, + yaml_file_encoding="utf-8", + ), + ) + else: + return ( + EnvSettingsSource( + settings_cls, + env_ignore_empty=True, + ), + ) + + def to_daemon_config(self) -> Configuration: + """Convert the nested configuration to the list of config hints expected by BaseDaemon. + + Flattens the nested configuration structure into a dictionary format + that can be consumed by the collector daemon infrastructure. + + Returns: + Dictionary with flattened configuration keys and values suitable + for daemon initialization. + + """ + return Configuration( + config_hints={ + # OpenAEV configuration (flattened) + "openaev_url": {"data": str(self.openaev.url)}, + "openaev_token": {"data": self.openaev.token}, + "openaev_tenant_id": {"data": self.openaev.tenant_id}, + # Collector configuration (flattened) + "collector_id": {"data": self.collector.id}, + "collector_name": {"data": self.collector.name}, + "collector_platform": {"data": self.collector.platform}, + "collector_log_level": {"data": self.collector.log_level}, + "collector_period": { + "data": int(self.collector.period.total_seconds()), # type: ignore[union-attr] + "is_number": True, + }, + "collector_icon_filepath": {"data": self.collector.icon_filepath}, + # NetWitness configuration (flattened) + "netwitness_base_url": {"data": str(self.netwitness.base_url)}, + "netwitness_token": { + "data": ( + self.netwitness.token.get_secret_value() + if self.netwitness.token + else None + ) + }, + "netwitness_username": {"data": self.netwitness.username}, + "netwitness_password": { + "data": ( + self.netwitness.password.get_secret_value() + if self.netwitness.password + else None + ) + }, + "netwitness_max_results": {"data": self.netwitness.max_results}, + "netwitness_console_url": {"data": self.netwitness.console_url}, + "netwitness_time_window": {"data": self.netwitness.time_window}, + "netwitness_max_retry": {"data": self.netwitness.max_retry}, + "netwitness_offset": {"data": self.netwitness.offset}, + "netwitness_verify_ssl": {"data": self.netwitness.verify_ssl}, + }, + config_base_model=self, + ) diff --git a/netwitness/src/models/configs/netwitness_configs.py b/netwitness/src/models/configs/netwitness_configs.py new file mode 100644 index 00000000..4095e46c --- /dev/null +++ b/netwitness/src/models/configs/netwitness_configs.py @@ -0,0 +1,86 @@ +"""Configuration for NetWitness integration.""" + +from datetime import timedelta +from typing import Optional + +from pydantic import Field, SecretStr, model_validator +from src.models.configs import ConfigBaseSettings + + +class _ConfigLoaderNetWitness(ConfigBaseSettings): + """NetWitness Core SDK configuration settings. + + Contains connection details, authentication, and timing parameters for the + NetWitness Core SDK query API integration. + """ + + model_config = {"frozen": False} + + base_url: str = Field( + alias="NETWITNESS_BASE_URL", + default="https://netwitness.company.com:50103", + description="Base URL of a NetWitness Core service (e.g., Broker on port 50103).", + ) + token: Optional[SecretStr] = Field( + alias="NETWITNESS_TOKEN", + default=None, + description="Bearer token for the NetWitness Platform API (optional).", + ) + username: Optional[str] = Field( + alias="NETWITNESS_USERNAME", + default=None, + description="Username for HTTP basic authentication to the Core SDK.", + ) + password: Optional[SecretStr] = Field( + alias="NETWITNESS_PASSWORD", + default=None, + description="Password for HTTP basic authentication.", + ) + max_results: int = Field( + alias="NETWITNESS_MAX_RESULTS", + default=100, + description="Maximum number of sessions to return per query.", + ) + console_url: Optional[str] = Field( + alias="NETWITNESS_CONSOLE_URL", + default=None, + description="NetWitness console URL used to build trace links (defaults to base_url).", + ) + verify_ssl: bool = Field( + alias="NETWITNESS_VERIFY_SSL", + default=True, + description="Whether to verify the NetWitness TLS certificate.", + ) + time_window: Optional[timedelta] = Field( + alias="NETWITNESS_TIME_WINDOW", + default=timedelta(hours=1), + description="Time window for searches when no dates are provided.", + ) + max_retry: int = Field( + alias="NETWITNESS_MAX_RETRY", + default=3, + description="Maximum number of retry attempts for API calls.", + ) + offset: timedelta = Field( + alias="NETWITNESS_OFFSET", + default=timedelta(seconds=30), + description="Time offset between retry attempts.", + ) + + @model_validator(mode="after") + def _validate_auth(self) -> "_ConfigLoaderNetWitness": + """Ensure either a token or a username/password pair is configured. + + Returns: + The validated configuration instance. + + Raises: + ValueError: If no usable authentication method is configured. + + """ + if not self.token and not (self.username and self.password): + raise ValueError( + "NetWitness authentication requires either NETWITNESS_TOKEN or both " + "NETWITNESS_USERNAME and NETWITNESS_PASSWORD" + ) + return self diff --git a/netwitness/src/py.typed b/netwitness/src/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/netwitness/src/services/__init__.py b/netwitness/src/services/__init__.py new file mode 100644 index 00000000..1d50b762 --- /dev/null +++ b/netwitness/src/services/__init__.py @@ -0,0 +1,56 @@ +"""NetWitness Services Module. + +This module provides all the service components for NetWitness integration +following the SentinelOne pattern with KISS principles. +""" + +from .client_api import NetWitnessClientAPI +from .converter import Converter +from .exception import ( + NetWitnessAPIError, + NetWitnessAuthenticationError, + NetWitnessConfigurationError, + NetWitnessDataConversionError, + NetWitnessExpectationError, + NetWitnessFetchError, + NetWitnessMatchingError, + NetWitnessNetworkError, + NetWitnessNoAlertsFoundError, + NetWitnessNoMatchingAlertsError, + NetWitnessQueryError, + NetWitnessServiceError, + NetWitnessSessionError, + NetWitnessTimeoutError, + NetWitnessValidationError, +) +from .expectation_service import NetWitnessExpectationService +from .models import NetWitnessAlert, NetWitnessResponse, NetWitnessSearchCriteria +from .trace_service import NetWitnessTraceService + +__all__ = [ + # Main services + "NetWitnessClientAPI", + "NetWitnessExpectationService", + "NetWitnessTraceService", + "Converter", + # Models + "NetWitnessAlert", + "NetWitnessResponse", + "NetWitnessSearchCriteria", + # Exceptions + "NetWitnessServiceError", + "NetWitnessConfigurationError", + "NetWitnessExpectationError", + "NetWitnessFetchError", + "NetWitnessMatchingError", + "NetWitnessNoAlertsFoundError", + "NetWitnessNoMatchingAlertsError", + "NetWitnessDataConversionError", + "NetWitnessAPIError", + "NetWitnessNetworkError", + "NetWitnessSessionError", + "NetWitnessQueryError", + "NetWitnessValidationError", + "NetWitnessTimeoutError", + "NetWitnessAuthenticationError", +] diff --git a/netwitness/src/services/client_api.py b/netwitness/src/services/client_api.py new file mode 100644 index 00000000..9d8216b1 --- /dev/null +++ b/netwitness/src/services/client_api.py @@ -0,0 +1,398 @@ +"""NetWitness API client for querying the Core SDK with NWQL.""" + +import logging +import time +from datetime import datetime, timedelta, timezone +from typing import Any + +import requests # type: ignore[import-untyped] +from requests.exceptions import ( # type: ignore[import-untyped] + ConnectionError, + RequestException, + Timeout, +) + +from ..models.configs.config_loader import ConfigLoader +from .exception import ( + NetWitnessAPIError, + NetWitnessAuthenticationError, + NetWitnessNetworkError, + NetWitnessQueryError, + NetWitnessSessionError, + NetWitnessValidationError, +) +from .models import NetWitnessAlert, NetWitnessResponse, NetWitnessSearchCriteria +from .utils.parent_process_parser import ParentProcessParser + +LOG_PREFIX = "[NetWitnessClientAPI]" + +DEFAULT_TIME_WINDOW_HOURS = 1 +REQUEST_TIMEOUT_SECONDS = 60 +NW_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" + + +class NetWitnessClientAPI: + """NetWitness API client for fetching sessions via the Core SDK query API.""" + + def __init__(self, config: ConfigLoader | None = None) -> None: + """Initialize the NetWitness API client. + + Args: + config: Configuration loader instance for API client settings. + + Raises: + NetWitnessValidationError: If config is None or has invalid structure. + NetWitnessSessionError: If session creation fails. + + """ + if config is None: + raise NetWitnessValidationError("Config is required for API client") + + self.logger = logging.getLogger(__name__) + self.config = config + + try: + self.base_url = str(self.config.netwitness.base_url).rstrip("/") + self.token = ( + self.config.netwitness.token.get_secret_value() + if self.config.netwitness.token + else None + ) + self.username = self.config.netwitness.username + self.password = ( + self.config.netwitness.password.get_secret_value() + if self.config.netwitness.password + else None + ) + self.max_results = self.config.netwitness.max_results + self.console_url = self.config.netwitness.console_url + self.offset = self.config.netwitness.offset.total_seconds() + self.max_retry = self.config.netwitness.max_retry + self.verify_ssl = self.config.netwitness.verify_ssl + except AttributeError as e: + raise NetWitnessValidationError(f"Invalid config structure: {e}") from e + + if ( + hasattr(self.config.netwitness, "time_window") + and self.config.netwitness.time_window + ): + self.time_window = self.config.netwitness.time_window + else: + self.time_window = timedelta(hours=DEFAULT_TIME_WINDOW_HOURS) + + try: + self.session = self._create_session() + self.parent_process_parser = ParentProcessParser() + except NetWitnessValidationError: + raise + except Exception as e: + raise NetWitnessSessionError(f"Failed to create HTTP session: {e}") from e + + self.logger.info(f"{LOG_PREFIX} NetWitness API client initialized") + + def _create_session(self) -> requests.Session: + """Create an HTTP session with basic or bearer-token authentication. + + Returns: + Configured requests.Session with authentication. + + Raises: + NetWitnessValidationError: If no authentication is configured. + + """ + session = requests.Session() + headers = {"Accept": "application/json"} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + elif self.username and self.password: + session.auth = (self.username, self.password) + else: + raise NetWitnessValidationError( + "Either a token or a username/password pair is required" + ) + session.headers.update(headers) + session.verify = self.verify_ssl + return session + + def fetch_signatures( + self, search_signatures: list[dict[str, Any]], expectation_type: str + ) -> list[NetWitnessAlert]: + """Fetch NetWitness sessions based on search signatures. + + Args: + search_signatures: List of signature dictionaries. + expectation_type: Type of expectation for the fetched data. + + Returns: + List of NetWitnessAlert objects. + + Raises: + NetWitnessValidationError: If inputs are invalid. + NetWitnessAPIError: If API operations fail. + + """ + self._validate_inputs(search_signatures, expectation_type) + search_criteria = self._build_search_criteria(search_signatures) + return self._execute_query_with_retry(search_criteria) + + def fetch_with_retry( + self, + search_signatures: list[dict[str, Any]], + expectation_type: str, + max_retries: int | None = None, + offset_seconds: int | None = None, + ) -> list[NetWitnessAlert]: + """Fetch NetWitness sessions with a retry mechanism. + + Args: + search_signatures: List of signature dictionaries. + expectation_type: Type of expectation for the fetched data. + max_retries: Maximum number of retry attempts (defaults to config value). + offset_seconds: Seconds to wait between retries (defaults to config value). + + Returns: + List of NetWitnessAlert objects. + + Raises: + NetWitnessValidationError: If inputs are invalid. + NetWitnessAPIError: If all retry attempts fail. + + """ + self._validate_inputs(search_signatures, expectation_type) + search_criteria = self._build_search_criteria(search_signatures) + return self._execute_query_with_retry( + search_criteria, + max_retries=max_retries if max_retries is not None else self.max_retry, + offset_seconds=( + offset_seconds if offset_seconds is not None else int(self.offset) + ), + ) + + def _validate_inputs( + self, search_signatures: list[dict[str, Any]], expectation_type: str + ) -> None: + """Validate fetch inputs. + + Args: + search_signatures: List of signature dictionaries. + expectation_type: Type of expectation. + + Raises: + NetWitnessValidationError: If inputs are invalid. + + """ + if not search_signatures: + raise NetWitnessValidationError("search_signatures cannot be empty") + if expectation_type not in {"detection"}: + raise NetWitnessValidationError( + f"Invalid expectation_type: {expectation_type}. NetWitness only supports 'detection'" + ) + + def _build_search_criteria( + self, search_signatures: list[dict[str, str]] + ) -> NetWitnessSearchCriteria: + """Build a NetWitnessSearchCriteria object from search signatures. + + Args: + search_signatures: List of signature dictionaries. + + Returns: + NetWitnessSearchCriteria object. + + Raises: + NetWitnessValidationError: If signature format is invalid. + + """ + source_ips = [] + target_ips = [] + parent_process_names = [] + start_date = None + end_date = None + + for sig in search_signatures: + if not isinstance(sig, dict) or "type" not in sig or "value" not in sig: + raise NetWitnessValidationError(f"Invalid signature format: {sig}") + + sig_type = sig.get("type") + sig_value = sig.get("value") + + if sig_type in ["source_ipv4_address", "source_ipv6_address"]: + source_ips.append(sig_value) + elif sig_type in ["target_ipv4_address", "target_ipv6_address"]: + target_ips.append(sig_value) + elif sig_type == "parent_process_name": + parent_process_names.append(sig_value) + elif sig_type == "start_date": + start_date = sig_value + elif sig_type == "end_date": + end_date = sig_value + + return NetWitnessSearchCriteria( + source_ips=source_ips, + target_ips=target_ips, + parent_process_names=parent_process_names, + start_date=start_date, + end_date=end_date, + ) + + def _build_query( + self, search_criteria: NetWitnessSearchCriteria, extend_end_seconds: int = 0 + ) -> str: + """Build an NWQL query string from search criteria. + + Args: + search_criteria: NetWitnessSearchCriteria object. + extend_end_seconds: Optional seconds to widen the time window on retries. + + Returns: + The NWQL query expression. + + """ + fields = "time,ip.src,ip.dst,url,service,alert" + + conditions: list[str] = [] + for ip in search_criteria.source_ips or []: + conditions.append(f"ip.src={ip}") + for ip in search_criteria.target_ips or []: + conditions.append(f"ip.dst={ip}") + for parent_process_name in search_criteria.parent_process_names or []: + uuids = self.parent_process_parser.extract_uuids_from_parent_process_name( + parent_process_name + ) + if uuids: + inject_uuid, agent_uuid = uuids + path = f"/api/injects/{inject_uuid}/{agent_uuid}/executable-payload" + conditions.append(f"url contains '{path}'") + + match_clause = " || ".join(conditions) if conditions else "ip.src exists" + + end = datetime.now(timezone.utc) + window = self.time_window + timedelta(seconds=extend_end_seconds) + start = end - window + time_clause = ( + f'time="{start.strftime(NW_TIME_FORMAT)}"' + f'-"{end.strftime(NW_TIME_FORMAT)}"' + ) + + return f"select {fields} where {time_clause} && ({match_clause})" + + def _execute_query( + self, search_criteria: NetWitnessSearchCriteria, extend_end_seconds: int = 0 + ) -> list[NetWitnessAlert]: + """Execute a single NetWitness Core SDK query. + + Args: + search_criteria: NetWitnessSearchCriteria object with search parameters. + extend_end_seconds: Optional seconds to widen the time window for retries. + + Returns: + List of NetWitnessAlert objects. + + Raises: + NetWitnessAuthenticationError: If authentication fails. + NetWitnessAPIError: If the API call fails. + NetWitnessNetworkError: If a network error occurs. + NetWitnessQueryError: If query execution fails unexpectedly. + + """ + try: + query = self._build_query(search_criteria, extend_end_seconds) + endpoint = f"{self.base_url}/sdk" + params = { + "msg": "query", + "query": query, + "force-content-type": "application/json", + "size": self.max_results, + } + response = self.session.get( + endpoint, params=params, timeout=REQUEST_TIMEOUT_SECONDS + ) + + if response.status_code == 401: + raise NetWitnessAuthenticationError( + "Authentication with NetWitness failed" + ) + if response.status_code != 200: + raise NetWitnessAPIError( + f"NetWitness API returned status {response.status_code}: {response.text}" + ) + + netwitness_response = NetWitnessResponse.from_raw_response(response.json()) + self.logger.info( + f"{LOG_PREFIX} Retrieved {len(netwitness_response.results)} sessions" + ) + return netwitness_response.results + + except (NetWitnessAuthenticationError, NetWitnessAPIError): + raise + except (ConnectionError, Timeout) as e: + raise NetWitnessNetworkError(f"Network error during query: {e}") from e + except RequestException as e: + raise NetWitnessAPIError(f"HTTP request failed during query: {e}") from e + except Exception as e: + raise NetWitnessQueryError(f"Unexpected error executing query: {e}") from e + + def _execute_query_with_retry( + self, + search_criteria: NetWitnessSearchCriteria, + max_retries: int | None = None, + offset_seconds: int | None = None, + ) -> list[NetWitnessAlert]: + """Execute a NetWitness query with a retry mechanism. + + Args: + search_criteria: NetWitnessSearchCriteria object with search parameters. + max_retries: Maximum number of retry attempts. + offset_seconds: Seconds to wait between retries. + + Returns: + List of NetWitnessAlert objects (empty if none found after all retries). + + Raises: + NetWitnessAPIError: If all attempts fail with an error. + + """ + retries = max_retries if max_retries is not None else self.max_retry + offset = offset_seconds if offset_seconds is not None else int(self.offset) + + last_exception: Exception | None = None + + for attempt in range(retries + 1): + try: + if attempt > 0: + time.sleep(offset) + extend_seconds = offset * attempt + else: + extend_seconds = 0 + + alerts = self._execute_query(search_criteria, extend_seconds) + if alerts: + self.logger.info( + f"{LOG_PREFIX} Attempt {attempt + 1}: found {len(alerts)} sessions" + ) + return alerts + if attempt == retries: + self.logger.warning( + f"{LOG_PREFIX} No sessions found after all retry attempts" + ) + return [] + except (NetWitnessAuthenticationError, NetWitnessValidationError): + raise + except ( + NetWitnessAPIError, + NetWitnessNetworkError, + NetWitnessQueryError, + ConnectionError, + Timeout, + RequestException, + ) as e: + last_exception = e + self.logger.warning(f"{LOG_PREFIX} Attempt {attempt + 1} failed: {e}") + if attempt == retries: + break + + if last_exception: + raise NetWitnessAPIError( + f"All NetWitness fetch attempts failed. Last error: {last_exception}" + ) from last_exception + return [] diff --git a/netwitness/src/services/converter.py b/netwitness/src/services/converter.py new file mode 100644 index 00000000..67ebe8bc --- /dev/null +++ b/netwitness/src/services/converter.py @@ -0,0 +1,272 @@ +"""NetWitness Data Converter. + +This module provides conversion functionality for NetWitness data types. +Handles conversion between different data formats and OAEV data. +""" + +import logging +from typing import Any + +from .exception import NetWitnessDataConversionError, NetWitnessValidationError +from .models import NetWitnessAlert +from .utils.parent_process_parser import ParentProcessParser + +LOG_PREFIX = "[NetWitnessConverter]" + + +class Converter: + """Converter for NetWitness data to OAEV format.""" + + def __init__(self) -> None: + """Initialize converter with logger. + + Sets up logging for the converter instance. + """ + self.logger = logging.getLogger(__name__) + self.parent_process_parser = ParentProcessParser() + self.logger.debug(f"{LOG_PREFIX} NetWitness data converter initialized") + + def convert_data_to_oaev_data( + self, + data: NetWitnessAlert | list[NetWitnessAlert] | None, + ) -> list[dict[str, Any]]: + """Convert NetWitness data to OAEV format. + + Args: + data: Raw NetWitness alert data. + + Returns: + List of OAEV data dictionaries. + + Raises: + NetWitnessValidationError: If data format is invalid. + NetWitnessDataConversionError: If conversion fails. + + """ + if not data: + self.logger.debug( + f"{LOG_PREFIX} No data provided for conversion, returning empty list" + ) + return [] + + if not isinstance(data, list): + data = [data] + + try: + self.logger.debug( + f"{LOG_PREFIX} Converting {len(data)} NetWitness alert items to OAEV format" + ) + oaev_datas = [] + alert_count = 0 + unknown_count = 0 + + for i, item in enumerate(data, 1): + self.logger.debug(f"{LOG_PREFIX} Processing alert item {i}/{len(data)}") + + try: + if self._is_alert_data(item): + oaev_data = self._alert_data(item) + alert_count += 1 + self.logger.debug( + f"{LOG_PREFIX} Converted NetWitness alert item {i}" + ) + else: + unknown_count += 1 + self.logger.warning( + f"{LOG_PREFIX} Unknown data type for item {i}: {type(item)}" + ) + continue + + if oaev_data: + oaev_datas.append(oaev_data) + self.logger.debug( + f"{LOG_PREFIX} Successfully converted item {i} to OAEV format" + ) + else: + self.logger.debug( + f"{LOG_PREFIX} Item {i} conversion resulted in empty OAEV data - filtering out" + ) + + except Exception as e: + raise NetWitnessDataConversionError( + f"Failed to convert data item {i}: {e}" + ) from e + + self.logger.info( + f"{LOG_PREFIX} NetWitness to OAEV conversion: processed {len(data)} items -> {len(oaev_datas)} results" + ) + + self.logger.info( + f"{LOG_PREFIX} Conversion completed: {alert_count} alerts, " + f"{unknown_count} unknown items -> {len(oaev_datas)} OAEV items" + ) + return oaev_datas + + except NetWitnessDataConversionError: + raise + except Exception as e: + raise NetWitnessDataConversionError( + f"Unexpected error converting data to OAEV format: {e}" + ) from e + + def _is_alert_data(self, data: Any) -> bool: + """Check if data is NetWitness alert data. + + Args: + data: Data object to check. + + Returns: + True if data is a NetWitnessAlert instance. + + """ + return isinstance(data, NetWitnessAlert) + + def _alert_data(self, alert_data: NetWitnessAlert) -> dict[str, Any]: + """Convert NetWitness alert data to OAEV format. + + Args: + alert_data: NetWitness alert data. + + Returns: + OAEV formatted data dictionary. + + Raises: + NetWitnessValidationError: If input type is invalid. + NetWitnessDataConversionError: If conversion fails. + + """ + try: + oaev_data = {} + + if not isinstance(alert_data, NetWitnessAlert): + raise NetWitnessValidationError( + f"Invalid input type for alert conversion: {type(alert_data)}" + ) + + source_ips = self._extract_source_ips(alert_data) + if source_ips: + oaev_data["source_ipv4_address"] = { + "type": "simple", + "data": source_ips, + } + self.logger.debug(f"{LOG_PREFIX} Using source IPs: {source_ips}") + + target_ips = self._extract_target_ips(alert_data) + if target_ips: + oaev_data["target_ipv4_address"] = { + "type": "simple", + "data": target_ips, + } + self.logger.debug(f"{LOG_PREFIX} Using target IPs: {target_ips}") + + parent_process_name = self._extract_parent_process_name(alert_data) + if parent_process_name: + oaev_data["parent_process_name"] = { + "type": "fuzzy", + "data": [parent_process_name], + "score": 95, + } + self.logger.debug( + f"{LOG_PREFIX} Using parent process name: {parent_process_name}" + ) + + if alert_data.signature: + self.logger.debug( + f"{LOG_PREFIX} Alert includes signature: {alert_data.signature}" + ) + + if alert_data.rule_name: + self.logger.debug( + f"{LOG_PREFIX} Alert includes rule name: {alert_data.rule_name}" + ) + + self.logger.debug( + f"{LOG_PREFIX} Converted NetWitness alert to OAEV with {len(oaev_data)} fields" + ) + return oaev_data if oaev_data else {} + + except NetWitnessValidationError: + raise + except Exception as e: + raise NetWitnessDataConversionError( + f"Error converting NetWitness alert data to OAEV: {e}" + ) from e + + def _extract_source_ips(self, alert_data: NetWitnessAlert) -> list[str]: + """Extract source IP addresses from alert data. + + Args: + alert_data: NetWitnessAlert object. + + Returns: + List of unique source IP addresses. + + """ + source_ips = [] + + if alert_data.src_ip and alert_data.src_ip not in source_ips: + source_ips.append(alert_data.src_ip) + + return source_ips + + def _extract_target_ips(self, alert_data: NetWitnessAlert) -> list[str]: + """Extract target IP addresses from alert data. + + Args: + alert_data: NetWitnessAlert object. + + Returns: + List of unique target IP addresses. + + """ + target_ips = [] + + if alert_data.dst_ip and alert_data.dst_ip not in target_ips: + target_ips.append(alert_data.dst_ip) + + return target_ips + + def _extract_parent_process_name(self, alert_data: NetWitnessAlert) -> str: + """Extract parent process name from alert data. + + This method reconstructs the parent process name from the URL path + found in the alert data. + + Args: + alert_data: NetWitnessAlert object. + + Returns: + Reconstructed parent process name if UUIDs found in URL path, empty string otherwise. + + """ + if not alert_data.url_path: + self.logger.debug(f"{LOG_PREFIX} No URL path found in alert data") + return "" + + try: + self.logger.debug( + f"{LOG_PREFIX} Extracting parent process name from URL path: {alert_data.url_path}" + ) + + uuids = self.parent_process_parser.extract_uuids_from_url_path( + alert_data.url_path + ) + if uuids: + inject_uuid, agent_uuid = uuids + parent_process_name = ( + self.parent_process_parser.construct_parent_process_name( + inject_uuid, agent_uuid + ) + ) + self.logger.debug( + f"{LOG_PREFIX} Reconstructed parent process name: {parent_process_name}" + ) + return parent_process_name + else: + self.logger.debug( + f"{LOG_PREFIX} No UUIDs found in URL path: {alert_data.url_path}" + ) + return "" + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Error extracting parent process name: {e}") + return "" diff --git a/netwitness/src/services/exception.py b/netwitness/src/services/exception.py new file mode 100644 index 00000000..3951a853 --- /dev/null +++ b/netwitness/src/services/exception.py @@ -0,0 +1,94 @@ +"""NetWitness Service Exceptions. + +Custom exceptions for NetWitness service operations. +""" + + +class NetWitnessServiceError(Exception): + """Base exception for all NetWitness service errors.""" + + pass + + +class NetWitnessConfigurationError(NetWitnessServiceError): + """Raised when there's a configuration error.""" + + pass + + +class NetWitnessExpectationError(NetWitnessServiceError): + """Raised when there's an error processing expectations.""" + + pass + + +class NetWitnessFetchError(NetWitnessServiceError): + """Raised when there's an error fetching data from NetWitness API.""" + + pass + + +class NetWitnessMatchingError(NetWitnessServiceError): + """Raised when there's an error matching alerts.""" + + pass + + +class NetWitnessNoAlertsFoundError(NetWitnessServiceError): + """Raised when no alerts are found for the search criteria.""" + + pass + + +class NetWitnessNoMatchingAlertsError(NetWitnessServiceError): + """Raised when alerts are found but none match the expectation.""" + + pass + + +class NetWitnessDataConversionError(NetWitnessServiceError): + """Raised when there's an error converting data.""" + + pass + + +class NetWitnessAPIError(NetWitnessServiceError): + """Raised when there's an error with NetWitness API operations.""" + + pass + + +class NetWitnessNetworkError(NetWitnessServiceError): + """Raised when there's a network connectivity error.""" + + pass + + +class NetWitnessSessionError(NetWitnessServiceError): + """Raised when there's an error with session management.""" + + pass + + +class NetWitnessQueryError(NetWitnessServiceError): + """Raised when there's an error with query operations.""" + + pass + + +class NetWitnessValidationError(NetWitnessServiceError): + """Raised when input validation fails.""" + + pass + + +class NetWitnessTimeoutError(NetWitnessServiceError): + """Raised when operations timeout.""" + + pass + + +class NetWitnessAuthenticationError(NetWitnessServiceError): + """Raised when authentication fails.""" + + pass diff --git a/netwitness/src/services/expectation_service.py b/netwitness/src/services/expectation_service.py new file mode 100644 index 00000000..46afc3ba --- /dev/null +++ b/netwitness/src/services/expectation_service.py @@ -0,0 +1,777 @@ +"""NetWitness Expectation Service Provider. + +This module contains all the NetWitness-specific logic for handling expectations. +It implements the service provider protocol and defines which signatures to support, +how to fetch data, and how to process expectations. +""" + +import logging +from datetime import timedelta +from typing import Any + +from pyoaev.apis.inject_expectation.model import ( # type: ignore[import-untyped] + DetectionExpectation, + PreventionExpectation, +) +from pyoaev.helpers import OpenAEVDetectionHelper # type: ignore[import-untyped] +from pyoaev.signatures.types import SignatureTypes # type: ignore[import-untyped] + +from ..collector.models import ExpectationResult +from ..models.configs.config_loader import ConfigLoader +from .client_api import NetWitnessClientAPI +from .converter import Converter +from .exception import ( + NetWitnessAPIError, + NetWitnessConfigurationError, + NetWitnessDataConversionError, + NetWitnessExpectationError, + NetWitnessMatchingError, + NetWitnessNetworkError, + NetWitnessNoAlertsFoundError, + NetWitnessNoMatchingAlertsError, + NetWitnessServiceError, + NetWitnessValidationError, +) + +LOG_PREFIX = "[NetWitnessExpectationService]" + + +class NetWitnessExpectationService: + """NetWitness-specific service provider for expectation handling. + + This class contains all the business logic specific to NetWitness: + - Which signature types to support (only IPV4/6 addresses) + - How to fetch data from NetWitness + - How to validate expectations against data + - How to handle batching and optimization + """ + + SUPPORTED_SIGNATURES = [ + SignatureTypes.SIG_TYPE_SOURCE_IPV4_ADDRESS, + SignatureTypes.SIG_TYPE_TARGET_IPV4_ADDRESS, + SignatureTypes.SIG_TYPE_SOURCE_IPV6_ADDRESS, + SignatureTypes.SIG_TYPE_TARGET_IPV6_ADDRESS, + SignatureTypes.SIG_TYPE_START_DATE, + SignatureTypes.SIG_TYPE_END_DATE, + SignatureTypes.SIG_TYPE_PARENT_PROCESS_NAME, + ] + + def __init__(self, config: ConfigLoader | None = None) -> None: + """Initialize the NetWitness service provider. + + Args: + config: Configuration loader instance for service settings. + + Raises: + NetWitnessValidationError: If config is None. + NetWitnessConfigurationError: If service components initialization fails. + + """ + if config is None: + raise NetWitnessValidationError( + "Config is required for expectation service" + ) + + self.logger = logging.getLogger(__name__) + self.config = config + + try: + self.logger.debug( + f"{LOG_PREFIX} Initializing NetWitness service components..." + ) + self.client_api = NetWitnessClientAPI(config) + self.converter = Converter() + self.logger.info( + f"{LOG_PREFIX} NetWitness expectation service initialized successfully" + ) + except (NetWitnessValidationError, NetWitnessConfigurationError): + raise + except Exception as e: + raise NetWitnessConfigurationError( + f"Failed to initialize NetWitness service components: {e}" + ) from e + + if ( + hasattr(config, "netwitness") + and hasattr(config.netwitness, "time_window") + and config.netwitness.time_window + ): + self.time_window = config.netwitness.time_window + self.logger.debug( + f"{LOG_PREFIX} Using configured time window: {self.time_window}" + ) + else: + self.time_window = timedelta(hours=1) + self.logger.warning( + f"{LOG_PREFIX} No time_window configured, using default 1 hour" + ) + + if hasattr(config, "netwitness"): + self.max_retry = getattr(config.netwitness, "max_retry", 3) + self.offset = getattr( + config.netwitness, "offset", timedelta(seconds=30) + ).total_seconds() + self.logger.debug( + f"{LOG_PREFIX} Using configured retry parameters: max_retry={self.max_retry}, offset={self.offset}s" + ) + else: + self.max_retry = 3 + self.offset = 30 + self.logger.warning( + f"{LOG_PREFIX} No retry configuration found, using defaults: max_retry={self.max_retry}, offset={self.offset}s" + ) + + def get_supported_signatures(self) -> list[SignatureTypes]: + """Get the signature types this service supports. + + Returns: + List of SignatureTypes that this service can process. + + """ + self.logger.debug( + f"{LOG_PREFIX} Returning {len(self.SUPPORTED_SIGNATURES)} supported signature types" + ) + return self.SUPPORTED_SIGNATURES + + def handle_batch_expectations( + self, + expectations: list[DetectionExpectation | PreventionExpectation], + detection_helper: OpenAEVDetectionHelper, + ) -> list[ExpectationResult]: + """Handle a batch of expectations. + + Processes each expectation individually and collects results, + handling errors gracefully for individual expectations. + + Args: + expectations: List of expectations to process. + detection_helper: OpenAEV detection helper. + + Returns: + List of ExpectationResult objects. + + Raises: + NetWitnessExpectationError: If batch processing fails. + + """ + if not expectations: + self.logger.info(f"{LOG_PREFIX} No expectations to process") + return [] + + try: + self.logger.info( + f"{LOG_PREFIX} Starting batch processing of {len(expectations)} expectations" + ) + + all_results_with_expectations_associated = [] + + for i, expectation in enumerate(expectations, 1): + expectation_id = str(expectation.inject_expectation_id) + self.logger.debug( + f"{LOG_PREFIX} Processing expectation {i}/{len(expectations)}: {expectation_id}" + ) + + try: + result = self.process_expectation(expectation, detection_helper) + if result.is_valid: + self.logger.debug( + f"{LOG_PREFIX} Expectation {expectation_id} processed successfully" + ) + else: + self.logger.debug( + f"{LOG_PREFIX} Expectation {expectation_id} failed validation" + ) + + except NetWitnessServiceError as e: + self.logger.warning( + f"{LOG_PREFIX} NetWitness service error for expectation {expectation_id}: {e}" + ) + result = self._create_error_result_object(e, expectation) + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Unexpected error processing expectation {expectation_id}: {e}" + ) + result = self._create_error_result_object( + NetWitnessExpectationError(f"Unexpected error: {e}"), + expectation, + ) + + all_results_with_expectations_associated.append(result) + + valid_count = sum( + 1 for r in all_results_with_expectations_associated if r.is_valid + ) + invalid_count = len(all_results_with_expectations_associated) - valid_count + + self.logger.info( + f"{LOG_PREFIX} Batch expectation processing: processed {len(expectations)} items -> {len(all_results_with_expectations_associated)} results" + ) + self.logger.info( + f"{LOG_PREFIX} Batch processing completed: {valid_count} valid, {invalid_count} invalid" + ) + + return all_results_with_expectations_associated + + except Exception as e: + raise NetWitnessExpectationError( + f"Error in handle_batch_expectations: {e}" + ) from e + + def process_expectation( + self, + expectation: DetectionExpectation | PreventionExpectation, + detection_helper: OpenAEVDetectionHelper, + ) -> ExpectationResult: + """Process a single expectation based on its type. + + Args: + expectation: The expectation to process (Detection only for NetWitness). + detection_helper: OpenAEV detection helper instance. + + Returns: + ExpectationResult containing the processing outcome. + + Raises: + NetWitnessExpectationError: If expectation type is unsupported. + + """ + expectation_id = str(expectation.inject_expectation_id) + + if isinstance(expectation, DetectionExpectation): + self.logger.debug( + f"{LOG_PREFIX} Processing detection expectation: {expectation_id}" + ) + return self.handle_detection_expectation(expectation, detection_helper) + elif isinstance(expectation, PreventionExpectation): + self.logger.warning( + f"{LOG_PREFIX} NetWitness service warning for expectation {expectation_id}: NetWitness only supports DetectionExpectations, not PreventionExpectations, marking them as invalid" + ) + return ExpectationResult( + expectation_id=expectation_id, + is_valid=False, + expectation=expectation, + error_message="NetWitness only supports DetectionExpectations, not PreventionExpectations", + ) + else: + self.logger.error( + f"{LOG_PREFIX} Unsupported expectation type for {expectation_id}: {type(expectation).__name__}" + ) + raise NetWitnessExpectationError( + f"Unsupported expectation type: {type(expectation).__name__}" + ) + + def handle_detection_expectation( + self, + expectation: DetectionExpectation, + detection_helper: OpenAEVDetectionHelper, + ) -> ExpectationResult: + """Handle a detection expectation. + + Args: + expectation: The detection expectation to process. + detection_helper: OpenAEV detection helper instance. + + Returns: + ExpectationResult containing the processing outcome. + + """ + result_dict = self._handle_expectation( + expectation, detection_helper, "detection" + ) + return self._convert_dict_to_result(result_dict, expectation) + + def handle_prevention_expectation( + self, + expectation: PreventionExpectation, + detection_helper: OpenAEVDetectionHelper, + ) -> ExpectationResult: + """Handle a prevention expectation. + + Since NetWitness only supports detection, this method logs a warning + and returns an invalid result instead of throwing an error. + + Args: + expectation: The prevention expectation to process. + detection_helper: OpenAEV detection helper instance. + + Returns: + ExpectationResult indicating that prevention is not supported. + + """ + expectation_id = str(expectation.inject_expectation_id) + self.logger.warning( + f"{LOG_PREFIX} NetWitness service error for expectation {expectation_id}: NetWitness only supports DetectionExpectations, not PreventionExpectations" + ) + return ExpectationResult( + expectation_id=expectation_id, + is_valid=False, + expectation=expectation, + error_message="NetWitness only supports DetectionExpectations, not PreventionExpectations", + ) + + def _handle_expectation( + self, + expectation: DetectionExpectation, + detection_helper: OpenAEVDetectionHelper, + expectation_type: str, + ) -> dict[str, Any]: + """Core logic for handling expectations. + + Args: + expectation: The expectation to process. + detection_helper: OpenAEV detection helper instance. + expectation_type: Type of expectation ('detection'). + + Returns: + Dictionary containing processing results. + + Raises: + NetWitnessExpectationError: If expectation processing fails. + + """ + expectation_id = expectation.inject_expectation_id + + try: + self.logger.debug( + f"{LOG_PREFIX} Starting {expectation_type} expectation processing: {expectation_id}" + ) + + self.logger.debug(f"{LOG_PREFIX} Extracting signatures from expectation...") + search_signatures, matching_signatures = self._extract_signatures( + expectation + ) + self.logger.debug( + f"{LOG_PREFIX} Extracted {len(search_signatures)} search signatures, {len(matching_signatures)} matching signatures" + ) + + self.logger.debug( + f"{LOG_PREFIX} Fetching NetWitness data for {expectation_type} expectation..." + ) + netwitness_data = self.client_api.fetch_with_retry( + search_signatures, expectation_type, self.max_retry, int(self.offset) + ) + self.logger.debug( + f"{LOG_PREFIX} Fetched {len(netwitness_data)} data items from NetWitness" + ) + + self.logger.debug( + f"{LOG_PREFIX} Converting NetWitness data to OAEV format..." + ) + oaev_data = self.converter.convert_data_to_oaev_data(netwitness_data) + self.logger.debug( + f"{LOG_PREFIX} Converted to {len(oaev_data)} OAEV data items" + ) + + self.logger.debug( + f"{LOG_PREFIX} Matching data against expectation signatures..." + ) + result = self._match( + oaev_data, matching_signatures, detection_helper, expectation_type + ) + + return result + + except ( + NetWitnessServiceError, + NetWitnessAPIError, + NetWitnessNetworkError, + NetWitnessDataConversionError, + ): + raise + except Exception as e: + raise NetWitnessExpectationError( + f"Unexpected error processing expectation: {e}" + ) from e + + def _extract_signatures( + self, expectation: DetectionExpectation + ) -> tuple[list[dict[str, str]], list[dict[str, str]]]: + """Extract and filter signatures from expectation. + + Args: + expectation: The expectation to extract signatures from. + + Returns: + Tuple of (search_signatures, matching_signatures): + - search_signatures: signatures for API query building + - matching_signatures: signatures for alert matching (excludes date metadata) + + Raises: + NetWitnessExpectationError: If signature extraction fails. + + """ + try: + all_signatures = [ + {"type": sig.type.value, "value": sig.value} + for sig in expectation.inject_expectation_signatures + ] + self.logger.debug( + f"{LOG_PREFIX} Found {len(all_signatures)} total signatures in expectation" + ) + + search_signatures = [ + sig + for sig in all_signatures + if sig["type"] in [s.value for s in self.SUPPORTED_SIGNATURES] + ] + + date_signature_types = [ + SignatureTypes.SIG_TYPE_START_DATE.value, + SignatureTypes.SIG_TYPE_END_DATE.value, + ] + matching_signatures = [ + sig + for sig in search_signatures + if sig["type"] not in date_signature_types + ] + + self.logger.debug( + f"{LOG_PREFIX} Filtered to {len(search_signatures)} search signatures and {len(matching_signatures)} matching signatures" + ) + + return search_signatures, matching_signatures + + except Exception as e: + raise NetWitnessExpectationError( + f"Failed to extract signatures from expectation: {e}" + ) from e + + def _match( + self, + oaev_data: list[dict[str, Any]], + matching_signatures: list[dict[str, str]], + detection_helper: OpenAEVDetectionHelper, + expectation_type: str, + ) -> dict[str, Any]: + """Match OAEV data against expectation signatures. + + Args: + oaev_data: List of OAEV formatted data. + matching_signatures: Signatures to match against. + detection_helper: OpenAEV detection helper. + expectation_type: Type of expectation ('detection'). + + Returns: + Result dictionary with match status and matching data. + + Raises: + NetWitnessNoAlertsFoundError: If no data available for matching. + NetWitnessNoMatchingAlertsError: If no matching alerts found. + NetWitnessMatchingError: If matching process fails. + + """ + try: + if not oaev_data: + self.logger.debug(f"{LOG_PREFIX} No OAEV data available for matching") + raise NetWitnessNoAlertsFoundError("No data available for matching") + + self.logger.debug( + f"{LOG_PREFIX} Attempting to match {len(oaev_data)} data items against {len(matching_signatures)} signatures" + ) + + for i, data_item in enumerate(oaev_data): + self.logger.debug(f"{i} data_item: {data_item}") + self.logger.debug( + f"{LOG_PREFIX} Matching data item {i + 1}/{len(oaev_data)}" + ) + + available_signatures = [ + sig for sig in matching_signatures if sig["type"] in data_item + ] + + self.logger.debug( + f"{LOG_PREFIX} Data item {i + 1} has {len(available_signatures)} available signatures out of {len(matching_signatures)} total signatures" + ) + + if available_signatures: + try: + self.logger.debug( + f"{LOG_PREFIX} Testing match for data item {i + 1} with {len(available_signatures)} signatures" + ) + + # Use detection_helper with filtered signatures per type + if self._match_with_detection_helper( + available_signatures, data_item, detection_helper + ): + self.logger.debug( + f"{LOG_PREFIX} Match found for data item {i + 1}!" + ) + + self.logger.info( + f"{LOG_PREFIX} Successful match found for {expectation_type} expectation" + ) + self.logger.debug( + f"{LOG_PREFIX} Matching data: {data_item}" + ) + + result = { + "is_valid": True, + "matching_data": [data_item], + "total_data_found": len(oaev_data), + } + + return result + else: + self.logger.debug( + f"{LOG_PREFIX} No match for data item {i + 1}" + ) + continue + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Error during matching for data item {i + 1}: {e}" + ) + raise NetWitnessNoMatchingAlertsError() from e + else: + self.logger.debug( + f"{LOG_PREFIX} Data item {i + 1} has no available signatures to match against" + ) + + self.logger.info( + f"{LOG_PREFIX} No matching alerts found after checking {len(oaev_data)} data items" + ) + raise NetWitnessNoMatchingAlertsError() + + except ( + NetWitnessServiceError, + NetWitnessNoAlertsFoundError, + NetWitnessNoMatchingAlertsError, + ): + raise + except Exception as e: + raise NetWitnessMatchingError() from e + + def _match_with_detection_helper( + self, + signatures: list[dict[str, str]], + data_item: dict[str, Any], + detection_helper: OpenAEVDetectionHelper, + ) -> bool: + """Match signatures using detection_helper with proper OR logic. + + Args: + signatures: List of signature dictionaries. + data_item: OAEV data item to match against. + detection_helper: OpenAEV detection helper instance. + + Returns: + True if matching succeeds, False otherwise. + + Logic: + 1. Parent process: MUST match exactly (if present) - stop if False + 2. Source IPs: Call detection_helper for each IP individually, stop at first match (OR logic) + 3. Target IPs: Call detection_helper for each IP individually, stop at first match (OR logic) + 4. Must have parent_process=True AND (at least one src_ip=True OR at least one dst_ip=True) + + """ + try: + signature_groups: dict[str, list[dict[str, str]]] = {} + for sig in signatures: + sig_type = sig["type"] + if sig_type not in signature_groups: + signature_groups[sig_type] = [] + signature_groups[sig_type].append(sig) + + self.logger.debug( + f"{LOG_PREFIX} Processing {len(signature_groups)} signature groups" + ) + + # Parent process is only enforced when a parent_process_name + # signature is present (see step 1 of the docstring). Default to + # True so IP-only expectations are not rejected by the + # ``if not parent_process_match`` guard below. + parent_process_match = True + source_ip_match = False + target_ip_match = False + + if "parent_process_name" in signature_groups: + parent_sigs = signature_groups["parent_process_name"] + self.logger.debug( + f"{LOG_PREFIX} Checking parent process with {len(parent_sigs)} signatures" + ) + + filtered_data = { + k: v for k, v in data_item.items() if k == "parent_process_name" + } + + parent_process_match = detection_helper.match_alert_elements( + parent_sigs, filtered_data + ) + + self.logger.debug( + f"{LOG_PREFIX} Parent process match: {parent_process_match}" + ) + + if not parent_process_match: + self.logger.debug(f"{LOG_PREFIX} Parent process failed - stopping") + return False + + source_ip_types = ["source_ipv4_address", "source_ipv6_address"] + for ip_type in source_ip_types: + if ip_type in signature_groups and ip_type in data_item: + ip_sigs = signature_groups[ip_type] + self.logger.debug( + f"{LOG_PREFIX} Checking {ip_type} with {len(ip_sigs)} signatures" + ) + + for sig in ip_sigs: + filtered_data = {ip_type: data_item[ip_type]} + if detection_helper.match_alert_elements([sig], filtered_data): + self.logger.debug( + f"{LOG_PREFIX} ✓ {ip_type} signature matched: {sig['value']}" + ) + source_ip_match = True + break + + if source_ip_match: + break + + target_ip_types = ["target_ipv4_address", "target_ipv6_address"] + for ip_type in target_ip_types: + if ip_type in signature_groups and ip_type in data_item: + ip_sigs = signature_groups[ip_type] + self.logger.debug( + f"{LOG_PREFIX} Checking {ip_type} with {len(ip_sigs)} signatures" + ) + + for sig in ip_sigs: + filtered_data = {ip_type: data_item[ip_type]} + if detection_helper.match_alert_elements([sig], filtered_data): + self.logger.debug( + f"{LOG_PREFIX} ✓ {ip_type} signature matched: {sig['value']}" + ) + target_ip_match = True + break + + if target_ip_match: + break + + has_source_sigs = any(t in signature_groups for t in source_ip_types) + has_target_sigs = any(t in signature_groups for t in target_ip_types) + + self.logger.debug( + f"{LOG_PREFIX} Match results - Parent: {parent_process_match}, " + f"Source IP: {source_ip_match} (required: {has_source_sigs}), " + f"Target IP: {target_ip_match} (required: {has_target_sigs})" + ) + + if not parent_process_match: + return False + + if has_source_sigs and has_target_sigs: + result = source_ip_match or target_ip_match + elif has_source_sigs: + result = source_ip_match + elif has_target_sigs: + result = target_ip_match + else: + result = True + + self.logger.debug(f"{LOG_PREFIX} Final match result: {result}") + return result + + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Error in detection_helper matching: {e}") + return False + + def _create_error_result( + self, + error: NetWitnessServiceError, + expectation: DetectionExpectation | None = None, + ) -> dict[str, Any]: + """Create an error result dictionary from a NetWitness service error. + + Args: + error: The NetWitness service error that occurred. + expectation: Optional expectation object that caused the error. + + Returns: + Dictionary containing error details and metadata. + + """ + result = { + "is_valid": False, + "error": str(error), + "error_type": error.__class__.__name__, + } + + if hasattr(error, "status_code") and error.status_code: + result["status_code"] = error.status_code + + if hasattr(error, "response_data") and error.response_data: + result["response_data"] = error.response_data + + if expectation: + result["expectation"] = expectation + result["expectation_id"] = str(expectation.inject_expectation_id) + + return result + + def _create_error_result_object( + self, + error: NetWitnessServiceError, + expectation: DetectionExpectation | None = None, + ) -> ExpectationResult: + """Create an ExpectationResult object from a NetWitness service error. + + Args: + error: The NetWitness service error that occurred. + expectation: Optional expectation object that caused the error. + + Returns: + ExpectationResult object with error details. + + """ + expectation_id = ( + str(expectation.inject_expectation_id) if expectation else "unknown" + ) + + error_message = str(error) + if hasattr(error, "status_code") and error.status_code: + error_message += f" (Status: {error.status_code})" + + return ExpectationResult( + expectation_id=expectation_id, + is_valid=False, + expectation=expectation, + error_message=error_message, + ) + + def _convert_dict_to_result( + self, + result_dict: dict[str, Any], + expectation: DetectionExpectation, + ) -> ExpectationResult: + """Convert a dictionary result to ExpectationResult object. + + Args: + result_dict: Dictionary containing processing results. + expectation: The expectation that was processed. + + Returns: + ExpectationResult object with structured data. + + """ + return ExpectationResult( + expectation_id=str(expectation.inject_expectation_id), + is_valid=result_dict.get("is_valid", False), + expectation=expectation, + matched_alerts=result_dict.get("matching_data"), + error_message=result_dict.get("error"), + ) + + def get_service_info(self) -> dict[str, Any]: + """Get information about this service provider. + + Returns: + Dictionary containing service metadata and capabilities. + + """ + info = { + "service_name": "NetWitness", + "supported_signatures": [sig.value for sig in self.SUPPORTED_SIGNATURES], + "supports_detection": True, + "supports_prevention": False, + "description": f"NetWitness expectation validation service ({len(self.SUPPORTED_SIGNATURES)} signature types, detection only)", + } + self.logger.debug(f"{LOG_PREFIX} Service info: {info}") + return info diff --git a/netwitness/src/services/models.py b/netwitness/src/services/models.py new file mode 100644 index 00000000..769a39f6 --- /dev/null +++ b/netwitness/src/services/models.py @@ -0,0 +1,109 @@ +"""NetWitness Data Models. + +This module provides Pydantic models for NetWitness Core SDK query operations. +""" + +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +def _meta_str(meta: dict[str, Any], keys: list[str]) -> Optional[str]: + """Return the first non-empty meta value among several meta keys.""" + for key in keys: + value = meta.get(key) + if isinstance(value, list) and value: + value = value[0] + if value not in (None, ""): + return str(value) + return None + + +class NetWitnessSearchCriteria(BaseModel): + """Search criteria for NetWitness Core SDK queries.""" + + source_ips: Optional[list[str]] = Field( + default_factory=list, description="Source IP addresses to search for" + ) + target_ips: Optional[list[str]] = Field( + default_factory=list, description="Target IP addresses to search for" + ) + parent_process_names: Optional[list[str]] = Field( + default_factory=list, description="Parent process names to search for" + ) + start_date: Optional[str] = Field( + None, description="Start date for the search in ISO format" + ) + end_date: Optional[str] = Field( + None, description="End date for the search in ISO format" + ) + + +class NetWitnessAlert(BaseModel): + """NetWitness alert model (mapped from a grouped SDK query result).""" + + time: str = Field(..., description="Session time") + src_ip: Optional[str] = Field(None, description="Source IP address (ip.src)") + dst_ip: Optional[str] = Field(None, description="Destination IP address (ip.dst)") + url_path: Optional[str] = Field(None, description="URL meta value") + signature: Optional[str] = Field(None, description="Alert / risk meta") + rule_name: Optional[str] = Field(None, description="Rule / alert name") + event_type: Optional[str] = Field(None, description="Service / event category") + severity: Optional[str] = Field(None, description="Risk severity") + + +class NetWitnessResponse(BaseModel): + """Response from the NetWitness Core SDK query API.""" + + results: list[NetWitnessAlert] = Field( + default_factory=list, description="List of NetWitness alerts" + ) + + @classmethod + def from_raw_response(cls, response_data: dict[str, Any]) -> "NetWitnessResponse": + """Create from a raw SDK ``msg=query`` response. + + The SDK returns a flat list of meta ``fields`` tagged with a ``group`` + (session) identifier; this groups them back into per-session records. + + Args: + response_data: Raw response from ``/sdk?msg=query``. + + Returns: + NetWitnessResponse instance with parsed alerts. + + """ + results = response_data.get("results", {}) + fields = results.get("fields", []) if isinstance(results, dict) else [] + if not isinstance(fields, list): + fields = [] + + grouped: dict[Any, dict[str, Any]] = {} + order: list[Any] = [] + for field in fields: + if not isinstance(field, dict): + continue + group = field.get("group", 0) + meta_type = field.get("type") + if group not in grouped: + grouped[group] = {} + order.append(group) + if meta_type and meta_type not in grouped[group]: + grouped[group][meta_type] = field.get("value") + + alerts = [] + for group in order: + meta = grouped[group] + alert = NetWitnessAlert( + time=_meta_str(meta, ["time", "event.time"]) or "", + src_ip=_meta_str(meta, ["ip.src", "ipv6.src"]), + dst_ip=_meta_str(meta, ["ip.dst", "ipv6.dst"]), + url_path=_meta_str(meta, ["url", "web.host", "alias.host"]), + signature=_meta_str(meta, ["alert", "risk.info", "risk.warning"]), + rule_name=_meta_str(meta, ["alert.id", "rule.name"]), + event_type=_meta_str(meta, ["service", "event.cat.name"]), + severity=_meta_str(meta, ["risk", "severity"]), + ) + alerts.append(alert) + + return cls(results=alerts) diff --git a/netwitness/src/services/trace_service.py b/netwitness/src/services/trace_service.py new file mode 100644 index 00000000..f91badc6 --- /dev/null +++ b/netwitness/src/services/trace_service.py @@ -0,0 +1,303 @@ +"""NetWitness Trace Service Provider. + +This module provides NetWitness-specific logic for creating expectation traces +from processing results. +""" + +import logging +from datetime import datetime +from typing import Any +from urllib.parse import quote + +from pyoaev.apis.inject_expectation.model import ( # type: ignore[import-untyped] + DetectionExpectation, + PreventionExpectation, +) + +from ..collector.models import ExpectationResult, ExpectationTrace +from ..models.configs.config_loader import ConfigLoader +from .client_api import NetWitnessClientAPI +from .exception import NetWitnessDataConversionError, NetWitnessValidationError + +LOG_PREFIX = "[NetWitnessTraceService]" + + +class NetWitnessTraceService: + """NetWitness-specific trace service provider. + + This service extracts trace information from expectation processing results + and converts them into OpenAEV expectation traces using proper Pydantic models. + """ + + def __init__(self, config: ConfigLoader | None = None) -> None: + """Initialize the NetWitness trace service. + + Args: + config: Configuration loader instance for trace service settings. + + Raises: + NetWitnessValidationError: If config is None. + + """ + if config is None: + raise NetWitnessValidationError("Config is required for trace service") + + self.logger = logging.getLogger(__name__) + self.config = config + self.client_api = NetWitnessClientAPI(config) + self.logger.debug(f"{LOG_PREFIX} NetWitness trace service initialized") + + def create_traces_from_results( + self, results: list[ExpectationResult], collector_id: str + ) -> list[ExpectationTrace]: + """Create trace data from processing results. + + Args: + results: List of expectation processing results. + collector_id: ID of the collector. + + Returns: + List of ExpectationTrace models for OpenAEV. + + Raises: + NetWitnessValidationError: If inputs are invalid. + NetWitnessDataConversionError: If trace creation fails. + + """ + if not collector_id: + raise NetWitnessValidationError("collector_id cannot be empty") + + if not isinstance(results, list): + raise NetWitnessValidationError("results must be a list") + + try: + valid_results = [r for r in results if r.is_valid and r.matched_alerts] + + if not valid_results: + self.logger.info( + f"{LOG_PREFIX} No valid results with matching data for traces out of {len(results)} results" + ) + return [] + + self.logger.info( + f"{LOG_PREFIX} Creating traces for {len(valid_results)} valid results out of {len(results)} total" + ) + + traces = [] + + for i, result in enumerate(valid_results, 1): + expectation_id = result.expectation_id + if not expectation_id: + self.logger.warning( + f"{LOG_PREFIX} Skipping result {i} - missing expectation_id" + ) + continue + + self.logger.debug( + f"{LOG_PREFIX} Creating trace {i}/{len(valid_results)} for expectation {expectation_id}" + ) + + try: + trace = self._create_expectation_trace( + result, expectation_id, collector_id + ) + + if trace: + traces.append(trace) + self.logger.debug( + f"{LOG_PREFIX} Created trace for expectation {expectation_id}: {trace.inject_expectation_trace_alert_name}" + ) + else: + self.logger.warning( + f"{LOG_PREFIX} Trace creation returned None for expectation {expectation_id}" + ) + except Exception as e: + raise NetWitnessDataConversionError( + f"Error creating trace for expectation {expectation_id}: {e}" + ) from e + + self.logger.info( + f"{LOG_PREFIX} Successfully created {len(traces)} traces from {len(valid_results)} valid results" + ) + return traces + + except NetWitnessDataConversionError: + raise + except Exception as e: + raise NetWitnessDataConversionError( + f"Unexpected error creating traces from results: {e}" + ) from e + + def _create_expectation_trace( + self, result: ExpectationResult, expectation_id: str, collector_id: str + ) -> ExpectationTrace: + """Create ExpectationTrace model from a single result. + + Args: + result: Processing result dictionary. + expectation_id: ID of the expectation. + collector_id: ID of the collector. + + Returns: + ExpectationTrace model for OpenAEV. + + Raises: + NetWitnessValidationError: If inputs are invalid. + NetWitnessDataConversionError: If trace creation fails. + + """ + if not expectation_id: + raise NetWitnessValidationError("expectation_id cannot be empty") + + if not collector_id: + raise NetWitnessValidationError("collector_id cannot be empty") + + if not result.matched_alerts: + raise NetWitnessValidationError( + "result must have matched_alerts for trace creation" + ) + + try: + matching_data = result.matched_alerts[0] or {} + self.logger.debug( + f"{LOG_PREFIX} Processing matching data with {len(matching_data)} fields" + ) + + alert_name = self._determine_alert_name(matching_data) + + self.logger.debug(f"{LOG_PREFIX} Building trace URL from matching data...") + trace_link = self._build_trace_url_from_expectation(result.expectation) + self.logger.debug(f"{LOG_PREFIX} Generated trace link: {trace_link}") + + trace_date = datetime.utcnow().replace(microsecond=0) + date_str = trace_date.isoformat() + "Z" + self.logger.debug(f"{LOG_PREFIX} Generated trace date: {date_str}") + + trace = ExpectationTrace( + inject_expectation_trace_expectation=str(expectation_id), + inject_expectation_trace_source_id=str(collector_id), + inject_expectation_trace_alert_name=alert_name, + inject_expectation_trace_alert_link=trace_link, + inject_expectation_trace_date=date_str, + ) + + self.logger.debug( + f"{LOG_PREFIX} Created ExpectationTrace with alert name: {alert_name}" + ) + return trace + + except NetWitnessValidationError: + raise + except Exception as e: + raise NetWitnessDataConversionError( + f"Error creating expectation trace: {e}" + ) from e + + def _determine_alert_name(self, matching_data: dict[str, Any]) -> str: + """Determine alert name based on matching data content. + + Args: + matching_data: Dictionary containing the matched data elements. + + Returns: + Human-readable alert name based on data content. + + """ + self.logger.debug(f"{LOG_PREFIX} Creating trace for NetWitness alert") + self.logger.debug( + f"{LOG_PREFIX} Creating trace from matching data {matching_data}" + ) + + if ( + "source_ipv4_address" in matching_data + or "source_ipv6_address" in matching_data + ): + self.logger.debug( + f"{LOG_PREFIX} Creating trace for detection event (source IP)" + ) + return "NetWitness Detection Alert - Source IP" + elif ( + "target_ipv4_address" in matching_data + or "target_ipv6_address" in matching_data + ): + self.logger.debug( + f"{LOG_PREFIX} Creating trace for detection event (target IP)" + ) + return "NetWitness Detection Alert - Target IP" + else: + self.logger.debug( + f"{LOG_PREFIX} Using generic alert name - no specific IP data type identified" + ) + return "NetWitness Detection Alert" + + def _build_trace_url_from_expectation( + self, expectation: DetectionExpectation | PreventionExpectation + ) -> str: + """Build a NetWitness Investigate URL from the expectation signatures. + + Reuses ``client_api._build_search_criteria`` to extract the source and + destination IPs from the expectation signatures, then builds an + Investigate query hint from those IPs only. Unlike the NWQL query built + by ``client_api._build_query``, this URL does not include the + parent-process ``url`` match or the time window. + + Args: + expectation: The expectation object with signatures. + + Returns: + NetWitness Investigate URL hinted with the expectation's source and + destination IPs. + + Raises: + NetWitnessDataConversionError: If URL building fails. + + """ + try: + if not hasattr(self.config, "netwitness"): + self.logger.warning( + f"{LOG_PREFIX} No NetWitness config available, returning empty URL" + ) + return "" + + console_url = getattr(self.config.netwitness, "console_url", None) + if console_url: + web_base_url = str(console_url).rstrip("/") + else: + web_base_url = str(self.config.netwitness.base_url).rstrip("/") + self.logger.debug( + f"{LOG_PREFIX} Using NetWitness console URL: {web_base_url}" + ) + + search_signatures = [] + for sig in expectation.inject_expectation_signatures: + search_signatures.append({"type": sig.type.value, "value": sig.value}) + + search_criteria = self.client_api._build_search_criteria(search_signatures) + ip_terms = list(search_criteria.source_ips or []) + list( + search_criteria.target_ips or [] + ) + query_hint = quote(" ".join(ip_terms)) + url = f"{web_base_url}/investigate?query={query_hint}" + + self.logger.debug(f"{LOG_PREFIX} Built trace URL: {url}") + return url + + except Exception as e: + raise NetWitnessDataConversionError(f"Error building trace URL: {e}") from e + + def get_service_info(self) -> dict[str, Any]: + """Get information about this trace service. + + Returns: + Dictionary containing service metadata and capabilities. + + """ + info = { + "service_type": "netwitness_trace", + "supported_result_types": ["NetWitness processing results"], + "creates_detection_traces": True, + "creates_prevention_traces": False, + "description": "Creates traces from NetWitness expectation processing results", + } + self.logger.debug(f"{LOG_PREFIX} Trace service info: {info}") + return info diff --git a/netwitness/src/services/utils/__init__.py b/netwitness/src/services/utils/__init__.py new file mode 100644 index 00000000..67d78f52 --- /dev/null +++ b/netwitness/src/services/utils/__init__.py @@ -0,0 +1,5 @@ +from src.services.utils.config_loader import NetWitnessConfig + +__all__ = [ + "NetWitnessConfig", +] diff --git a/netwitness/src/services/utils/config_loader.py b/netwitness/src/services/utils/config_loader.py new file mode 100644 index 00000000..37f844e4 --- /dev/null +++ b/netwitness/src/services/utils/config_loader.py @@ -0,0 +1,72 @@ +"""Configuration loader.""" + +import logging + +from pydantic import ValidationError +from src.models import ConfigLoader + +LOG_PREFIX = "[CollectorConfig]" + + +class NetWitnessConfig: + """Class for loading NetWitness configuration.""" + + def __init__(self) -> None: + """Initialize NetWitness configuration loader. + + Loads configuration from YAML files, environment variables, and defaults. + Sets up logging and validates the configuration structure. + + Raises: + ValueError: If configuration loading or validation fails. + + """ + self.logger = logging.getLogger(__name__) + self.logger.debug(f"{LOG_PREFIX} Initializing NetWitness configuration loader") + self.load = self._load_config() + self.logger.info(f"{LOG_PREFIX} NetWitness configuration loaded successfully") + + def _load_config(self) -> ConfigLoader: + """Load configuration with proper error handling and logging. + + Loads configuration from multiple sources and validates the structure. + Logs configuration details for debugging purposes. + + Returns: + ConfigLoader instance with validated configuration. + + Raises: + ValueError: If configuration validation or loading fails. + + """ + try: + self.logger.debug( + f"{LOG_PREFIX} Loading configuration from sources (YAML/ENV/defaults)" + ) + load_settings = ConfigLoader() + + self.logger.debug( + f"{LOG_PREFIX} Collector ID: {load_settings.collector.id}" + ) + self.logger.debug( + f"{LOG_PREFIX} Collector name: {load_settings.collector.name}" + ) + self.logger.debug( + f"{LOG_PREFIX} Log level: {load_settings.collector.log_level}" + ) + self.logger.debug(f"{LOG_PREFIX} OpenAEV URL: {load_settings.openaev.url}") + self.logger.debug( + f"{LOG_PREFIX} NetWitness base URL: {load_settings.netwitness.base_url}" + ) + + return load_settings + except ValidationError as err: + self.logger.error( + f"{LOG_PREFIX} Error in configuration validation: {err} (Context: error_type=ValidationError)" + ) + raise ValueError(f"Configuration validation failed: {err}") from err + except Exception as err: + self.logger.error( + f"{LOG_PREFIX} Error in configuration loading: {err} (Context: error_type={type(err).__name__})" + ) + raise ValueError(f"Configuration loading failed: {err}") from err diff --git a/netwitness/src/services/utils/parent_process_parser.py b/netwitness/src/services/utils/parent_process_parser.py new file mode 100644 index 00000000..a9fa40f3 --- /dev/null +++ b/netwitness/src/services/utils/parent_process_parser.py @@ -0,0 +1,207 @@ +"""Utility functions for parsing parent process names and extracting UUIDs. + +This module provides functions to extract UUIDs from parent process names +and reconstruct them for matching purposes. +""" + +import logging +import re +from typing import Optional, Tuple + +LOG_PREFIX = "[ParentProcessParser]" + + +class ParentProcessParser: + """Parser for extracting and reconstructing parent process name data.""" + + def __init__(self) -> None: + """Initialize the parent process parser.""" + self.logger = logging.getLogger(__name__) + + self.uuid_pattern = ( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" + ) + self.parent_process_pattern = r"oaev-implant-([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})-agent-([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})" + + def extract_uuids_from_parent_process_name( + self, parent_process_name: str + ) -> Optional[Tuple[str, str]]: + """Extract UUIDs from parent process name. + + Args: + parent_process_name: The parent process name containing UUIDs. + Expected format: 'oaev-implant-{UUID1}-agent-{UUID2}' + + Returns: + Tuple of (inject_uuid, agent_uuid) if found, None otherwise. + + Example: + Input: 'oaev-implant-877b423b-ae91-4fc5-86c3-fa8ea3c938ba-agent-1402422f-2eaa-4fbd-80b2-b30df1b83b19' + Output: ('877b423b-ae91-4fc5-86c3-fa8ea3c938ba', '1402422f-2eaa-4fbd-80b2-b30df1b83b19') + + """ + if not parent_process_name: + self.logger.debug(f"{LOG_PREFIX} Empty parent process name provided") + return None + + try: + self.logger.debug( + f"{LOG_PREFIX} Extracting UUIDs from: {parent_process_name}" + ) + + match = re.search( + self.parent_process_pattern, parent_process_name, re.IGNORECASE + ) + if match: + inject_uuid = match.group(1) + agent_uuid = match.group(2) + self.logger.debug( + f"{LOG_PREFIX} Extracted UUIDs - inject: {inject_uuid}, agent: {agent_uuid}" + ) + return (inject_uuid, agent_uuid) + else: + self.logger.debug( + f"{LOG_PREFIX} No UUIDs found in parent process name: {parent_process_name}" + ) + return None + + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Error extracting UUIDs from parent process name: {e}" + ) + return None + + def construct_parent_process_name(self, inject_uuid: str, agent_uuid: str) -> str: + """Construct parent process name from UUIDs. + + Args: + inject_uuid: The inject UUID. + agent_uuid: The agent UUID. + + Returns: + Constructed parent process name. + + Example: + Input: inject_uuid='877b423b-ae91-4fc5-86c3-fa8ea3c938ba', + agent_uuid='1402422f-2eaa-4fbd-80b2-b30df1b83b19' + Output: 'oaev-implant-877b423b-ae91-4fc5-86c3-fa8ea3c938ba-agent-1402422f-2eaa-4fbd-80b2-b30df1b83b19' + + """ + if not inject_uuid or not agent_uuid: + self.logger.warning( + f"{LOG_PREFIX} Missing UUIDs for parent process construction" + ) + return "" + + try: + parent_process_name = f"oaev-implant-{inject_uuid}-agent-{agent_uuid}" + self.logger.debug( + f"{LOG_PREFIX} Constructed parent process name: {parent_process_name}" + ) + return parent_process_name + + except Exception as e: + self.logger.error( + f"{LOG_PREFIX} Error constructing parent process name: {e}" + ) + return "" + + def extract_uuids_from_url_path(self, url_path: str) -> Optional[Tuple[str, str]]: + """Extract UUIDs from URL path. + + Args: + url_path: URL path containing UUIDs. + Expected format: '/api/injects/{UUID1}/{UUID2}/executable-payload' + + Returns: + Tuple of (inject_uuid, agent_uuid) if found, None otherwise. + + Example: + Input: '/api/injects/877b423b-ae91-4fc5-86c3-fa8ea3c938ba/1402422f-2eaa-4fbd-80b2-b30df1b83b19/executable-payload' + Output: ('877b423b-ae91-4fc5-86c3-fa8ea3c938ba', '1402422f-2eaa-4fbd-80b2-b30df1b83b19') + + """ + if not url_path: + self.logger.debug(f"{LOG_PREFIX} Empty URL path provided") + return None + + try: + self.logger.debug( + f"{LOG_PREFIX} Extracting UUIDs from URL path: {url_path}" + ) + + url_pattern = r"/api/injects/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})/executable-payload" + + match = re.search(url_pattern, url_path, re.IGNORECASE) + if match: + inject_uuid = match.group(1) + agent_uuid = match.group(2) + self.logger.debug( + f"{LOG_PREFIX} Extracted UUIDs from URL - inject: {inject_uuid}, agent: {agent_uuid}" + ) + return (inject_uuid, agent_uuid) + else: + self.logger.debug( + f"{LOG_PREFIX} No UUIDs found in URL path: {url_path}" + ) + return None + + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Error extracting UUIDs from URL path: {e}") + return None + + def build_url_path_search_query(self, inject_uuid: str, agent_uuid: str) -> str: + """Build URL path search query from UUIDs. + + Args: + inject_uuid: The inject UUID. + agent_uuid: The agent UUID. + + Returns: + URL path search query string. + + Example: + Input: inject_uuid='877b423b-ae91-4fc5-86c3-fa8ea3c938ba', + agent_uuid='1402422f-2eaa-4fbd-80b2-b30df1b83b19' + Output: 'url_path="/api/injects/877b423b-ae91-4fc5-86c3-fa8ea3c938ba/1402422f-2eaa-4fbd-80b2-b30df1b83b19/executable-payload"' + + """ + if not inject_uuid or not agent_uuid: + self.logger.warning(f"{LOG_PREFIX} Missing UUIDs for URL path search query") + return "" + + try: + url_path = f"/api/injects/{inject_uuid}/{agent_uuid}/executable-payload" + + url_fields = ["url_path", "url", "path", "query"] + url_conditions = [] + for field in url_fields: + url_conditions.append(f'{field}="{url_path}"') + + search_query = f"({' OR '.join(url_conditions)})" + self.logger.debug( + f"{LOG_PREFIX} Built URL path search query: {search_query}" + ) + return search_query + + except Exception as e: + self.logger.error(f"{LOG_PREFIX} Error building URL path search query: {e}") + return "" + + def validate_uuid_format(self, uuid_string: str) -> bool: + """Validate if string matches UUID format. + + Args: + uuid_string: String to validate. + + Returns: + True if string matches UUID format, False otherwise. + + """ + if not uuid_string: + return False + + try: + return bool(re.match(f"^{self.uuid_pattern}$", uuid_string, re.IGNORECASE)) + except Exception: + return False diff --git a/netwitness/tests/__init__.py b/netwitness/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netwitness/tests/conftest.py b/netwitness/tests/conftest.py new file mode 100644 index 00000000..b12c3f89 --- /dev/null +++ b/netwitness/tests/conftest.py @@ -0,0 +1,93 @@ +"""Conftest file for Pytest fixtures.""" + +from typing import TYPE_CHECKING, Any +from unittest.mock import patch + +from pytest import fixture + +if TYPE_CHECKING: + from os import _Environ + + +def mock_env_vars(os_environ: "_Environ[str]", wanted_env: dict[str, str]) -> Any: + """Fixture to mock environment variables dynamically and clean up after. + + Args: + os_environ: The os.environ object to patch. + wanted_env: Dictionary of environment variables to mock. + + Returns: + Mock object for environment variable patching. + + """ + mock_env = patch.dict(os_environ, wanted_env) + mock_env.start() + + return mock_env + + +@fixture(autouse=True) +def mock_openaev_client() -> Any: + """Fixture to mock OpenAEV calls and clean up after. + + Auto-applies to all tests to prevent actual OpenAEV API calls. + Mocks urllib3, pyoaev client, and collector daemon setup. + + Yields: + Tuple of mock objects (urllib, pyoaev, daemon_setup). + + """ + mock_urllib = patch("urllib3.connectionpool.HTTPConnectionPool.urlopen") + mock_pyoaev = patch("pyoaev.client.OpenAEV.http_request") + mock_daemon_setup = patch("pyoaev.daemons.collector_daemon.CollectorDaemon._setup") + + mock_urllib.start() + mock_pyoaev.start() + mock_daemon_setup.start() + + yield mock_urllib, mock_pyoaev, mock_daemon_setup + + mock_urllib.stop() + mock_pyoaev.stop() + mock_daemon_setup.stop() + + +@fixture(autouse=True) +def disable_config_yml() -> Any: + """Fixture to disable config.yml and .env files for tests, forcing environment variable usage only. + + Auto-applies to all tests to ensure consistent configuration loading + from environment variables instead of config files. + + Yields: + Patcher object for the settings customization. + + """ + + def fake_settings_customise_sources( + cls, + settings_cls, + init_settings, + env_settings, + dotenv_settings, + file_secret_settings, + ): + from pydantic_settings import EnvSettingsSource + + # Return only environment settings source, ignoring files + return ( + EnvSettingsSource( + settings_cls, + env_ignore_empty=True, + ), + ) + + patcher = patch( + "src.models.configs.config_loader.ConfigLoader.settings_customise_sources", + new=classmethod(fake_settings_customise_sources), + ) + patcher.start() + + yield patcher + + patcher.stop() diff --git a/netwitness/tests/services/__init__.py b/netwitness/tests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netwitness/tests/services/conftest.py b/netwitness/tests/services/conftest.py new file mode 100644 index 00000000..710d0abb --- /dev/null +++ b/netwitness/tests/services/conftest.py @@ -0,0 +1,397 @@ +"""Conftest for services tests with polyfactory fixtures.""" + +from unittest.mock import Mock, patch + +import pytest +from tests.services.fixtures.factories import ( + ConfigLoaderFactory, + ExpectationResultFactory, + ExpectationTraceFactory, + MockObjectsFactory, + NetWitnessAlertFactory, + TestDataFactory, + create_test_config, +) + + +@pytest.fixture +def mock_config(): + """Provide a mock configuration for tests. + + Returns: + ConfigLoader instance with test configuration values. + + """ + return create_test_config() + + +@pytest.fixture +def mock_client_api(): + """Provide a mock NetWitness client API. + + Returns: + Mock NetWitness client API instance for testing. + + """ + return MockObjectsFactory.create_mock_client_api() + + +@pytest.fixture +def mock_detection_helper(): + """Provide a mock detection helper that matches by default. + + Returns: + Mock OpenAEV detection helper that returns True for matches. + + """ + return MockObjectsFactory.create_mock_detection_helper(match_result=True) + + +@pytest.fixture +def mock_detection_helper_no_match(): + """Provide a mock detection helper that doesn't match. + + Returns: + Mock OpenAEV detection helper that returns False for matches. + + """ + return MockObjectsFactory.create_mock_detection_helper(match_result=False) + + +@pytest.fixture +def sample_alert(): + """Provide a sample NetWitness alert. + + Returns: + NetWitnessAlert instance for testing. + + """ + return NetWitnessAlertFactory.build() + + +@pytest.fixture +def sample_alerts(): + """Provide a list of sample NetWitness alerts. + + Returns: + List of 3 NetWitnessAlert instances for testing. + + """ + return [NetWitnessAlertFactory.build() for _ in range(3)] + + +@pytest.fixture +def sample_expectation_result(): + """Provide a sample expectation result. + + Returns: + ExpectationResult instance for testing. + + """ + return ExpectationResultFactory.build() + + +@pytest.fixture +def sample_expectation_trace(): + """Provide a sample expectation trace. + + Returns: + ExpectationTrace instance for testing. + + """ + return ExpectationTraceFactory.build() + + +@pytest.fixture +def detection_signatures(): + """Provide sample detection expectation signatures. + + Returns: + List of signature dictionaries for detection expectations. + + """ + return TestDataFactory.create_expectation_signatures( + signature_type="source_ipv4_address" + ) + + +@pytest.fixture +def ip_signatures(): + """Provide sample IP-based expectation signatures. + + Returns: + List of signature dictionaries for IP-based expectations. + + """ + return TestDataFactory.create_expectation_signatures( + signature_type="target_ipv4_address" + ) + + +@pytest.fixture +def oaev_detection_data(): + """Provide sample OAEV detection data. + + Returns: + List of OAEV-formatted detection data dictionaries. + + """ + return TestDataFactory.create_oaev_detection_data() + + +@pytest.fixture +def mixed_netwitness_data(): + """Provide mixed NetWitness data (alerts). + + Returns: + List containing NetWitnessAlert instances. + + """ + return TestDataFactory.create_mixed_netwitness_data() + + +@pytest.fixture +def mock_expectation_detection(): + """Provide a mock detection expectation. + + Returns: + Mock DetectionExpectation instance for testing. + + """ + return MockObjectsFactory.create_mock_expectation(expectation_type="detection") + + +@pytest.fixture +def mock_requests_session(): + """Provide a mock requests session. + + Returns: + Mock requests.Session instance for HTTP testing. + + """ + return MockObjectsFactory.create_mock_session() + + +@pytest.fixture +def api_response_data(): + """Provide sample API response data. + + Returns: + Dictionary containing mock NetWitness API response. + + """ + return TestDataFactory.create_api_response_data() + + +@pytest.fixture(autouse=True) +def mock_logging(): + """Auto-mock logging to reduce noise in tests. + + Auto-applies to all tests to prevent logging output during test execution. + + Yields: + Mock logger instance. + + """ + with patch("logging.getLogger") as mock_logger: + mock_logger.return_value = Mock() + yield mock_logger + + +@pytest.fixture +def disable_sleep(): + """Disable time.sleep in tests for faster execution. + + Patches time.sleep to prevent actual delays during testing. + + Yields: + None (context manager for sleep patching). + + """ + with patch("time.sleep"): + yield + + +# Parametrized fixtures for testing different scenarios +@pytest.fixture(params=[1, 3, 5]) +def various_counts(request): + """Provide various counts for testing different data sizes. + + Args: + request: Pytest request object containing parameter values. + + Returns: + Integer count (1, 3, or 5) for parameterized testing. + + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def match_scenarios(request): + """Provide both matching and non-matching scenarios. + + Args: + request: Pytest request object containing parameter values. + + Returns: + Boolean value (True or False) for match testing scenarios. + + """ + return request.param + + +@pytest.fixture( + params=[ + "source_ipv4_address", + "target_ipv4_address", + "source_ipv6_address", + "target_ipv6_address", + ] +) +def ip_signature_types(request): + """Provide different IP signature types. + + Args: + request: Pytest request object containing parameter values. + + Returns: + String IP signature type for testing. + + """ + return request.param + + +# Factory fixtures that can be called in tests +@pytest.fixture +def config_factory(): + """Provide the ConfigLoaderFactory for creating configs in tests. + + Returns: + ConfigLoaderFactory class for generating test configurations. + + """ + return ConfigLoaderFactory + + +@pytest.fixture +def alert_factory(): + """Provide the NetWitnessAlertFactory for creating alerts. + + Returns: + NetWitnessAlertFactory class for generating test alerts. + + """ + return NetWitnessAlertFactory + + +@pytest.fixture +def expectation_result_factory(): + """Provide the ExpectationResultFactory for creating results. + + Returns: + ExpectationResultFactory class for generating test results. + + """ + return ExpectationResultFactory + + +@pytest.fixture +def expectation_trace_factory(): + """Provide the ExpectationTraceFactory for creating traces. + + Returns: + ExpectationTraceFactory class for generating test traces. + + """ + return ExpectationTraceFactory + + +@pytest.fixture +def test_data_factory(): + """Provide the TestDataFactory for creating test data combinations. + + Returns: + TestDataFactory class for generating complex test data scenarios. + + """ + return TestDataFactory + + +@pytest.fixture +def mock_objects_factory(): + """Provide the MockObjectsFactory for creating mock objects. + + Returns: + MockObjectsFactory class for generating mock instances. + + """ + return MockObjectsFactory + + +# Cleanup fixtures +@pytest.fixture(autouse=True) +def cleanup_mocks(): + """Auto-cleanup mocks after each test. + + Auto-applies to all tests to ensure proper mock cleanup. + + Yields: + None (context manager for cleanup operations). + + """ + yield + # Any cleanup logic can go here if needed + + +# Error simulation fixtures +@pytest.fixture +def api_error_responses(): + """Provide various API error responses for testing error handling. + + Returns: + Dictionary mapping HTTP status codes to error response data. + + """ + return { + "400": { + "status_code": 400, + "text": "Bad Request", + "json": {"errors": ["Bad request"]}, + }, + "401": { + "status_code": 401, + "text": "Unauthorized", + "json": {"errors": ["Unauthorized"]}, + }, + "403": { + "status_code": 403, + "text": "Forbidden", + "json": {"errors": ["Forbidden"]}, + }, + "404": { + "status_code": 404, + "text": "Not Found", + "json": {"errors": ["Not found"]}, + }, + "500": { + "status_code": 500, + "text": "Internal Server Error", + "json": {"errors": ["Server error"]}, + }, + } + + +@pytest.fixture +def network_errors(): + """Provide various network errors for testing error handling. + + Returns: + List of different exception types for network error testing. + + """ + return [ + ConnectionError("Connection failed"), + TimeoutError("Request timeout"), + Exception("Generic network error"), + ] diff --git a/netwitness/tests/services/fixtures/__init__.py b/netwitness/tests/services/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netwitness/tests/services/fixtures/factories.py b/netwitness/tests/services/fixtures/factories.py new file mode 100644 index 00000000..08e392d8 --- /dev/null +++ b/netwitness/tests/services/fixtures/factories.py @@ -0,0 +1,437 @@ +"""Essential polyfactory factories for NetWitness models and test fixtures.""" + +import os +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List +from unittest.mock import Mock + +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from src.collector.models import ExpectationResult, ExpectationTrace +from src.models.configs.collector_configs import _ConfigLoaderOAEV +from src.models.configs.config_loader import ConfigLoader, ConfigLoaderCollector +from src.models.configs.netwitness_configs import _ConfigLoaderNetWitness +from src.services.models import NetWitnessAlert, NetWitnessSearchCriteria + + +class ConfigLoaderOAEVFactory(ModelFactory[_ConfigLoaderOAEV]): + """Factory for OpenAEV configuration. + + Creates test instances of OpenAEV configuration with required + environment variables automatically set. + """ + + __check_model__ = False + + @classmethod + def build(cls, **kwargs): + """Build the model with required environment variables set. + + Args: + **kwargs: Additional keyword arguments for model creation. + + Returns: + _ConfigLoaderOAEV instance with test configuration. + + """ + os.environ["OPENAEV_URL"] = "https://test-openaev.example.com" + os.environ["OPENAEV_TOKEN"] = "test-openaev-token-12345" # noqa: S105 + return super().build(**kwargs) + + +class ConfigLoaderNetWitnessFactory(ModelFactory[_ConfigLoaderNetWitness]): + """Factory for NetWitness configuration. + + Constructs deterministic test instances of the NetWitness + configuration (basic authentication, instant retries). The model is built + directly so values are passed as the highest-priority init source and are + never overridden by ambient environment variables. + """ + + __check_model__ = False + + token = Use(lambda: None) + username = Use(lambda: "test-user") + password = Use(lambda: "test-password") # noqa: S106 + + @classmethod + def build(cls, **kwargs): + """Build the model with required environment variables set. + + Two build paths must both yield valid auth: polyfactory's own field + generation (covered by the ``Use`` overrides above) and the nested + ``netwitness`` settings rebuilt from the environment when the parent + ``ConfigLoader`` is constructed (covered by the env vars below). + + Args: + **kwargs: Additional keyword arguments for model creation. + + Returns: + _ConfigLoaderNetWitness instance with test configuration. + + """ + os.environ["NETWITNESS_BASE_URL"] = "https://test-netwitness.example.com:50103" + os.environ["NETWITNESS_USERNAME"] = "test-user" + os.environ["NETWITNESS_PASSWORD"] = "test-password" # noqa: S105 + os.environ["NETWITNESS_MAX_RETRY"] = "1" + os.environ["NETWITNESS_OFFSET"] = "PT0S" + os.environ.pop("NETWITNESS_TOKEN", None) + os.environ.pop("NETWITNESS_CONSOLE_URL", None) + return super().build(**kwargs) + + +class ConfigLoaderCollectorFactory(ModelFactory[ConfigLoaderCollector]): + """Factory for Collector configuration. + + Creates test instances of collector configuration with auto-generated + UUIDs and sensible defaults. + """ + + __check_model__ = False + + id = Use(lambda: f"netwitness--{uuid.uuid4()}") + name = "NetWitness" + + +class ConfigLoaderFactory(ModelFactory[ConfigLoader]): + """Factory for main configuration. + + Creates complete test configuration instances combining OpenAEV, + collector, and NetWitness settings using subfactories. + """ + + __check_model__ = False + + openaev = Use(ConfigLoaderOAEVFactory.build) + collector = Use(ConfigLoaderCollectorFactory.build) + netwitness = Use(ConfigLoaderNetWitnessFactory.build) + + +class NetWitnessSearchCriteriaFactory(ModelFactory[NetWitnessSearchCriteria]): + """Factory for NetWitnessSearchCriteria. + + Creates test instances of NetWitness search criteria with + realistic IP addresses and date ranges for queries. + """ + + __check_model__ = False + + source_ips = Use(lambda: ["192.168.1.100", "10.0.0.50"]) + target_ips = Use(lambda: ["172.16.0.10", "203.0.113.5"]) + start_date = Use( + lambda: (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + "Z" + ) + end_date = Use( + lambda: (datetime.now(timezone.utc) + timedelta(microseconds=1)).isoformat() + + "Z" + ) + + +class NetWitnessAlertFactory(ModelFactory[NetWitnessAlert]): + """Factory for NetWitness alerts. + + Creates test instances of NetWitness alert objects with + randomized IP addresses and alert metadata. + """ + + __check_model__ = False + + time = Use(lambda: datetime.now(timezone.utc).isoformat() + "Z") + src_ip = Use(lambda: f"192.168.1.{uuid.uuid4().int % 255}") + dst_ip = Use(lambda: f"10.0.0.{uuid.uuid4().int % 255}") + signature = Use(lambda: f"Test Malicious Activity {uuid.uuid4().hex[:8]}") + rule_name = Use(lambda: f"Test Security Rule {uuid.uuid4().hex[:8]}") + event_type = "Security Alert" + severity = "High" + + +class ExpectationResultFactory(ModelFactory[ExpectationResult]): + """Factory for ExpectationResult. + + Creates test instances of expectation processing results with + valid expectation IDs and configurable validation status. + """ + + __check_model__ = False + + expectation_id = Use(lambda: str(uuid.uuid4())) + is_valid = True + error_message = None + matched_alerts = Use(lambda: []) + + +class ExpectationTraceFactory(ModelFactory[ExpectationTrace]): + """Factory for ExpectationTrace. + + Creates test instances of expectation traces for OpenAEV + with properly formatted trace data. + """ + + __check_model__ = False + + inject_expectation_trace_expectation = Use(lambda: str(uuid.uuid4())) + inject_expectation_trace_source_id = Use(lambda: f"netwitness--{uuid.uuid4()}") + inject_expectation_trace_alert_name = "NetWitness Detection Alert" + inject_expectation_trace_alert_link = Use( + lambda: f"https://test-netwitness.example.com:5601/app/security/alerts?query=test-{uuid.uuid4().hex[:8]}" + ) + inject_expectation_trace_date = Use( + lambda: datetime.now(timezone.utc).isoformat() + "Z" + ) + + +# Mock Objects Factory +class MockObjectsFactory: + """Factory for creating mock objects. + + Provides static methods for creating various mock objects + used throughout the test suite. + """ + + @staticmethod + def create_mock_client_api(): + """Create mock NetWitness client API. + + Returns: + Mock NetWitnessClientAPI instance with basic attributes set. + + """ + mock_client = Mock() + mock_client.base_url = "https://test-netwitness.example.com:9200" + mock_client.session = Mock() + mock_client.session.headers = {} + return mock_client + + @staticmethod + def create_mock_detection_helper(match_result: bool = True): + """Create mock detection helper. + + Args: + match_result: Whether the helper should return matches (default True). + + Returns: + Mock OpenAEVDetectionHelper instance. + + """ + mock_helper = Mock() + mock_helper.match_alert_elements.return_value = match_result + return mock_helper + + @staticmethod + def create_mock_expectation( + expectation_type: str = "detection", expectation_id: str = None + ): + """Create mock expectation for testing. + + Args: + expectation_type: Type of expectation ("detection" only for NetWitness). + expectation_id: Optional custom expectation ID. + + Returns: + Mock expectation object with required attributes. + + """ + mock_expectation = Mock() + mock_expectation.inject_expectation_id = expectation_id or str(uuid.uuid4()) + mock_expectation.inject_expectation_signatures = [] + mock_expectation.expectation_type = expectation_type + return mock_expectation + + @staticmethod + def create_mock_session(): + """Create mock requests session. + + Returns: + Mock requests.Session instance with headers attribute. + + """ + mock_session = Mock() + mock_session.headers = {} + mock_session.auth = ("test-user", "test-password") + return mock_session + + +# Test Data Factory +class TestDataFactory: + """Factory for creating essential test data. + + Provides static methods for creating complex test data structures + that simulate real-world scenarios. + """ + + @staticmethod + def create_expectation_signatures( + signature_type: str = "source_ipv4_address", signature_value: str = None + ) -> List[Dict[str, Any]]: + """Create expectation signatures. + + Args: + signature_type: Type of signature to create. + signature_value: Optional custom signature value. + + Returns: + List of signature dictionaries for testing. + + """ + if signature_value is None: + if "ip" in signature_type: + signature_value = f"192.168.1.{uuid.uuid4().int % 255}" + else: + signature_value = f"test-{signature_type}-{uuid.uuid4().hex[:8]}" + + return [{"type": signature_type, "value": signature_value}] + + @staticmethod + def create_oaev_detection_data() -> List[Dict[str, Any]]: + """Create OAEV detection data for IP-based matching. + + Returns: + List of OAEV-formatted detection data dictionaries. + + """ + return [ + { + "source_ipv4_address": { + "type": "simple", + "data": f"192.168.1.{uuid.uuid4().int % 255}", + }, + "target_ipv4_address": { + "type": "simple", + "data": f"10.0.0.{uuid.uuid4().int % 255}", + }, + "parent_process_name": { + "type": "simple", + "data": "test_process.exe", + }, + } + ] + + @staticmethod + def create_mixed_netwitness_data() -> List[Any]: + """Create mixed NetWitness data (alerts). + + Returns: + List containing NetWitnessAlert instances. + + """ + return create_test_netwitness_alerts(count=3) + + @staticmethod + def create_api_response_data() -> Dict[str, Any]: + """Create mock NetWitness Core SDK query results data. + + Returns: + Dictionary simulating a ``/sdk?msg=query`` JSON response. + + """ + return { + "results": { + "fields": [ + { + "count": 1, + "format": 32, + "group": 1, + "type": "ip.src", + "value": "192.168.1.100", + }, + { + "count": 1, + "format": 32, + "group": 1, + "type": "ip.dst", + "value": "10.0.0.50", + }, + { + "count": 1, + "format": 8, + "group": 1, + "type": "url", + "value": "/api/test", + }, + { + "count": 1, + "format": 8, + "group": 1, + "type": "service", + "value": "HTTP", + }, + { + "count": 1, + "format": 32, + "group": 2, + "type": "ip.src", + "value": "172.16.0.10", + }, + { + "count": 1, + "format": 32, + "group": 2, + "type": "ip.dst", + "value": "203.0.113.5", + }, + { + "count": 1, + "format": 8, + "group": 2, + "type": "service", + "value": "DNS", + }, + ] + } + } + + +# Helper functions +def create_test_config(**overrides) -> ConfigLoader: + """Create test configuration. + + Args: + **overrides: Configuration values to override defaults. + + Returns: + ConfigLoader instance with test configuration. + + """ + return ConfigLoaderFactory.build(**overrides) + + +def create_test_netwitness_alerts(count: int = 1) -> List[NetWitnessAlert]: + """Create test NetWitness alerts with varied IP configurations. + + Args: + count: Number of alerts to create (default 1). + + Returns: + List of NetWitnessAlert instances with test data. + + """ + alerts = [] + for i in range(count): + if i % 2 == 0: + alert = NetWitnessAlertFactory.build( + src_ip=f"192.168.1.{100 + i}", + dst_ip=f"10.0.0.{50 + i}", + ) + else: + alert = NetWitnessAlertFactory.build( + source_ip=f"172.16.0.{10 + i}", + destination_ip=f"203.0.113.{5 + i}", + src_ip=None, + dst_ip=None, + ) + alerts.append(alert) + return alerts + + +def create_test_search_criteria(**overrides) -> NetWitnessSearchCriteria: + """Create test search criteria. + + Args: + **overrides: Criteria values to override defaults. + + Returns: + NetWitnessSearchCriteria instance with test configuration. + + """ + return NetWitnessSearchCriteriaFactory.build(**overrides) diff --git a/netwitness/tests/services/test_client_api_essential.py b/netwitness/tests/services/test_client_api_essential.py new file mode 100644 index 00000000..041fd9f0 --- /dev/null +++ b/netwitness/tests/services/test_client_api_essential.py @@ -0,0 +1,236 @@ +"""Essential tests for NetWitness Client API service.""" + +from unittest.mock import Mock, patch + +import pytest +from pydantic import SecretStr +from requests import Session +from src.services.client_api import NetWitnessClientAPI +from src.services.exception import ( + NetWitnessAPIError, + NetWitnessAuthenticationError, + NetWitnessValidationError, +) +from src.services.models import NetWitnessSearchCriteria +from tests.services.fixtures.factories import TestDataFactory, create_test_config + +PARENT_PROCESS_NAME = ( + "oaev-implant-12345678-1234-1234-1234-123456789abc" + "-agent-87654321-4321-4321-4321-cba987654321" +) + +SIGNATURES = [{"type": "source_ipv4_address", "value": "192.168.1.100"}] + + +def _ok_response() -> Mock: + """Mock a successful SDK query response.""" + response = Mock() + response.status_code = 200 + response.json.return_value = TestDataFactory.create_api_response_data() + return response + + +class TestNetWitnessClientAPIEssential: + """Essential test cases for NetWitnessClientAPI.""" + + def test_init_with_valid_config(self): + """Client initializes with config values and a session.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + assert client.config == config # noqa: S101 + assert client.base_url == str(config.netwitness.base_url).rstrip( + "/" + ) # noqa: S101 + assert client.username == config.netwitness.username # noqa: S101 + assert ( # noqa: S101 + client.password == config.netwitness.password.get_secret_value() + ) + assert isinstance(client.session, Session) # noqa: S101 + + def test_init_without_config_raises_error(self): + """Initialization without config raises a validation error.""" + with pytest.raises(NetWitnessValidationError): + NetWitnessClientAPI(config=None) + + def test_create_session_with_basic_auth(self): + """Username/password configure HTTP basic auth, not a bearer header.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + assert client.session.auth == ("test-user", "test-password") # noqa: S101 + assert "Authorization" not in client.session.headers # noqa: S101 + + def test_create_session_with_token(self): + """A token configures the bearer Authorization header.""" + config = create_test_config() + config.netwitness.token = SecretStr("my-token") + client = NetWitnessClientAPI(config=config) + + assert ( + client.session.headers["Authorization"] == "Bearer my-token" + ) # noqa: S101 + + @patch("requests.Session.get") + def test_fetch_signatures_detection_success(self, mock_get): + """A successful query returns parsed alerts grouped by session.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + mock_get.return_value = _ok_response() + + result = client.fetch_signatures(SIGNATURES, "detection") + + assert len(result) == 2 # noqa: S101 + assert result[0].src_ip == "192.168.1.100" # noqa: S101 + assert result[0].dst_ip == "10.0.0.50" # noqa: S101 + + @patch("requests.Session.get") + def test_fetch_signatures_builds_query_with_ips(self, mock_get): + """IP signatures are turned into an NWQL query passed to the SDK.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + mock_get.return_value = _ok_response() + + client.fetch_signatures( + [ + {"type": "source_ipv4_address", "value": "192.168.1.100"}, + {"type": "target_ipv4_address", "value": "10.0.0.50"}, + ], + "detection", + ) + + query = mock_get.call_args.kwargs["params"]["query"] + assert "ip.src=192.168.1.100" in query # noqa: S101 + assert "ip.dst=10.0.0.50" in query # noqa: S101 + + @patch("requests.Session.get") + def test_fetch_signatures_authentication_error(self, mock_get): + """A 401 raises NetWitnessAuthenticationError.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + response = Mock() + response.status_code = 401 + response.text = "Unauthorized" + mock_get.return_value = response + + with pytest.raises(NetWitnessAuthenticationError): + client.fetch_signatures(SIGNATURES, "detection") + + @patch("src.services.client_api.time.sleep") + @patch("requests.Session.get") + def test_fetch_signatures_no_data_returns_empty(self, mock_get, mock_sleep): + """No result fields yields an empty list.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + empty = Mock() + empty.status_code = 200 + empty.json.return_value = {"results": {"fields": []}} + mock_get.return_value = empty + + result = client.fetch_signatures(SIGNATURES, "detection") + + assert result == [] # noqa: S101 + + @patch("src.services.client_api.time.sleep") + @patch("requests.Session.get") + def test_fetch_signatures_exception_handling(self, mock_get, mock_sleep): + """Repeated errors are wrapped in NetWitnessAPIError.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + mock_get.side_effect = Exception("Network Error") + + with pytest.raises(NetWitnessAPIError) as exc_info: + client.fetch_signatures(SIGNATURES, "detection") + + assert "All NetWitness fetch attempts failed." in str( # noqa: S101 + exc_info.value + ) + + def test_build_query_with_ips(self): + """NWQL building includes IP conditions and a time range.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + criteria = NetWitnessSearchCriteria( + source_ips=["192.168.1.100"], target_ips=["10.0.0.50"] + ) + query = client._build_query(criteria) + + assert "ip.src=192.168.1.100" in query # noqa: S101 + assert "ip.dst=10.0.0.50" in query # noqa: S101 + assert query.startswith("select ") # noqa: S101 + assert "time=" in query # noqa: S101 + + def test_build_query_with_parent_process(self): + """NWQL building includes a URL contains clause for parent process.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + criteria = NetWitnessSearchCriteria( + source_ips=["192.168.1.100"], + parent_process_names=[PARENT_PROCESS_NAME], + ) + query = client._build_query(criteria) + + assert "url contains" in query # noqa: S101 + assert "/api/injects/" in query # noqa: S101 + assert "executable-payload" in query # noqa: S101 + + def test_build_query_time_window_extension(self): + """Retries widen the NWQL time window.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + criteria = NetWitnessSearchCriteria(source_ips=["192.168.1.100"]) + query1 = client._build_query(criteria, extend_end_seconds=0) + query2 = client._build_query(criteria, extend_end_seconds=86400) + + assert query1 != query2 # noqa: S101 + + def test_build_search_criteria_from_signatures(self): + """Signatures are extracted into a NetWitnessSearchCriteria object.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + criteria = client._build_search_criteria( + [ + {"type": "source_ipv4_address", "value": "192.168.1.100"}, + {"type": "target_ipv6_address", "value": "2001:db8::1"}, + {"type": "parent_process_name", "value": PARENT_PROCESS_NAME}, + {"type": "start_date", "value": "2024-01-01T00:00:00Z"}, + {"type": "end_date", "value": "2024-01-01T23:59:59Z"}, + ] + ) + + assert criteria.source_ips == ["192.168.1.100"] # noqa: S101 + assert criteria.target_ips == ["2001:db8::1"] # noqa: S101 + assert criteria.parent_process_names == [PARENT_PROCESS_NAME] # noqa: S101 + assert criteria.start_date == "2024-01-01T00:00:00Z" # noqa: S101 + assert criteria.end_date == "2024-01-01T23:59:59Z" # noqa: S101 + + def test_prevention_expectation_not_supported(self): + """Prevention expectations raise a validation error.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + with pytest.raises(NetWitnessValidationError) as exc_info: + client.fetch_signatures(SIGNATURES, "prevention") + + assert "Invalid expectation_type" in str(exc_info.value) # noqa: S101 + + def test_parent_process_uuid_extraction(self): + """UUIDs are extracted from a parent process name.""" + config = create_test_config() + client = NetWitnessClientAPI(config=config) + + uuids = client.parent_process_parser.extract_uuids_from_parent_process_name( + PARENT_PROCESS_NAME + ) + + assert uuids is not None # noqa: S101 + inject_uuid, agent_uuid = uuids + assert inject_uuid == "12345678-1234-1234-1234-123456789abc" # noqa: S101 + assert agent_uuid == "87654321-4321-4321-4321-cba987654321" # noqa: S101 diff --git a/netwitness/tests/services/test_client_api_extra.py b/netwitness/tests/services/test_client_api_extra.py new file mode 100644 index 00000000..2de6c891 --- /dev/null +++ b/netwitness/tests/services/test_client_api_extra.py @@ -0,0 +1,82 @@ +"""Additional branch-coverage tests for NetWitnessClientAPI.""" + +from unittest.mock import Mock, patch + +import pytest +import requests +from src.services.client_api import NetWitnessClientAPI +from src.services.exception import NetWitnessAPIError, NetWitnessValidationError +from tests.services.fixtures.factories import TestDataFactory, create_test_config + +SIGNATURES = [{"type": "source_ipv4_address", "value": "1.2.3.4"}] + + +def _ok_response() -> Mock: + """Mock a successful SDK query response.""" + response = Mock() + response.status_code = 200 + response.json.return_value = TestDataFactory.create_api_response_data() + return response + + +class TestNetWitnessClientAPIExtra: + """Branch-coverage tests for NetWitnessClientAPI.""" + + @patch("src.services.client_api.time.sleep") + @patch("requests.Session.get") + def test_network_error_wrapped(self, mock_get, mock_sleep): + """Connection errors are retried then wrapped in NetWitnessAPIError.""" + client = NetWitnessClientAPI(config=create_test_config()) + mock_get.side_effect = requests.exceptions.ConnectionError("net down") + + with pytest.raises(NetWitnessAPIError): + client.fetch_signatures(SIGNATURES, "detection") + + @patch("src.services.client_api.time.sleep") + @patch("requests.Session.get") + def test_server_error_status(self, mock_get, mock_sleep): + """A 500 response is retried then wrapped in NetWitnessAPIError.""" + client = NetWitnessClientAPI(config=create_test_config()) + response = Mock() + response.status_code = 500 + response.text = "server error" + mock_get.return_value = response + + with pytest.raises(NetWitnessAPIError): + client.fetch_signatures(SIGNATURES, "detection") + + def test_fetch_with_retry_empty_signatures(self): + """Empty signatures raise a validation error.""" + client = NetWitnessClientAPI(config=create_test_config()) + with pytest.raises(NetWitnessValidationError): + client.fetch_with_retry([], "detection") + + def test_fetch_with_retry_invalid_type(self): + """Non-detection expectation types raise a validation error.""" + client = NetWitnessClientAPI(config=create_test_config()) + with pytest.raises(NetWitnessValidationError): + client.fetch_with_retry(SIGNATURES, "prevention") + + @patch("requests.Session.get") + def test_fetch_with_retry_success(self, mock_get): + """A successful query yields parsed alerts.""" + client = NetWitnessClientAPI(config=create_test_config()) + mock_get.return_value = _ok_response() + + result = client.fetch_with_retry(SIGNATURES, "detection") + + assert len(result) == 2 # noqa: S101 + + @patch("src.services.client_api.time.sleep") + @patch("requests.Session.get") + def test_fetch_with_retry_empty_returns_empty(self, mock_get, mock_sleep): + """No result fields after all retries returns an empty list.""" + client = NetWitnessClientAPI(config=create_test_config()) + empty = Mock() + empty.status_code = 200 + empty.json.return_value = {"results": {"fields": []}} + mock_get.return_value = empty + + result = client.fetch_with_retry(SIGNATURES, "detection") + + assert result == [] # noqa: S101 diff --git a/netwitness/tests/services/test_converter_essential.py b/netwitness/tests/services/test_converter_essential.py new file mode 100644 index 00000000..2387a68e --- /dev/null +++ b/netwitness/tests/services/test_converter_essential.py @@ -0,0 +1,194 @@ +"""Essential tests for NetWitness Converter service.""" + +from src.services.converter import Converter +from tests.services.fixtures.factories import ( + NetWitnessAlertFactory, + create_test_netwitness_alerts, +) + + +class TestConverterEssential: + """Essential test cases for NetWitness Converter. + + Tests the core functionality of the NetWitness data converter including + initialization, data type detection, and conversion to OAEV format. + """ + + def test_init(self): + """Test that Converter initializes correctly. + + Verifies that the converter instance is properly initialized + with a logger and ready for data conversion operations. + """ + converter = Converter() + assert converter.logger is not None # noqa: S101 + + def test_convert_empty_data_returns_empty_list(self): + """Test converting empty data returns empty list. + + Verifies that both None and empty list inputs result in + empty list outputs without raising exceptions. + """ + converter = Converter() + + result_none = converter.convert_data_to_oaev_data(None) + result_empty = converter.convert_data_to_oaev_data([]) + + assert result_none == [] # noqa: S101 + assert result_empty == [] # noqa: S101 + + def test_convert_alert_with_source_and_target_ips(self): + """Test converting alert with both source and target IP addresses. + + Verifies that NetWitness alerts containing both source and target IPs + are properly converted to OAEV format with correct structure. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build(src_ip="192.168.1.100", dst_ip="10.0.0.50") + + result = converter.convert_data_to_oaev_data(alert) + + assert len(result) == 1 # noqa: S101 + assert "source_ipv4_address" in result[0] # noqa: S101 + assert "target_ipv4_address" in result[0] # noqa: S101 + assert result[0]["source_ipv4_address"]["type"] == "simple" # noqa: S101 + assert result[0]["target_ipv4_address"]["type"] == "simple" # noqa: S101 + assert result[0]["source_ipv4_address"]["data"] == [ # noqa: S101 + "192.168.1.100" + ] + assert result[0]["target_ipv4_address"]["data"] == ["10.0.0.50"] # noqa: S101 + + def test_convert_alert_with_only_source_ip(self): + """Test converting alert with only source IP address. + + Verifies that alerts containing only source IP addresses + result in OAEV data with only source IP field populated. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build( + src_ip="192.168.1.100", dst_ip=None, source_ip=None, destination_ip=None + ) + + result = converter.convert_data_to_oaev_data(alert) + + assert len(result) == 1 # noqa: S101 + assert "source_ipv4_address" in result[0] # noqa: S101 + assert "target_ipv4_address" not in result[0] # noqa: S101 + assert result[0]["source_ipv4_address"]["data"] == [ # noqa: S101 + "192.168.1.100" + ] + + def test_convert_alert_with_only_target_ip(self): + """Test converting alert with only target IP address. + + Verifies that alerts containing only target IP addresses + result in OAEV data with only target IP field populated. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build( + src_ip=None, dst_ip="10.0.0.50", source_ip=None, destination_ip=None + ) + + result = converter.convert_data_to_oaev_data(alert) + + assert len(result) == 1 # noqa: S101 + assert "source_ipv4_address" not in result[0] # noqa: S101 + assert "target_ipv4_address" in result[0] # noqa: S101 + assert result[0]["target_ipv4_address"]["data"] == ["10.0.0.50"] # noqa: S101 + + def test_convert_alert_without_ips(self): + """Test converting alert without any IP addresses. + + Verifies that alerts without any IP addresses are filtered out + and do not appear in the final OAEV data list. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build( + src_ip=None, dst_ip=None, source_ip=None, destination_ip=None + ) + + result = converter.convert_data_to_oaev_data(alert) + + # Should be filtered out completely - no results + assert len(result) == 0 # noqa: S101 + + def test_convert_multiple_alerts_list(self): + """Test converting list of multiple alerts. + + Verifies that lists containing multiple NetWitnessAlert objects + are processed correctly, with each alert converted to its + appropriate OAEV format. + """ + converter = Converter() + + alerts = create_test_netwitness_alerts(count=3) + + result = converter.convert_data_to_oaev_data(alerts) + + # Filter out alerts without IPs - some test alerts may not have IPs + assert len(result) >= 2 # noqa: S101 + # Each result should be a dictionary + assert all(isinstance(item, dict) for item in result) # noqa: S101 + + def test_convert_invalid_data_handles_gracefully(self): + """Test converting invalid data handles gracefully. + + Verifies that unknown or invalid data types are handled gracefully + by returning empty results without raising exceptions. + """ + converter = Converter() + invalid_data = {"unknown": "data", "type": "mystery"} + + result = converter.convert_data_to_oaev_data(invalid_data) + + assert result == [] # noqa: S101 + + def test_extract_source_ips_from_multiple_fields(self): + """Test extracting source IPs from multiple possible fields. + + Verifies that the converter correctly extracts source IPs from + both src_ip and source_ip fields, handling duplicates properly. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build( + src_ip="192.168.1.100", + source_ip="192.168.1.100", # Same IP in both fields + ) + + source_ips = converter._extract_source_ips(alert) + + # Should have single IP (consolidated field) + assert len(source_ips) == 1 # noqa: S101 + assert "192.168.1.100" in source_ips # noqa: S101 + + def test_extract_target_ips_from_multiple_fields(self): + """Test extracting target IPs from multiple possible fields. + + Verifies that the converter uses consolidated target IP field, + prioritizing dst_ip over destination_ip. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build( + dst_ip="10.0.0.50", + destination_ip="203.0.113.5", # Different IP in alternative field + ) + + target_ips = converter._extract_target_ips(alert) + + # Should use consolidated field (dst_ip takes priority) + assert len(target_ips) == 1 # noqa: S101 + assert "10.0.0.50" in target_ips # noqa: S101 + + def test_alert_data_type_detection(self): + """Test that converter correctly detects NetWitnessAlert data type. + + Verifies that the _is_alert_data method correctly identifies + NetWitnessAlert instances vs other data types. + """ + converter = Converter() + alert = NetWitnessAlertFactory.build() + non_alert = {"not": "an alert"} + + assert converter._is_alert_data(alert) is True # noqa: S101 + assert converter._is_alert_data(non_alert) is False # noqa: S101 + assert converter._is_alert_data(None) is False # noqa: S101 diff --git a/netwitness/tests/services/test_converter_extra.py b/netwitness/tests/services/test_converter_extra.py new file mode 100644 index 00000000..f61d88a7 --- /dev/null +++ b/netwitness/tests/services/test_converter_extra.py @@ -0,0 +1,57 @@ +"""Additional branch-coverage tests for the NetWitness Converter.""" + +import pytest +from src.services.converter import Converter +from src.services.exception import NetWitnessValidationError +from src.services.models import NetWitnessAlert + +INJECT_UUID = "12345678-1234-1234-1234-123456789abc" +AGENT_UUID = "87654321-4321-4321-4321-cba987654321" +URL_PATH = f"/api/injects/{INJECT_UUID}/{AGENT_UUID}/executable-payload" + + +class TestConverterExtra: + """Branch-coverage tests for Converter.""" + + def test_convert_alert_with_parent_process(self): + """An alert whose url_path encodes UUIDs yields a parent_process_name field.""" + converter = Converter() + alert = NetWitnessAlert(time="t", src_ip="1.2.3.4", url_path=URL_PATH) + + result = converter.convert_data_to_oaev_data(alert) + + assert len(result) == 1 # noqa: S101 + assert "parent_process_name" in result[0] # noqa: S101 + assert result[0]["parent_process_name"]["type"] == "fuzzy" # noqa: S101 + + def test_convert_alert_with_signature_and_rule(self): + """Signature and rule name on the alert are handled without error.""" + converter = Converter() + alert = NetWitnessAlert( + time="t", + src_ip="1.2.3.4", + signature="Malicious Activity", + rule_name="High Risk Rule", + ) + + result = converter.convert_data_to_oaev_data(alert) + + assert result[0]["source_ipv4_address"]["data"] == ["1.2.3.4"] # noqa: S101 + + def test_extract_parent_process_name_no_url(self): + """An alert without a url_path yields no parent process name.""" + converter = Converter() + alert = NetWitnessAlert(time="t", src_ip="1.2.3.4") + assert converter._extract_parent_process_name(alert) == "" # noqa: S101 + + def test_extract_parent_process_name_non_matching_url(self): + """A url_path without UUIDs yields no parent process name.""" + converter = Converter() + alert = NetWitnessAlert(time="t", src_ip="1.2.3.4", url_path="/api/other") + assert converter._extract_parent_process_name(alert) == "" # noqa: S101 + + def test_alert_data_rejects_invalid_type(self): + """_alert_data rejects non-NetWitnessAlert input.""" + converter = Converter() + with pytest.raises(NetWitnessValidationError): + converter._alert_data({"not": "an-alert"}) diff --git a/netwitness/tests/services/test_expectation_service_essential.py b/netwitness/tests/services/test_expectation_service_essential.py new file mode 100644 index 00000000..6c817fb8 --- /dev/null +++ b/netwitness/tests/services/test_expectation_service_essential.py @@ -0,0 +1,319 @@ +"""Essential tests for NetWitness Expectation Service.""" + +from unittest.mock import Mock + +import pytest +from pyoaev.signatures.types import SignatureTypes +from src.collector.models import ExpectationResult +from src.services.exception import ( + NetWitnessExpectationError, + NetWitnessNoAlertsFoundError, + NetWitnessNoMatchingAlertsError, + NetWitnessValidationError, +) +from src.services.expectation_service import NetWitnessExpectationService +from tests.services.fixtures.factories import ( + MockObjectsFactory, + TestDataFactory, + create_test_config, +) + + +class TestNetWitnessExpectationServiceEssential: + """Essential test cases for NetWitnessExpectationService. + + Tests the core functionality of the NetWitness expectation service including + initialization, signature support, batch processing, and matching operations. + """ + + def test_init_with_valid_config(self): + """Test that service initializes correctly with valid config. + + Verifies that the service properly initializes with configuration values, + sets up client API and converter components, and configures time window. + """ + config = create_test_config() + + service = NetWitnessExpectationService(config=config) + + assert service.config == config # noqa: S101 + assert service.client_api is not None # noqa: S101 + assert service.converter is not None # noqa: S101 + assert service.time_window is not None # noqa: S101 + + def test_init_without_config_raises_error(self): + """Test that initialization without config raises configuration error. + + Verifies that attempting to initialize the service without a valid + configuration raises a NetWitnessValidationError. + """ + with pytest.raises(NetWitnessValidationError): + NetWitnessExpectationService(config=None) + + def test_get_supported_signatures(self): + """Test that service returns correct supported signatures. + + Verifies that the service returns the expected list of signature types + it can process for expectation handling (only IP addresses and dates). + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + signatures = service.get_supported_signatures() + + expected_signatures = [ + SignatureTypes.SIG_TYPE_SOURCE_IPV4_ADDRESS, + SignatureTypes.SIG_TYPE_TARGET_IPV4_ADDRESS, + SignatureTypes.SIG_TYPE_SOURCE_IPV6_ADDRESS, + SignatureTypes.SIG_TYPE_TARGET_IPV6_ADDRESS, + SignatureTypes.SIG_TYPE_START_DATE, + SignatureTypes.SIG_TYPE_END_DATE, + SignatureTypes.SIG_TYPE_PARENT_PROCESS_NAME, + ] + assert signatures == expected_signatures # noqa: S101 + + def test_handle_batch_expectations_success(self): + """Test successful batch expectation handling. + + Verifies that the service can process multiple expectations in batch, + returning appropriate ExpectationResult objects for each. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + mock_result = ExpectationResult( + expectation_id="test-id", + is_valid=True, + expectation=None, + ) + service.process_expectation = Mock(return_value=mock_result) + + expectations = [ + MockObjectsFactory.create_mock_expectation(expectation_type="detection"), + MockObjectsFactory.create_mock_expectation(expectation_type="detection"), + ] + + mock_detection_helper = MockObjectsFactory.create_mock_detection_helper() + + results = service.handle_batch_expectations(expectations, mock_detection_helper) + + assert len(results) == 2 # noqa: S101 + assert all(isinstance(r, ExpectationResult) for r in results) # noqa: S101 + assert service.process_expectation.call_count == 2 # noqa: S101 + + def test_handle_batch_expectations_with_error(self): + """Test batch expectation handling when expectation fails. + + Verifies that individual expectation failures are handled gracefully + in batch processing, returning error results without stopping the batch. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + service.process_expectation = Mock( + side_effect=NetWitnessExpectationError("Test error") + ) + + expectations = [MockObjectsFactory.create_mock_expectation()] + mock_detection_helper = MockObjectsFactory.create_mock_detection_helper() + + results = service.handle_batch_expectations(expectations, mock_detection_helper) + + assert len(results) == 1 # noqa: S101 + assert results[0].is_valid is False # noqa: S101 + + def test_prevention_expectation_not_supported(self): + """Test that prevention expectations raise error. + + Verifies that NetWitness correctly rejects prevention expectation + types as it only supports detection expectations. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + mock_prevention_expectation = Mock() + mock_prevention_expectation.inject_expectation_id = "test-prevention-id" + + # Mock the isinstance check to return False for DetectionExpectation + # and True for PreventionExpectation + # We'll simulate this by calling the method that checks expectation type + from pyoaev.apis.inject_expectation.model import PreventionExpectation + + # Create a mock that will fail the detection check + prevention_mock = Mock(spec=PreventionExpectation) + prevention_mock.inject_expectation_id = "test-prevention-id" + + mock_detection_helper = MockObjectsFactory.create_mock_detection_helper() + + result = service.process_expectation(prevention_mock, mock_detection_helper) + + assert isinstance(result, ExpectationResult) # noqa: S101 + assert result.is_valid is False # noqa: S101 + assert ( # noqa: S101 + "only supports DetectionExpectations" in result.error_message + ) + + def test_match_success(self): + """Test successful matching for detection expectation. + + Verifies that the matching logic correctly identifies when OAEV data + matches expectation signatures and returns appropriate result data. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + oaev_data = TestDataFactory.create_oaev_detection_data() + matching_signatures = [ + { + "type": "source_ipv4_address", + "value": oaev_data[0]["source_ipv4_address"]["data"], + }, + { + "type": "parent_process_name", + "value": "test_process.exe", + }, + ] + + mock_detection_helper = MockObjectsFactory.create_mock_detection_helper( + match_result=True + ) + + result = service._match( + oaev_data, matching_signatures, mock_detection_helper, "detection" + ) + + assert result["is_valid"] is True # noqa: S101 + assert result["matching_data"] == [oaev_data[0]] # noqa: S101 + + def test_match_no_data_raises_exception(self): + """Test matching with no data raises NoAlertsFound exception. + + Verifies that attempting to match against empty data properly + raises NetWitnessNoAlertsFoundError. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + mock_detection_helper = MockObjectsFactory.create_mock_detection_helper() + + with pytest.raises(NetWitnessNoAlertsFoundError): + service._match([], [], mock_detection_helper, "detection") + + def test_match_no_matching_alerts_raises_exception(self): + """Test matching that finds no matches raises NoMatchingAlerts exception. + + Verifies that when data is available but no matches are found, + the service raises NetWitnessNoMatchingAlertsError. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + oaev_data = TestDataFactory.create_oaev_detection_data() + matching_signatures = [ + {"type": "source_ipv4_address", "value": "192.168.99.99"} # Different IP + ] + + mock_detection_helper = MockObjectsFactory.create_mock_detection_helper( + match_result=False + ) + + with pytest.raises(NetWitnessNoMatchingAlertsError): + service._match( + oaev_data, matching_signatures, mock_detection_helper, "detection" + ) + + def test_extract_signatures_filters_correctly(self): + """Test signature extraction and filtering. + + Verifies that signature extraction properly separates search signatures + from matching signatures, excluding date metadata from matching. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + # Create mock expectation with mixed signature types + mock_expectation = Mock() + mock_signature_ip = Mock() + mock_signature_ip.type.value = "source_ipv4_address" + mock_signature_ip.value = "192.168.1.100" + + mock_signature_date = Mock() + mock_signature_date.type.value = "start_date" + mock_signature_date.value = "2024-01-01T00:00:00Z" + + mock_expectation.inject_expectation_signatures = [ + mock_signature_ip, + mock_signature_date, + ] + + search_signatures, matching_signatures = service._extract_signatures( + mock_expectation + ) + + # Search signatures should include both + assert len(search_signatures) == 2 # noqa: S101 + + # Matching signatures should exclude dates + assert len(matching_signatures) == 1 # noqa: S101 + assert matching_signatures[0]["type"] == "source_ipv4_address" # noqa: S101 + + def test_create_error_result_object(self): + """Test creating error result objects from exceptions. + + Verifies that service errors are properly converted to ExpectationResult + objects with appropriate error information and validation status. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + mock_expectation = MockObjectsFactory.create_mock_expectation() + error = NetWitnessNoAlertsFoundError("No alerts found") + + result = service._create_error_result_object(error, mock_expectation) + + assert isinstance(result, ExpectationResult) # noqa: S101 + assert result.is_valid is False # noqa: S101 + assert result.error_message is not None # noqa: S101 + assert "No alerts found" in result.error_message # noqa: S101 + + def test_get_service_info(self): + """Test getting service information. + + Verifies that the service provides accurate metadata about its + capabilities, supported signatures, and service type information. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + info = service.get_service_info() + + assert info["service_name"] == "NetWitness" # noqa: S101 + assert info["supports_detection"] is True # noqa: S101 + assert info["supports_prevention"] is False # noqa: S101 + assert "supported_signatures" in info # noqa: S101 + assert len(info["supported_signatures"]) == 7 # noqa: S101 + + def test_convert_dict_to_result(self): + """Test converting dictionary results to ExpectationResult objects. + + Verifies that result dictionaries are properly converted to + structured ExpectationResult instances. + """ + config = create_test_config() + service = NetWitnessExpectationService(config=config) + + mock_expectation = MockObjectsFactory.create_mock_expectation() + result_dict = { + "is_valid": True, + "matching_data": [ + {"source_ipv4_address": {"type": "simple", "data": "192.168.1.100"}} + ], + "total_data_found": 1, + } + + result = service._convert_dict_to_result(result_dict, mock_expectation) + + assert isinstance(result, ExpectationResult) # noqa: S101 + assert result.is_valid is True # noqa: S101 + assert result.matched_alerts is not None # noqa: S101 + assert result.expectation == mock_expectation # noqa: S101 diff --git a/netwitness/tests/services/test_expectation_service_flow.py b/netwitness/tests/services/test_expectation_service_flow.py new file mode 100644 index 00000000..45b3aff4 --- /dev/null +++ b/netwitness/tests/services/test_expectation_service_flow.py @@ -0,0 +1,184 @@ +"""Flow-level tests for NetWitnessExpectationService (real processing paths).""" + +from unittest.mock import Mock + +import pytest +from pyoaev.apis.inject_expectation.model import ( + DetectionExpectation, + PreventionExpectation, +) +from src.services.exception import NetWitnessNoAlertsFoundError +from src.services.expectation_service import NetWitnessExpectationService +from tests.services.fixtures.factories import create_test_config + +PARENT_VALUE = ( + "oaev-implant-12345678-1234-1234-1234-123456789abc" + "-agent-87654321-4321-4321-4321-cba987654321" +) + + +def _service() -> NetWitnessExpectationService: + """Build an NetWitnessExpectationService from a test config.""" + return NetWitnessExpectationService(config=create_test_config()) + + +def _detection_expectation(signatures: list[tuple[str, str]]) -> Mock: + """Build a mock detection expectation with the given (type, value) signatures.""" + expectation = Mock(spec=DetectionExpectation) + expectation.inject_expectation_id = "exp-1" + sig_objs = [] + for sig_type, value in signatures: + sig = Mock() + sig.type.value = sig_type + sig.value = value + sig_objs.append(sig) + expectation.inject_expectation_signatures = sig_objs + return expectation + + +class TestExpectationServiceFlow: + """Flow-level tests covering process/handle/match paths.""" + + def test_process_expectation_detection_success(self): + """A detection expectation with a matching alert returns a valid result.""" + service = _service() + service.client_api.fetch_with_retry = Mock(return_value=[Mock()]) + service.converter.convert_data_to_oaev_data = Mock( + return_value=[ + { + "source_ipv4_address": {"type": "simple", "data": ["1.2.3.4"]}, + "parent_process_name": {"type": "simple", "data": PARENT_VALUE}, + } + ] + ) + helper = Mock() + helper.match_alert_elements.return_value = True + + expectation = _detection_expectation( + [ + ("source_ipv4_address", "1.2.3.4"), + ("parent_process_name", PARENT_VALUE), + ] + ) + + result = service.process_expectation(expectation, helper) + + assert result.is_valid is True # noqa: S101 + assert result.matched_alerts is not None # noqa: S101 + + def test_process_expectation_no_alerts_raises(self): + """No fetched alerts raises NetWitnessNoAlertsFoundError.""" + service = _service() + service.client_api.fetch_with_retry = Mock(return_value=[]) + service.converter.convert_data_to_oaev_data = Mock(return_value=[]) + + expectation = _detection_expectation([("source_ipv4_address", "1.2.3.4")]) + + with pytest.raises(NetWitnessNoAlertsFoundError): + service.process_expectation(expectation, Mock()) + + def test_handle_prevention_expectation_invalid(self): + """Prevention expectations are reported as invalid.""" + service = _service() + expectation = Mock(spec=PreventionExpectation) + expectation.inject_expectation_id = "prev-1" + + result = service.handle_prevention_expectation(expectation, Mock()) + + assert result.is_valid is False # noqa: S101 + assert ( + "only supports DetectionExpectations" in result.error_message + ) # noqa: S101 + + def test_match_with_detection_helper_parent_fail(self): + """A failing parent-process match short-circuits to False.""" + service = _service() + helper = Mock() + helper.match_alert_elements.return_value = False + + signatures = [{"type": "parent_process_name", "value": PARENT_VALUE}] + data_item = {"parent_process_name": {"type": "simple", "data": PARENT_VALUE}} + + assert ( # noqa: S101 + service._match_with_detection_helper(signatures, data_item, helper) is False + ) + + def test_match_with_detection_helper_target_only(self): + """A target-IP match (with parent) returns True.""" + service = _service() + helper = Mock() + helper.match_alert_elements.return_value = True + + signatures = [ + {"type": "parent_process_name", "value": PARENT_VALUE}, + {"type": "target_ipv4_address", "value": "10.0.0.1"}, + ] + data_item = { + "parent_process_name": {"type": "simple", "data": PARENT_VALUE}, + "target_ipv4_address": {"type": "simple", "data": ["10.0.0.1"]}, + } + + assert ( # noqa: S101 + service._match_with_detection_helper(signatures, data_item, helper) is True + ) + + def test_match_with_detection_helper_source_and_target(self): + """Source and target IP signatures both present resolves via OR logic.""" + service = _service() + helper = Mock() + helper.match_alert_elements.return_value = True + + signatures = [ + {"type": "parent_process_name", "value": PARENT_VALUE}, + {"type": "source_ipv4_address", "value": "1.2.3.4"}, + {"type": "target_ipv4_address", "value": "10.0.0.1"}, + ] + data_item = { + "parent_process_name": {"type": "simple", "data": PARENT_VALUE}, + "source_ipv4_address": {"type": "simple", "data": ["1.2.3.4"]}, + "target_ipv4_address": {"type": "simple", "data": ["10.0.0.1"]}, + } + + assert ( # noqa: S101 + service._match_with_detection_helper(signatures, data_item, helper) is True + ) + + def test_match_with_detection_helper_source_only_no_parent(self): + """An IP-only expectation (no parent_process signature) can still match.""" + service = _service() + helper = Mock() + helper.match_alert_elements.return_value = True + + signatures = [{"type": "source_ipv4_address", "value": "1.2.3.4"}] + data_item = { + "source_ipv4_address": {"type": "simple", "data": ["1.2.3.4"]}, + } + + assert ( # noqa: S101 + service._match_with_detection_helper(signatures, data_item, helper) is True + ) + + def test_match_with_detection_helper_source_only_no_parent_no_match(self): + """An IP-only expectation returns False when the IP does not match.""" + service = _service() + helper = Mock() + helper.match_alert_elements.return_value = False + + signatures = [{"type": "source_ipv4_address", "value": "1.2.3.4"}] + data_item = { + "source_ipv4_address": {"type": "simple", "data": ["9.9.9.9"]}, + } + + assert ( # noqa: S101 + service._match_with_detection_helper(signatures, data_item, helper) is False + ) + + def test_create_error_result_dict(self): + """_create_error_result builds an error dictionary from a service error.""" + service = _service() + error = NetWitnessNoAlertsFoundError("none found") + + result = service._create_error_result(error) + + assert result["is_valid"] is False # noqa: S101 + assert result["error_type"] == "NetWitnessNoAlertsFoundError" # noqa: S101 diff --git a/netwitness/tests/services/test_parent_process_parser.py b/netwitness/tests/services/test_parent_process_parser.py new file mode 100644 index 00000000..e6e8674a --- /dev/null +++ b/netwitness/tests/services/test_parent_process_parser.py @@ -0,0 +1,77 @@ +"""Tests for the ParentProcessParser utility.""" + +from src.services.utils.parent_process_parser import ParentProcessParser + +INJECT_UUID = "877b423b-ae91-4fc5-86c3-fa8ea3c938ba" +AGENT_UUID = "1402422f-2eaa-4fbd-80b2-b30df1b83b19" +PARENT_PROCESS_NAME = f"oaev-implant-{INJECT_UUID}-agent-{AGENT_UUID}" +URL_PATH = f"/api/injects/{INJECT_UUID}/{AGENT_UUID}/executable-payload" + + +class TestParentProcessParser: + """Test cases for ParentProcessParser.""" + + def test_extract_uuids_from_parent_process_name_valid(self): + """Valid parent process names yield the inject and agent UUIDs.""" + parser = ParentProcessParser() + result = parser.extract_uuids_from_parent_process_name(PARENT_PROCESS_NAME) + assert result == (INJECT_UUID, AGENT_UUID) # noqa: S101 + + def test_extract_uuids_from_parent_process_name_empty(self): + """Empty input returns None.""" + parser = ParentProcessParser() + assert parser.extract_uuids_from_parent_process_name("") is None # noqa: S101 + + def test_extract_uuids_from_parent_process_name_no_match(self): + """Non-matching input returns None.""" + parser = ParentProcessParser() + assert ( # noqa: S101 + parser.extract_uuids_from_parent_process_name("not-a-match") is None + ) + + def test_construct_parent_process_name(self): + """UUIDs are recombined into the canonical parent process name.""" + parser = ParentProcessParser() + result = parser.construct_parent_process_name(INJECT_UUID, AGENT_UUID) + assert result == PARENT_PROCESS_NAME # noqa: S101 + + def test_construct_parent_process_name_missing(self): + """Missing UUIDs produce an empty string.""" + parser = ParentProcessParser() + assert parser.construct_parent_process_name("", AGENT_UUID) == "" # noqa: S101 + + def test_extract_uuids_from_url_path_valid(self): + """Valid URL paths yield the inject and agent UUIDs.""" + parser = ParentProcessParser() + result = parser.extract_uuids_from_url_path(URL_PATH) + assert result == (INJECT_UUID, AGENT_UUID) # noqa: S101 + + def test_extract_uuids_from_url_path_empty(self): + """Empty URL path returns None.""" + parser = ParentProcessParser() + assert parser.extract_uuids_from_url_path("") is None # noqa: S101 + + def test_extract_uuids_from_url_path_no_match(self): + """Non-matching URL path returns None.""" + parser = ParentProcessParser() + assert parser.extract_uuids_from_url_path("/api/other") is None # noqa: S101 + + def test_build_url_path_search_query(self): + """The search query lists all URL field aliases and the injected path.""" + parser = ParentProcessParser() + query = parser.build_url_path_search_query(INJECT_UUID, AGENT_UUID) + assert URL_PATH in query # noqa: S101 + for field in ("url_path", "url", "path", "query"): + assert f'{field}="{URL_PATH}"' in query # noqa: S101 + + def test_build_url_path_search_query_missing(self): + """Missing UUIDs produce an empty query.""" + parser = ParentProcessParser() + assert parser.build_url_path_search_query("", "") == "" # noqa: S101 + + def test_validate_uuid_format(self): + """UUID validation accepts valid UUIDs and rejects invalid ones.""" + parser = ParentProcessParser() + assert parser.validate_uuid_format(INJECT_UUID) is True # noqa: S101 + assert parser.validate_uuid_format("not-a-uuid") is False # noqa: S101 + assert parser.validate_uuid_format("") is False # noqa: S101 diff --git a/netwitness/tests/services/test_trace_service.py b/netwitness/tests/services/test_trace_service.py new file mode 100644 index 00000000..0cf93826 --- /dev/null +++ b/netwitness/tests/services/test_trace_service.py @@ -0,0 +1,117 @@ +"""Tests for the NetWitness trace service.""" + +from unittest.mock import Mock + +import pytest +from src.collector.models import ExpectationResult +from src.services.exception import NetWitnessValidationError +from src.services.trace_service import NetWitnessTraceService +from tests.services.fixtures.factories import create_test_config + + +def _make_expectation(sig_type: str = "source_ipv4_address", value: str = "1.2.3.4"): + """Build a mock expectation with a single signature.""" + signature = Mock() + signature.type.value = sig_type + signature.value = value + expectation = Mock() + expectation.inject_expectation_signatures = [signature] + return expectation + + +def _make_result(matching_data: dict) -> ExpectationResult: + """Build a valid ExpectationResult with a matched alert.""" + return ExpectationResult( + expectation_id="exp-1", + is_valid=True, + expectation=_make_expectation(), + matched_alerts=[matching_data], + ) + + +class TestNetWitnessTraceService: + """Test cases for NetWitnessTraceService.""" + + def test_init_without_config_raises(self): + """Initialization without a config raises a validation error.""" + with pytest.raises(NetWitnessValidationError): + NetWitnessTraceService(config=None) + + def test_create_traces_from_results_success(self): + """A valid result produces a single trace with a source-IP alert name.""" + service = NetWitnessTraceService(config=create_test_config()) + result = _make_result({"source_ipv4_address": {"data": "1.2.3.4"}}) + + traces = service.create_traces_from_results([result], "netwitness--collector") + + assert len(traces) == 1 # noqa: S101 + trace = traces[0] + assert trace.inject_expectation_trace_expectation == "exp-1" # noqa: S101 + assert ( + trace.inject_expectation_trace_source_id == "netwitness--collector" + ) # noqa: S101 + assert "Source IP" in trace.inject_expectation_trace_alert_name # noqa: S101 + assert trace.inject_expectation_trace_alert_link.startswith( + "http" + ) # noqa: S101 + + def test_create_traces_target_ip_alert_name(self): + """A target-IP match yields a target-IP alert name.""" + service = NetWitnessTraceService(config=create_test_config()) + result = _make_result({"target_ipv4_address": {"data": "10.0.0.1"}}) + + traces = service.create_traces_from_results([result], "netwitness--collector") + + assert ( + "Target IP" in traces[0].inject_expectation_trace_alert_name + ) # noqa: S101 + + def test_create_traces_generic_alert_name(self): + """A non-IP match yields the generic alert name.""" + service = NetWitnessTraceService(config=create_test_config()) + result = _make_result({"parent_process_name": {"data": "x.exe"}}) + + traces = service.create_traces_from_results([result], "netwitness--collector") + + name = traces[0].inject_expectation_trace_alert_name + assert name == "NetWitness Detection Alert" # noqa: S101 + + def test_create_traces_empty_collector_id_raises(self): + """An empty collector_id raises a validation error.""" + service = NetWitnessTraceService(config=create_test_config()) + with pytest.raises(NetWitnessValidationError): + service.create_traces_from_results([], "") + + def test_create_traces_non_list_raises(self): + """A non-list results argument raises a validation error.""" + service = NetWitnessTraceService(config=create_test_config()) + with pytest.raises(NetWitnessValidationError): + service.create_traces_from_results("nope", "netwitness--collector") + + def test_create_traces_no_valid_results(self): + """Invalid results (no matches) produce no traces.""" + service = NetWitnessTraceService(config=create_test_config()) + invalid = ExpectationResult( + expectation_id="exp-2", is_valid=False, matched_alerts=None + ) + assert service.create_traces_from_results([invalid], "c") == [] # noqa: S101 + + def test_create_traces_console_url_used(self): + """When console_url is configured it is used for the trace link.""" + config = create_test_config() + config.netwitness.console_url = "https://console.example.com" + service = NetWitnessTraceService(config=config) + result = _make_result({"source_ipv4_address": {"data": "1.2.3.4"}}) + + traces = service.create_traces_from_results([result], "netwitness--collector") + + assert traces[0].inject_expectation_trace_alert_link.startswith( # noqa: S101 + "https://console.example.com" + ) + + def test_get_service_info(self): + """The service exposes detection-only metadata.""" + service = NetWitnessTraceService(config=create_test_config()) + info = service.get_service_info() + assert info["creates_detection_traces"] is True # noqa: S101 + assert info["creates_prevention_traces"] is False # noqa: S101 diff --git a/netwitness/tests/test_collector_models.py b/netwitness/tests/test_collector_models.py new file mode 100644 index 00000000..1a0c19b2 --- /dev/null +++ b/netwitness/tests/test_collector_models.py @@ -0,0 +1,86 @@ +"""Tests for collector Pydantic models.""" + +import pytest +from pydantic import ValidationError +from src.collector.models import ( + ExpectationResult, + ExpectationTrace, + ProcessingSummary, +) + + +def _valid_trace_kwargs() -> dict: + """Return a set of valid ExpectationTrace field values.""" + return { + "inject_expectation_trace_expectation": "exp-1", + "inject_expectation_trace_source_id": "netwitness--collector", + "inject_expectation_trace_alert_name": "NetWitness Detection Alert", + "inject_expectation_trace_alert_link": "https://kibana/app/security/alerts", + "inject_expectation_trace_date": "2026-01-01T00:00:00Z", + } + + +class TestExpectationTrace: + """Test cases for the ExpectationTrace model.""" + + def test_valid_trace(self): + """A fully populated trace builds successfully.""" + trace = ExpectationTrace(**_valid_trace_kwargs()) + assert trace.inject_expectation_trace_expectation == "exp-1" # noqa: S101 + + def test_to_api_dict_stringifies_values(self): + """to_api_dict returns string values for all fields.""" + trace = ExpectationTrace(**_valid_trace_kwargs()) + api_dict = trace.to_api_dict() + assert all(isinstance(value, str) for value in api_dict.values()) # noqa: S101 + assert api_dict["inject_expectation_trace_expectation"] == "exp-1" # noqa: S101 + + def test_values_are_trimmed(self): + """Leading/trailing whitespace is stripped from values.""" + kwargs = _valid_trace_kwargs() + kwargs["inject_expectation_trace_expectation"] = " exp-1 " + trace = ExpectationTrace(**kwargs) + assert trace.inject_expectation_trace_expectation == "exp-1" # noqa: S101 + + @pytest.mark.parametrize( + "field", + [ + "inject_expectation_trace_expectation", + "inject_expectation_trace_source_id", + "inject_expectation_trace_alert_name", + "inject_expectation_trace_alert_link", + "inject_expectation_trace_date", + ], + ) + def test_empty_field_raises(self, field): + """Each required field rejects empty/whitespace-only values.""" + kwargs = _valid_trace_kwargs() + kwargs[field] = " " + with pytest.raises(ValidationError): + ExpectationTrace(**kwargs) + + +class TestExpectationResult: + """Test cases for the ExpectationResult model.""" + + def test_valid_result(self): + """An ExpectationResult builds with required and optional fields.""" + result = ExpectationResult( + expectation_id="exp-1", + is_valid=True, + matched_alerts=[{"source_ipv4_address": {"data": "1.2.3.4"}}], + ) + assert result.is_valid is True # noqa: S101 + assert result.matched_alerts is not None # noqa: S101 + assert result.error_message is None # noqa: S101 + + +class TestProcessingSummary: + """Test cases for the ProcessingSummary model.""" + + def test_valid_summary(self): + """A ProcessingSummary builds with all counters.""" + summary = ProcessingSummary(processed=3, valid=2, invalid=1, skipped=0) + assert summary.processed == 3 # noqa: S101 + assert summary.valid == 2 # noqa: S101 + assert summary.invalid == 1 # noqa: S101 diff --git a/netwitness/tests/test_create_collector.py b/netwitness/tests/test_create_collector.py new file mode 100644 index 00000000..23b1f828 --- /dev/null +++ b/netwitness/tests/test_create_collector.py @@ -0,0 +1,224 @@ +"""Test module for the NetWitness Collector initialization.""" + +from os import environ as os_environ +from typing import Any +from uuid import UUID + +import pytest +from src.collector import Collector +from src.collector.exception import CollectorConfigError +from tests.conftest import mock_env_vars + +# -------- +# Fixtures +# -------- + + +@pytest.fixture() +def collector_config() -> dict[str, str]: # type: ignore + """Fixture for minimum required configuration. + + Returns: + Dictionary containing all required environment variables + for collector initialization with test values. + + """ + return { + "OPENAEV_URL": "http://fake-url/", + "OPENAEV_TOKEN": "fake-oaev-token", + "OPENAEV_TENANT_ID": "deadbeef-dead-beef-dead-beefdeadbeef", + "COLLECTOR_ID": "fake-collector-id", + "COLLECTOR_NAME": "NetWitness", + "NETWITNESS_BASE_URL": "https://fake-netwitness.net:50103/", + "NETWITNESS_USERNAME": "fake-user", + "NETWITNESS_PASSWORD": "fake-password", + "COLLECTOR_ICON_FILEPATH": "src/img/netwitness-logo.png", + "COLLECTOR_LOG_LEVEL": "debug", + } + + +# -------- +# Tests +# -------- + + +# Scenario: Create a collector with success. +def test_success_create_collector(capfd, collector_config): # type: ignore + """Test that the main function initializes and start the NetWitness Collector. + + Args: + capfd: Pytest fixture for capturing stdout and stderr output. + collector_config: Fixture providing valid collector configuration. + + """ + # Given I have a valid configuration to start the NetWitness Collector. + data = {**collector_config} + mock_env = _given_setup_config(data) + + # When I create the collector. + collector = _when_create_collector() + + # Then the collector should be created successfully + _then_collector_created_successfully(capfd, mock_env, collector, data) + + +# Scenario: Create a collector with missing required config +def test_collector_config_missing_required_values() -> None: + """Test for the collector with missing required configuration values. + + Verifies that collector creation fails appropriately when required + configuration values are missing, specifically the NetWitness username. + + """ + # Given configuration with missing required NetWitness username + data = { + "OPENAEV_URL": "http://fake-url", + "OPENAEV_TOKEN": "fake-oaev-token", + "OPENAEV_TENANT_ID": "deadbeef-dead-beef-dead-beefdeadbeef", + "COLLECTOR_ID": "fake-collector-id", + "COLLECTOR_NAME": "NetWitness", + "NETWITNESS_BASE_URL": "https://fake-netwitness.net:9200/", + # Missing NETWITNESS_USERNAME - this should cause validation error + "NETWITNESS_PASSWORD": "fake-password", + "NETWITNESS_ALERTS_INDEX": ".alerts-security.alerts-*", + "COLLECTOR_ICON_FILEPATH": "src/img/netwitness-logo.png", + "COLLECTOR_LOG_LEVEL": "debug", + } + mock_env = _given_setup_config(data) + + # Remove username env var if it was set by factory + if "NETWITNESS_USERNAME" in os_environ: + del os_environ["NETWITNESS_USERNAME"] + + # Then the collector config should raise a custom ConfigurationException + with pytest.raises((CollectorConfigError, ValueError)): + # When the collector is created + _when_create_collector() + + mock_env.stop() + + +# Scenario: Create a collector with missing password +def test_collector_config_missing_password() -> None: + """Test for the collector with missing password configuration. + + Verifies that collector creation fails appropriately when required + password configuration is missing. + + """ + # Given configuration with missing required NetWitness password + data = { + "OPENAEV_URL": "http://fake-url", + "OPENAEV_TOKEN": "fake-oaev-token", + "OPENAEV_TENANT_ID": "deadbeef-dead-beef-dead-beefdeadbeef", + "COLLECTOR_ID": "fake-collector-id", + "COLLECTOR_NAME": "NetWitness", + "NETWITNESS_BASE_URL": "https://fake-netwitness.net:9200/", + "NETWITNESS_USERNAME": "fake-user", + # Missing NETWITNESS_PASSWORD - this should cause validation error + "NETWITNESS_ALERTS_INDEX": ".alerts-security.alerts-*", + "COLLECTOR_ICON_FILEPATH": "src/img/netwitness-logo.png", + "COLLECTOR_LOG_LEVEL": "debug", + } + mock_env = _given_setup_config(data) + + # Remove password env var if it was set by factory + if "NETWITNESS_PASSWORD" in os_environ: + del os_environ["NETWITNESS_PASSWORD"] + + # Then the collector config should raise a custom ConfigurationException + with pytest.raises((CollectorConfigError, ValueError)): + # When the collector is created + _when_create_collector() + + mock_env.stop() + + +# --------- +# Given +# --------- + + +# Given setup config +def _given_setup_config(data: dict[str, str]) -> Any: # type: ignore + """Set up the environment variables for the test. + + Args: + data: Dictionary of environment variables to mock. + + Returns: + Mock environment variable patcher object. + + """ + mock_env = mock_env_vars(os_environ, data) + return mock_env + + +# --------- +# When +# --------- + + +# When the collector is created +def _when_create_collector() -> Collector: # type: ignore + """Create the collector. + + Returns: + Collector instance for testing. + + """ + collector = Collector() + return collector + + +# --------- +# Then +# --------- + + +# Then the collector should be created successfully +def _then_collector_created_successfully(capfd, mock_env, collector, data) -> None: # type: ignore + """Check if the connector was created successfully. + + Args: + capfd: Pytest fixture for capturing stdout and stderr output. + mock_env: Mock environment variable patcher to clean up. + collector: The created collector instance to verify. + data: Expected configuration data to validate against. + + """ + assert collector is not None # noqa: S101 + + # Check that the collector has the expected configuration + daemon_config = collector.config_instance.to_daemon_config() + + # Verify key configuration values + assert daemon_config.get("openaev_url") == data.get("OPENAEV_URL") # noqa: S101 + assert daemon_config.get("openaev_token") == data.get("OPENAEV_TOKEN") # noqa: S101 + assert daemon_config.get("openaev_tenant_id") == UUID( + data.get("OPENAEV_TENANT_ID") + ) # noqa: S101 + assert daemon_config.get("collector_id") == data.get("COLLECTOR_ID") # noqa: S101 + assert daemon_config.get("collector_name") == data.get( # noqa: S101 + "COLLECTOR_NAME" + ) + assert daemon_config.get("netwitness_base_url") == data.get( # noqa: S101 + "NETWITNESS_BASE_URL" + ) + assert daemon_config.get("netwitness_username") == data.get( # noqa: S101 + "NETWITNESS_USERNAME" + ) + assert daemon_config.get("netwitness_password") == data.get( # noqa: S101 + "NETWITNESS_PASSWORD" + ) + assert daemon_config.get("netwitness_max_results") is not None # noqa: S101 + assert daemon_config.get("collector_log_level") == data.get( # noqa: S101 + "COLLECTOR_LOG_LEVEL" + ) + + log_records = capfd.readouterr() + if daemon_config.get("collector_log_level") in ["info", "debug"]: + registered_message = "NetWitness Collector initialized successfully" + assert registered_message in log_records.err # noqa: S101 + + mock_env.stop() diff --git a/netwitness/tests/test_expectation_handler.py b/netwitness/tests/test_expectation_handler.py new file mode 100644 index 00000000..8eff0fa8 --- /dev/null +++ b/netwitness/tests/test_expectation_handler.py @@ -0,0 +1,121 @@ +"""Tests for the GenericExpectationHandler.""" + +from unittest.mock import Mock + +import pytest +from pyoaev.apis.inject_expectation.model import ( + DetectionExpectation, + PreventionExpectation, +) +from pyoaev.signatures.types import SignatureTypes +from src.collector.exception import ExpectationHandlerError +from src.collector.expectation_handler import GenericExpectationHandler +from src.collector.models import ExpectationResult + + +def _service_provider() -> Mock: + """Build a mock service provider exposing supported signatures.""" + provider = Mock() + provider.get_supported_signatures.return_value = [ + SignatureTypes.SIG_TYPE_SOURCE_IPV4_ADDRESS + ] + return provider + + +class TestGenericExpectationHandler: + """Test cases for GenericExpectationHandler.""" + + def test_init_registers_with_registry(self): + """Initialization registers handlers using the provider's signatures.""" + provider = _service_provider() + handler = GenericExpectationHandler(provider) + assert handler.service_provider is provider # noqa: S101 + provider.get_supported_signatures.assert_called() + + def test_handle_detection_expectation(self): + """Detection expectations are delegated to the detection handler.""" + provider = _service_provider() + expected = ExpectationResult(expectation_id="e1", is_valid=True) + provider.handle_detection_expectation.return_value = expected + handler = GenericExpectationHandler(provider) + + expectation = Mock(spec=DetectionExpectation) + expectation.inject_expectation_id = "e1" + + result = handler.handle_expectation(expectation, Mock()) + + assert result is expected # noqa: S101 + provider.handle_detection_expectation.assert_called_once() + + def test_handle_prevention_expectation(self): + """Prevention expectations are delegated to the prevention handler.""" + provider = _service_provider() + expected = ExpectationResult(expectation_id="e2", is_valid=False) + provider.handle_prevention_expectation.return_value = expected + handler = GenericExpectationHandler(provider) + + expectation = Mock(spec=PreventionExpectation) + expectation.inject_expectation_id = "e2" + + result = handler.handle_expectation(expectation, Mock()) + + assert result is expected # noqa: S101 + + def test_handle_unsupported_type(self): + """Unsupported expectation types yield an invalid result.""" + provider = _service_provider() + handler = GenericExpectationHandler(provider) + + expectation = Mock() + expectation.inject_expectation_id = "e3" + + result = handler.handle_expectation(expectation, Mock()) + + assert result.is_valid is False # noqa: S101 + assert "Unsupported" in result.error_message # noqa: S101 + + def test_handle_expectation_propagates_errors(self): + """Errors from the service provider are propagated.""" + provider = _service_provider() + provider.handle_detection_expectation.side_effect = RuntimeError("boom") + handler = GenericExpectationHandler(provider) + + expectation = Mock(spec=DetectionExpectation) + expectation.inject_expectation_id = "e4" + + with pytest.raises(RuntimeError): + handler.handle_expectation(expectation, Mock()) + + def test_handle_batch_post_processes_results(self): + """Batch handling fills in missing expectation IDs and objects.""" + provider = _service_provider() + provider.handle_batch_expectations.return_value = [ + ExpectationResult(expectation_id="", is_valid=True) + ] + handler = GenericExpectationHandler(provider) + + expectation = Mock() + expectation.inject_expectation_id = "e5" + + results = handler.handle_batch_expectations([expectation], Mock()) + + assert len(results) == 1 # noqa: S101 + assert results[0].expectation is expectation # noqa: S101 + assert results[0].expectation_id == "e5" # noqa: S101 + + def test_handle_batch_wraps_errors(self): + """Batch failures are wrapped in ExpectationHandlerError.""" + provider = _service_provider() + provider.handle_batch_expectations.side_effect = RuntimeError("x") + handler = GenericExpectationHandler(provider) + + with pytest.raises(ExpectationHandlerError): + handler.handle_batch_expectations([Mock()], Mock()) + + def test_get_supported_signatures(self): + """The handler exposes the provider's supported signatures.""" + provider = _service_provider() + handler = GenericExpectationHandler(provider) + assert handler.get_supported_signatures() == [ # noqa: S101 + SignatureTypes.SIG_TYPE_SOURCE_IPV4_ADDRESS + ] diff --git a/netwitness/tests/test_expectation_manager.py b/netwitness/tests/test_expectation_manager.py new file mode 100644 index 00000000..55fa116b --- /dev/null +++ b/netwitness/tests/test_expectation_manager.py @@ -0,0 +1,147 @@ +"""Tests for the GenericExpectationManager.""" + +from unittest.mock import Mock + +import pytest +from pyoaev.apis.inject_expectation.model import ( + DetectionExpectation, + PreventionExpectation, +) +from pyoaev.signatures.types import SignatureTypes +from src.collector.exception import ExpectationUpdateError +from src.collector.expectation_manager import GenericExpectationManager +from src.collector.models import ExpectationResult + + +def _detection_exp(exp_id: str = "e1", end_date: bool = True) -> Mock: + """Build a mock detection expectation, optionally with an end_date signature.""" + expectation = Mock(spec=DetectionExpectation) + expectation.inject_expectation_id = exp_id + signatures = [] + if end_date: + sig = Mock() + sig.type = SignatureTypes.SIG_TYPE_END_DATE + signatures.append(sig) + expectation.inject_expectation_signatures = signatures + return expectation + + +def _manager() -> tuple[GenericExpectationManager, Mock, Mock]: + """Build a manager with mock API and handler.""" + api = Mock() + handler = Mock() + manager = GenericExpectationManager(api, "collector-1", handler) + return manager, api, handler + + +class TestGenericExpectationManager: + """Test cases for GenericExpectationManager.""" + + def test_init_requires_api(self): + """A missing API raises ValueError.""" + with pytest.raises(ValueError): + GenericExpectationManager(None, "c", Mock()) + + def test_init_requires_collector_id(self): + """A missing collector id raises ValueError.""" + with pytest.raises(ValueError): + GenericExpectationManager(Mock(), "", Mock()) + + def test_init_requires_handler(self): + """A missing handler raises ValueError.""" + with pytest.raises(ValueError): + GenericExpectationManager(Mock(), "c", None) + + def test_process_expectations_success(self): + """A full processing cycle returns a summary and updates the API.""" + manager, api, handler = _manager() + expectation = _detection_exp() + api.inject_expectation.expectations_models_for_source.return_value = [ + expectation + ] + handler.handle_batch_expectations.return_value = [ + ExpectationResult( + expectation_id="e1", is_valid=True, expectation=expectation + ) + ] + + summary = manager.process_expectations(Mock()) + + assert summary.processed == 1 # noqa: S101 + assert summary.valid == 1 # noqa: S101 + api.inject_expectation.bulk_update.assert_called_once() + + def test_check_for_end_date(self): + """end_date detection returns True only when present.""" + manager, _, _ = _manager() + assert ( + manager._check_for_end_date([_detection_exp(end_date=True)]) is True + ) # noqa: S101 + assert ( # noqa: S101 + manager._check_for_end_date([_detection_exp(end_date=False)]) is False + ) + + def test_prepare_bulk_data_filters(self): + """Bulk data preparation skips results without ids or expectations.""" + manager, _, _ = _manager() + expectation = _detection_exp() + results = [ + ExpectationResult( + expectation_id="e1", is_valid=True, expectation=expectation + ), + ExpectationResult( + expectation_id="", is_valid=True, expectation=expectation + ), + ExpectationResult(expectation_id="e3", is_valid=False, expectation=None), + ] + + bulk = manager._prepare_bulk_data(results) + + assert "e1" in bulk # noqa: S101 + assert "e3" not in bulk # noqa: S101 + assert bulk["e1"]["is_success"] is True # noqa: S101 + + def test_get_result_text(self): + """Result text reflects expectation type and validity.""" + manager, _, _ = _manager() + detection = Mock(spec=DetectionExpectation) + prevention = Mock(spec=PreventionExpectation) + assert manager._get_result_text(detection, True) == "Detected" # noqa: S101 + assert ( + manager._get_result_text(detection, False) == "Not Detected" + ) # noqa: S101 + assert manager._get_result_text(prevention, True) == "Prevented" # noqa: S101 + + def test_attempt_bulk_update_success(self): + """A successful bulk update calls the API once.""" + manager, api, _ = _manager() + manager._attempt_bulk_update({"e1": {"x": 1}}) + api.inject_expectation.bulk_update.assert_called_once() + + def test_attempt_bulk_update_falls_back_to_individual(self): + """A bulk failure falls back to individual updates without raising.""" + manager, api, _ = _manager() + api.inject_expectation.bulk_update.side_effect = RuntimeError("bulk") + api.inject_expectation.update.side_effect = RuntimeError("individual") + + manager._attempt_bulk_update({"e1": {"x": 1}}) + + api.inject_expectation.update.assert_called() + + def test_update_expectation_error(self): + """A failed individual update raises ExpectationUpdateError.""" + manager, api, _ = _manager() + api.inject_expectation.update.side_effect = RuntimeError("x") + with pytest.raises(ExpectationUpdateError): + manager._update_expectation("e1", {"x": 1}) + + def test_bulk_update_empty_results(self): + """No results means no bulk update call.""" + manager, api, _ = _manager() + manager._bulk_update_expectations([]) + api.inject_expectation.bulk_update.assert_not_called() + + def test_interruptible_sleep_zero_returns(self): + """A non-positive sleep returns immediately.""" + manager, _, _ = _manager() + manager._interruptible_sleep(0) diff --git a/netwitness/tests/test_signature_registry.py b/netwitness/tests/test_signature_registry.py new file mode 100644 index 00000000..76637a16 --- /dev/null +++ b/netwitness/tests/test_signature_registry.py @@ -0,0 +1,83 @@ +"""Tests for the SignatureRegistry.""" + +import pytest +from pyoaev.signatures.types import SignatureTypes +from src.collector.models import ExpectationResult +from src.collector.signature_registry import ( + ExpectationHandlerType, + SignatureRegistry, + get_registry, +) + +SOURCE_IP = SignatureTypes.SIG_TYPE_SOURCE_IPV4_ADDRESS +TARGET_IP = SignatureTypes.SIG_TYPE_TARGET_IPV4_ADDRESS + + +def _handler(expectation, helper) -> ExpectationResult: + """Return a trivial valid result (test handler).""" + return ExpectationResult(expectation_id="x", is_valid=True) + + +class TestSignatureRegistry: + """Test cases for SignatureRegistry.""" + + def test_subscribe_and_get_signatures(self): + """Subscribed signatures are returned by the registry.""" + registry = SignatureRegistry() + registry.subscribe_to_signatures([SOURCE_IP]) + assert SOURCE_IP in registry.get_subscribed_signatures() # noqa: S101 + + def test_register_handler(self): + """Registering a handler records it and its signatures.""" + registry = SignatureRegistry() + registry.register_handler( + ExpectationHandlerType.DETECTION, _handler, [SOURCE_IP] + ) + + assert ( # noqa: S101 + registry.get_handler(ExpectationHandlerType.DETECTION) is _handler + ) + assert registry.has_handler_for_signatures( # noqa: S101 + ExpectationHandlerType.DETECTION, [SOURCE_IP] + ) + assert registry.is_signature_supported(SOURCE_IP) # noqa: S101 + assert ( + ExpectationHandlerType.DETECTION in registry.get_handler_types() + ) # noqa: S101 + + def test_has_handler_false_when_unregistered(self): + """An unregistered handler type reports no support.""" + registry = SignatureRegistry() + assert not registry.has_handler_for_signatures( # noqa: S101 + ExpectationHandlerType.PREVENTION, [SOURCE_IP] + ) + + def test_has_handler_false_without_overlap(self): + """A handler reports no support for unrelated signatures.""" + registry = SignatureRegistry() + registry.register_handler( + ExpectationHandlerType.DETECTION, _handler, [SOURCE_IP] + ) + assert not registry.has_handler_for_signatures( # noqa: S101 + ExpectationHandlerType.DETECTION, [TARGET_IP] + ) + + def test_get_handler_missing_raises(self): + """Retrieving a missing handler raises KeyError.""" + registry = SignatureRegistry() + with pytest.raises(KeyError): + registry.get_handler(ExpectationHandlerType.DETECTION) + + def test_clear(self): + """Clearing the registry removes all registrations.""" + registry = SignatureRegistry() + registry.register_handler( + ExpectationHandlerType.DETECTION, _handler, [SOURCE_IP] + ) + registry.clear() + assert registry.get_subscribed_signatures() == [] # noqa: S101 + assert registry.get_handler_types() == [] # noqa: S101 + + def test_get_registry_is_singleton(self): + """The module-level registry getter returns a singleton.""" + assert get_registry() is get_registry() # noqa: S101 diff --git a/netwitness/tests/test_trace_manager.py b/netwitness/tests/test_trace_manager.py new file mode 100644 index 00000000..30b9613b --- /dev/null +++ b/netwitness/tests/test_trace_manager.py @@ -0,0 +1,72 @@ +"""Tests for the TraceManager.""" + +from unittest.mock import Mock + +import pytest +from src.collector.exception import TracingError +from src.collector.trace_manager import TraceManager + + +def _trace() -> Mock: + """Build a mock trace with an API dict representation.""" + trace = Mock() + trace.to_api_dict.return_value = {"inject_expectation_trace_expectation": "e1"} + return trace + + +class TestTraceManager: + """Test cases for TraceManager.""" + + def test_no_trace_service_skips_submission(self): + """Without a trace service, no API calls are made.""" + api = Mock() + manager = TraceManager(api, "collector-id", trace_service=None) + manager.create_and_submit_traces([Mock()]) + api.inject_expectation_trace.bulk_create.assert_not_called() + + def test_create_and_submit_success(self): + """Traces are created and bulk-submitted to the API.""" + api = Mock() + trace_service = Mock() + trace_service.create_traces_from_results.return_value = [_trace()] + manager = TraceManager(api, "collector-id", trace_service=trace_service) + + manager.create_and_submit_traces([Mock()]) + + api.inject_expectation_trace.bulk_create.assert_called_once() + + def test_no_traces_created_skips_submission(self): + """When no traces are produced, no submission happens.""" + api = Mock() + trace_service = Mock() + trace_service.create_traces_from_results.return_value = [] + manager = TraceManager(api, "collector-id", trace_service=trace_service) + + manager.create_and_submit_traces([Mock()]) + + api.inject_expectation_trace.bulk_create.assert_not_called() + + def test_bulk_failure_falls_back_to_individual(self): + """A bulk failure triggers individual creation and raises TracingError.""" + api = Mock() + api.inject_expectation_trace.bulk_create.side_effect = RuntimeError("bulk") + trace_service = Mock() + trace_service.create_traces_from_results.return_value = [_trace()] + manager = TraceManager(api, "collector-id", trace_service=trace_service) + + with pytest.raises(TracingError): + manager.create_and_submit_traces([Mock()]) + + api.inject_expectation_trace.create.assert_called() + + def test_bulk_and_individual_failure(self): + """When both bulk and individual creation fail, TracingError is raised.""" + api = Mock() + api.inject_expectation_trace.bulk_create.side_effect = RuntimeError("bulk") + api.inject_expectation_trace.create.side_effect = RuntimeError("individual") + trace_service = Mock() + trace_service.create_traces_from_results.return_value = [_trace()] + manager = TraceManager(api, "collector-id", trace_service=trace_service) + + with pytest.raises(TracingError): + manager.create_and_submit_traces([Mock()])