diff --git a/.github/actions/docker-cache/action.yml b/.github/actions/docker-cache/action.yml
new file mode 100644
index 00000000..253885e2
--- /dev/null
+++ b/.github/actions/docker-cache/action.yml
@@ -0,0 +1,64 @@
+name: 'Docker Image Cache'
+description: 'Cache and load Docker images for CI jobs'
+
+inputs:
+ images:
+ description: 'Space-separated list of Docker images to cache'
+ required: true
+
+runs:
+ using: 'composite'
+ steps:
+ - name: Generate cache key from images
+ id: cache-key
+ shell: bash
+ env:
+ IMAGES_INPUT: ${{ inputs.images }}
+ run: |
+ # Create a stable hash from the sorted image list
+ # Using env var to prevent script injection
+ IMAGES_HASH=$(echo "$IMAGES_INPUT" | tr ' ' '\n' | sort | md5sum | cut -d' ' -f1)
+ echo "key=docker-${{ runner.os }}-${IMAGES_HASH}" >> $GITHUB_OUTPUT
+
+ - name: Cache Docker images
+ uses: actions/cache@v5
+ id: docker-cache
+ with:
+ path: /tmp/docker-cache
+ key: ${{ steps.cache-key.outputs.key }}
+
+ - name: Load cached Docker images
+ if: steps.docker-cache.outputs.cache-hit == 'true'
+ shell: bash
+ run: |
+ echo "Loading cached images..."
+ for f in /tmp/docker-cache/*.tar.zst; do
+ zstd -d -c "$f" | docker load &
+ done
+ wait
+ docker images
+
+ - name: Pull and save Docker images
+ if: steps.docker-cache.outputs.cache-hit != 'true'
+ shell: bash
+ env:
+ IMAGES_INPUT: ${{ inputs.images }}
+ run: |
+ mkdir -p /tmp/docker-cache
+
+ echo "Pulling images in parallel..."
+ for img in $IMAGES_INPUT; do
+ docker pull "$img" &
+ done
+ wait
+
+ echo "Saving images with zstd compression..."
+ for img in $IMAGES_INPUT; do
+ # Create filename from image name (replace special chars)
+ filename=$(echo "$img" | tr '/:' '_')
+ docker save "$img" | zstd -T0 -3 > "/tmp/docker-cache/${filename}.tar.zst" &
+ done
+ wait
+
+ echo "Cache size:"
+ du -sh /tmp/docker-cache/
diff --git a/.github/actions/setup-ci-compose/action.yml b/.github/actions/setup-ci-compose/action.yml
deleted file mode 100644
index 7ca7ccc9..00000000
--- a/.github/actions/setup-ci-compose/action.yml
+++ /dev/null
@@ -1,55 +0,0 @@
-name: Setup CI Compose
-description: Creates docker-compose.ci.yaml with CI-specific modifications
-
-inputs:
- kubeconfig-path:
- description: Path to kubeconfig file for cert-generator mount
- required: true
-
-runs:
- using: composite
- steps:
- - name: Install yq
- shell: bash
- run: |
- sudo wget -qO /usr/local/bin/yq https://github.com/mikefarah/yq/releases/download/v4.50.1/yq_linux_amd64
- sudo chmod +x /usr/local/bin/yq
-
- - name: Create CI compose configuration
- shell: bash
- env:
- KUBECONFIG_PATH: ${{ inputs.kubeconfig-path }}
- run: |
- cp docker-compose.yaml docker-compose.ci.yaml
-
- # Backend environment variables
- yq eval '.services.backend.environment += ["TESTING=true"]' -i docker-compose.ci.yaml
- yq eval '.services.backend.environment += ["MONGO_ROOT_USER=root"]' -i docker-compose.ci.yaml
- yq eval '.services.backend.environment += ["MONGO_ROOT_PASSWORD=rootpassword"]' -i docker-compose.ci.yaml
- yq eval '.services.backend.environment += ["OTEL_SDK_DISABLED=true"]' -i docker-compose.ci.yaml
-
- # Remove hot-reload volume mounts (causes permission issues and slow rebuilds in CI)
- yq eval '.services.backend.volumes = [.services.backend.volumes[] | select(. != "./backend:/app")]' -i docker-compose.ci.yaml
- yq eval '.services."k8s-worker".volumes = [.services."k8s-worker".volumes[] | select(. != "./backend:/app:ro")]' -i docker-compose.ci.yaml
- yq eval '.services."pod-monitor".volumes = [.services."pod-monitor".volumes[] | select(. != "./backend:/app:ro")]' -i docker-compose.ci.yaml
- yq eval '.services."result-processor".volumes = [.services."result-processor".volumes[] | select(. != "./backend:/app:ro")]' -i docker-compose.ci.yaml
- yq eval '.services.frontend.volumes = [.services.frontend.volumes[] | select(. != "./frontend:/app")]' -i docker-compose.ci.yaml
-
- # Disable Kafka SASL authentication for CI
- yq eval 'del(.services.kafka.environment.KAFKA_OPTS)' -i docker-compose.ci.yaml
- yq eval 'del(.services.zookeeper.environment.KAFKA_OPTS)' -i docker-compose.ci.yaml
- yq eval 'del(.services.zookeeper.environment.ZOOKEEPER_AUTH_PROVIDER_1)' -i docker-compose.ci.yaml
- yq eval '.services.kafka.volumes = [.services.kafka.volumes[] | select(. | contains("jaas.conf") | not)]' -i docker-compose.ci.yaml
- yq eval '.services.zookeeper.volumes = [.services.zookeeper.volumes[] | select(. | contains("/etc/kafka") | not)]' -i docker-compose.ci.yaml
-
- # Simplify Zookeeper for CI
- yq eval '.services.zookeeper.environment.ZOOKEEPER_4LW_COMMANDS_WHITELIST = "ruok,srvr"' -i docker-compose.ci.yaml
- yq eval 'del(.services.zookeeper.healthcheck)' -i docker-compose.ci.yaml
- yq eval '.services.kafka.depends_on.zookeeper.condition = "service_started"' -i docker-compose.ci.yaml
-
- # Cert-generator CI configuration
- yq eval '.services."cert-generator".extra_hosts = ((.services."cert-generator".extra_hosts // []) + ["host.docker.internal:host-gateway"] | unique)' -i docker-compose.ci.yaml
- yq eval '.services."cert-generator".environment = ((.services."cert-generator".environment // []) + ["CI=true"] | unique)' -i docker-compose.ci.yaml
- yq eval ".services.\"cert-generator\".volumes += [\"${KUBECONFIG_PATH}:/root/.kube/config:ro\"]" -i docker-compose.ci.yaml
-
- echo "Created docker-compose.ci.yaml"
diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index 8b1f0305..5c4cd373 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -15,13 +15,64 @@ on:
- 'docker-compose.ci.yaml'
workflow_dispatch:
+# Pin image versions for cache key consistency
+env:
+ MONGO_IMAGE: mongo:8.0
+ REDIS_IMAGE: redis:7-alpine
+ KAFKA_IMAGE: apache/kafka:3.9.0
+ SCHEMA_REGISTRY_IMAGE: confluentinc/cp-schema-registry:7.5.0
+
jobs:
+ unit:
+ name: Unit Tests
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - name: Set up uv
+ uses: astral-sh/setup-uv@v7
+ with:
+ enable-cache: true
+ cache-dependency-glob: "backend/uv.lock"
+
+ - name: Install Python dependencies
+ run: |
+ cd backend
+ uv python install 3.12
+ uv sync --frozen
+
+ - name: Run unit tests
+ timeout-minutes: 5
+ run: |
+ cd backend
+ uv run pytest tests/unit -v -rs \
+ --cov=app \
+ --cov-report=xml --cov-report=term
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v5
+ if: always()
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ files: backend/coverage.xml
+ flags: backend-unit
+ name: backend-unit-coverage
+ fail_ci_if_error: false
+ verbose: true
+
integration:
name: Integration Tests
runs-on: ubuntu-latest
+
steps:
- uses: actions/checkout@v6
+ - name: Cache and load Docker images
+ uses: ./.github/actions/docker-cache
+ with:
+ images: ${{ env.MONGO_IMAGE }} ${{ env.REDIS_IMAGE }} ${{ env.KAFKA_IMAGE }} ${{ env.SCHEMA_REGISTRY_IMAGE }}
+
- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
@@ -34,70 +85,114 @@ jobs:
uv python install 3.12
uv sync --frozen
- - name: Setup Docker Buildx
- uses: docker/setup-buildx-action@v3
+ - name: Start infrastructure services
+ run: |
+ docker compose -f docker-compose.ci.yaml up -d --wait --wait-timeout 120
+ docker compose -f docker-compose.ci.yaml ps
- - name: Setup Kubernetes (k3s)
+ - name: Run integration tests
+ timeout-minutes: 10
+ env:
+ MONGO_ROOT_USER: root
+ MONGO_ROOT_PASSWORD: rootpassword
+ MONGODB_HOST: 127.0.0.1
+ MONGODB_PORT: 27017
+ MONGODB_URL: mongodb://root:rootpassword@127.0.0.1:27017/?authSource=admin
+ KAFKA_BOOTSTRAP_SERVERS: localhost:9092
+ SCHEMA_REGISTRY_URL: http://localhost:8081
+ REDIS_HOST: localhost
+ REDIS_PORT: 6379
+ SCHEMA_SUBJECT_PREFIX: "ci.${{ github.run_id }}."
run: |
- curl -sfL https://get.k3s.io | INSTALL_K3S_EXEC="--disable=traefik --tls-san host.docker.internal" sh -
- mkdir -p /home/runner/.kube
- sudo k3s kubectl config view --raw > /home/runner/.kube/config
- sudo chmod 600 /home/runner/.kube/config
- export KUBECONFIG=/home/runner/.kube/config
- timeout 90 bash -c 'until sudo k3s kubectl cluster-info; do sleep 5; done'
+ cd backend
+ uv run pytest tests/integration -v -rs \
+ --ignore=tests/integration/k8s \
+ --cov=app \
+ --cov-report=xml --cov-report=term
- - name: Create kubeconfig for CI Docker containers
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v5
+ if: always()
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ files: backend/coverage.xml
+ flags: backend-integration
+ name: backend-integration-coverage
+ fail_ci_if_error: false
+ verbose: true
+
+ - name: Collect logs
+ if: failure()
run: |
- # Copy real k3s kubeconfig with valid credentials, but change server address
- # from 127.0.0.1 to host.docker.internal for Docker container networking
- # (k3s was started with --tls-san host.docker.internal so the cert is valid)
- sed 's|https://127.0.0.1:6443|https://host.docker.internal:6443|g' \
- /home/runner/.kube/config > backend/kubeconfig.yaml
- chmod 644 backend/kubeconfig.yaml
-
- - name: Setup CI Compose
- uses: ./.github/actions/setup-ci-compose
+ mkdir -p logs
+ docker compose -f docker-compose.ci.yaml logs > logs/docker-compose.log 2>&1
+ docker compose -f docker-compose.ci.yaml logs kafka > logs/kafka.log 2>&1
+ docker compose -f docker-compose.ci.yaml logs schema-registry > logs/schema-registry.log 2>&1
+
+ - name: Upload logs
+ if: failure()
+ uses: actions/upload-artifact@v6
with:
- kubeconfig-path: /home/runner/.kube/config
+ name: backend-logs
+ path: logs/
+
+ e2e:
+ name: E2E Tests
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v6
- - name: Build services
- uses: docker/bake-action@v6
+ - name: Cache and load Docker images
+ uses: ./.github/actions/docker-cache
with:
- source: .
- files: docker-compose.ci.yaml
- load: true
- set: |
- *.cache-from=type=gha,scope=buildkit-${{ github.repository }}-${{ github.ref_name }}
- *.cache-from=type=gha,scope=buildkit-${{ github.repository }}-main
- *.cache-to=type=gha,mode=max,scope=buildkit-${{ github.repository }}-${{ github.ref_name }}
- *.pull=true
- env:
- BUILDKIT_PROGRESS: plain
+ images: ${{ env.MONGO_IMAGE }} ${{ env.REDIS_IMAGE }} ${{ env.KAFKA_IMAGE }} ${{ env.SCHEMA_REGISTRY_IMAGE }}
+
+ - name: Set up uv
+ uses: astral-sh/setup-uv@v7
+ with:
+ enable-cache: true
+ cache-dependency-glob: "backend/uv.lock"
- - name: Start services
+ - name: Install Python dependencies
run: |
- docker compose -f docker-compose.ci.yaml up -d --remove-orphans
- docker compose -f docker-compose.ci.yaml ps
+ cd backend
+ uv python install 3.12
+ uv sync --frozen
- - name: Wait for backend
+ - name: Start infrastructure services
run: |
- curl --retry 60 --retry-delay 5 --retry-all-errors -ksf https://127.0.0.1:443/api/v1/health/live
+ docker compose -f docker-compose.ci.yaml up -d --wait --wait-timeout 120
docker compose -f docker-compose.ci.yaml ps
- kubectl get pods -A -o wide
- - name: Run integration tests
+ - name: Setup Kubernetes (k3s)
+ run: |
+ curl -sfL https://get.k3s.io | INSTALL_K3S_EXEC="--disable=traefik" sh -
+ mkdir -p /home/runner/.kube
+ sudo k3s kubectl config view --raw > /home/runner/.kube/config
+ sudo chmod 600 /home/runner/.kube/config
+ export KUBECONFIG=/home/runner/.kube/config
+ timeout 90 bash -c 'until sudo k3s kubectl cluster-info; do sleep 5; done'
+ kubectl create namespace integr8scode --dry-run=client -o yaml | kubectl apply -f -
+
+ - name: Run E2E tests
timeout-minutes: 10
env:
- BACKEND_BASE_URL: https://127.0.0.1:443
MONGO_ROOT_USER: root
MONGO_ROOT_PASSWORD: rootpassword
- MONGODB_HOST: 127.0.0.1
- MONGODB_PORT: 27017
MONGODB_URL: mongodb://root:rootpassword@127.0.0.1:27017/?authSource=admin
+ KAFKA_BOOTSTRAP_SERVERS: localhost:9092
+ SCHEMA_REGISTRY_URL: http://localhost:8081
+ REDIS_HOST: localhost
+ REDIS_PORT: 6379
SCHEMA_SUBJECT_PREFIX: "ci.${{ github.run_id }}."
+ KUBECONFIG: /home/runner/.kube/config
+ K8S_NAMESPACE: integr8scode
run: |
cd backend
- uv run pytest tests/integration -v -rs --cov=app --cov-branch --cov-report=xml --cov-report=term
+ uv run pytest tests/integration/k8s -v -rs \
+ --cov=app \
+ --cov-report=xml --cov-report=term
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -105,8 +200,8 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: backend/coverage.xml
- flags: backend
- name: backend-coverage
+ flags: backend-e2e
+ name: backend-e2e-coverage
fail_ci_if_error: false
verbose: true
@@ -114,15 +209,13 @@ jobs:
if: failure()
run: |
mkdir -p logs
- docker compose -f docker-compose.ci.yaml logs > logs/docker-compose.log
- docker compose -f docker-compose.ci.yaml logs backend > logs/backend.log
- docker compose -f docker-compose.ci.yaml logs mongo > logs/mongo.log
- kubectl get events --sort-by='.metadata.creationTimestamp' > logs/k8s-events.log 2>&1 || true
+ docker compose -f docker-compose.ci.yaml logs > logs/docker-compose.log 2>&1
+ kubectl get events --sort-by='.metadata.creationTimestamp' -A > logs/k8s-events.log 2>&1 || true
kubectl describe pods -A > logs/k8s-describe-pods.log 2>&1 || true
- name: Upload logs
if: failure()
uses: actions/upload-artifact@v6
with:
- name: backend-logs
+ name: k8s-logs
path: logs/
diff --git a/.github/workflows/frontend-ci.yml b/.github/workflows/frontend-ci.yml
index dca9efe9..e5d8c29d 100644
--- a/.github/workflows/frontend-ci.yml
+++ b/.github/workflows/frontend-ci.yml
@@ -81,59 +81,18 @@ jobs:
export KUBECONFIG=/home/runner/.kube/config
timeout 90 bash -c 'until sudo k3s kubectl cluster-info; do sleep 5; done'
- - name: Create kubeconfig for CI
+ - name: Create kubeconfig for Docker containers
run: |
- cat > backend/kubeconfig.yaml <
-
-
- Error: {session.error_message}
-
+
+
diff --git a/backend/.dockerignore b/backend/.dockerignore
new file mode 100644
index 00000000..5a9bec3f
--- /dev/null
+++ b/backend/.dockerignore
@@ -0,0 +1,25 @@
+# Virtual environments
+.venv/
+venv/
+__pycache__/
+*.pyc
+
+# IDE
+.idea/
+.vscode/
+*.swp
+
+# Test artifacts
+.pytest_cache/
+.coverage
+htmlcov/
+.mypy_cache/
+.ruff_cache/
+
+# Git
+.git/
+.gitignore
+
+# Local dev files
+*.log
+.DS_Store
diff --git a/backend/Dockerfile b/backend/Dockerfile
index 52f33b13..b9897fac 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -14,24 +14,34 @@ RUN mkdir -p /app/certs
# Expose metrics port
EXPOSE 9090
-# Simplified CMD
-CMD bash -c "\
- while [ ! -f /app/certs/server.key ]; do echo 'Waiting for TLS certs...'; sleep 2; done && \
- echo 'Starting application...' && \
- # Use kubeconfig if present, but do not block startup\
- if [ -f /app/kubeconfig.yaml ]; then export KUBECONFIG=/app/kubeconfig.yaml; fi && \
- WEB_CONCURRENCY=${WEB_CONCURRENCY:-4} WEB_THREADS=${WEB_THREADS:-1} WEB_TIMEOUT=${WEB_TIMEOUT:-60} \
- uv run gunicorn app.main:app \
- -k uvicorn.workers.UvicornWorker \
- --bind 0.0.0.0:443 \
- --workers ${WEB_CONCURRENCY} \
- --threads ${WEB_THREADS} \
- --timeout ${WEB_TIMEOUT} \
- --graceful-timeout 30 \
- --keep-alive 2 \
- --backlog ${WEB_BACKLOG:-2048} \
- --log-level info \
- --access-logfile - \
- --error-logfile - \
- --keyfile /app/certs/server.key \
- --certfile /app/certs/server.crt"
+# Create entrypoint script inline (BuildKit heredoc)
+COPY <<'EOF' /entrypoint.sh
+#!/bin/bash
+set -e
+
+while [ ! -f /app/certs/server.key ]; do
+ echo "Waiting for TLS certs..."
+ sleep 2
+done
+
+echo "Starting application..."
+[ -f /app/kubeconfig.yaml ] && export KUBECONFIG=/app/kubeconfig.yaml
+
+exec gunicorn app.main:app \
+ -k uvicorn.workers.UvicornWorker \
+ --bind 0.0.0.0:443 \
+ --workers ${WEB_CONCURRENCY:-4} \
+ --threads ${WEB_THREADS:-1} \
+ --timeout ${WEB_TIMEOUT:-60} \
+ --graceful-timeout 30 \
+ --keep-alive 2 \
+ --backlog ${WEB_BACKLOG:-2048} \
+ --log-level info \
+ --access-logfile - \
+ --error-logfile - \
+ --keyfile /app/certs/server.key \
+ --certfile /app/certs/server.crt
+EOF
+
+RUN chmod +x /entrypoint.sh
+CMD ["/entrypoint.sh"]
diff --git a/backend/Dockerfile.base b/backend/Dockerfile.base
index 8515ceb1..c06680c4 100644
--- a/backend/Dockerfile.base
+++ b/backend/Dockerfile.base
@@ -19,8 +19,8 @@ COPY --from=astral/uv:latest /uv /uvx /bin/
COPY pyproject.toml uv.lock ./
# Install Python dependencies (production only)
-# --no-install-project: don't install project itself, only dependencies
-RUN uv sync --frozen --no-dev --no-install-project
+RUN uv sync --locked --no-dev --no-install-project
-# Set Python path so imports work
+# Set paths: PYTHONPATH for imports, PATH for venv binaries (no uv run needed at runtime)
ENV PYTHONPATH=/app
+ENV PATH="/app/.venv/bin:$PATH"
diff --git a/backend/app/api/routes/admin/events.py b/backend/app/api/routes/admin/events.py
index 89a802ba..6dbda0e0 100644
--- a/backend/app/api/routes/admin/events.py
+++ b/backend/app/api/routes/admin/events.py
@@ -1,3 +1,4 @@
+from dataclasses import asdict
from datetime import datetime
from typing import Annotated
@@ -9,11 +10,9 @@
from app.api.dependencies import admin_user
from app.core.correlation import CorrelationContext
+from app.domain.admin import ReplayQuery
from app.domain.enums.events import EventType
-from app.infrastructure.mappers import (
- AdminReplayApiMapper,
- EventFilterMapper,
-)
+from app.domain.events.event_models import EventFilter
from app.schemas_pydantic.admin_events import (
EventBrowseRequest,
EventBrowseResponse,
@@ -24,7 +23,6 @@
EventReplayStatusResponse,
EventStatsResponse,
)
-from app.schemas_pydantic.admin_events import EventFilter as AdminEventFilter
from app.schemas_pydantic.user import UserResponse
from app.services.admin import AdminEventsService
@@ -36,7 +34,7 @@
@router.post("/browse")
async def browse_events(request: EventBrowseRequest, service: FromDishka[AdminEventsService]) -> EventBrowseResponse:
try:
- event_filter = EventFilterMapper.from_admin_pydantic(request.filters)
+ event_filter = EventFilter(**request.filters.model_dump())
result = await service.browse_events(
event_filter=event_filter,
@@ -79,12 +77,10 @@ async def export_events_csv(
limit: int = Query(default=10000, le=50000),
) -> StreamingResponse:
try:
- export_filter = EventFilterMapper.from_admin_pydantic(
- AdminEventFilter(
- event_types=event_types,
- start_time=start_time,
- end_time=end_time,
- )
+ export_filter = EventFilter(
+ event_types=[str(et) for et in event_types] if event_types else None,
+ start_time=start_time,
+ end_time=end_time,
)
result = await service.export_events_csv_content(event_filter=export_filter, limit=limit)
return StreamingResponse(
@@ -111,16 +107,14 @@ async def export_events_json(
) -> StreamingResponse:
"""Export events as JSON with comprehensive filtering."""
try:
- export_filter = EventFilterMapper.from_admin_pydantic(
- AdminEventFilter(
- event_types=event_types,
- aggregate_id=aggregate_id,
- correlation_id=correlation_id,
- user_id=user_id,
- service_name=service_name,
- start_time=start_time,
- end_time=end_time,
- )
+ export_filter = EventFilter(
+ event_types=[str(et) for et in event_types] if event_types else None,
+ aggregate_id=aggregate_id,
+ correlation_id=correlation_id,
+ user_id=user_id,
+ service_name=service_name,
+ start_time=start_time,
+ end_time=end_time,
)
result = await service.export_events_json_content(event_filter=export_filter, limit=limit)
return StreamingResponse(
@@ -159,7 +153,13 @@ async def replay_events(
) -> EventReplayResponse:
try:
replay_correlation_id = f"replay_{CorrelationContext.get_correlation_id()}"
- rq = AdminReplayApiMapper.request_to_query(request)
+ rq = ReplayQuery(
+ event_ids=request.event_ids,
+ correlation_id=request.correlation_id,
+ aggregate_id=request.aggregate_id,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ )
try:
result = await service.prepare_or_schedule_replay(
replay_query=rq,
@@ -201,7 +201,17 @@ async def get_replay_status(session_id: str, service: FromDishka[AdminEventsServ
if not status:
raise HTTPException(status_code=404, detail="Replay session not found")
- return EventReplayStatusResponse.model_validate(status)
+ session = status.session
+ estimated_completion = status.estimated_completion
+ execution_results = status.execution_results
+ return EventReplayStatusResponse(
+ **{
+ **asdict(session),
+ "status": session.status.value,
+ "estimated_completion": estimated_completion,
+ "execution_results": execution_results,
+ }
+ )
except HTTPException:
raise
diff --git a/backend/app/api/routes/admin/settings.py b/backend/app/api/routes/admin/settings.py
index 87ea4cf5..23a82628 100644
--- a/backend/app/api/routes/admin/settings.py
+++ b/backend/app/api/routes/admin/settings.py
@@ -6,11 +6,41 @@
from pydantic import ValidationError
from app.api.dependencies import admin_user
-from app.infrastructure.mappers import SettingsMapper
+from app.domain.admin import (
+ ExecutionLimits,
+ LogLevel,
+ MonitoringSettings,
+ SecuritySettings,
+)
+from app.domain.admin import (
+ SystemSettings as DomainSystemSettings,
+)
from app.schemas_pydantic.admin_settings import SystemSettings
from app.schemas_pydantic.user import UserResponse
from app.services.admin import AdminSettingsService
+
+def _domain_to_pydantic(domain: DomainSystemSettings) -> SystemSettings:
+ """Convert domain SystemSettings to Pydantic schema."""
+ return SystemSettings.model_validate(domain, from_attributes=True)
+
+
+def _pydantic_to_domain(schema: SystemSettings) -> DomainSystemSettings:
+ """Convert Pydantic schema to domain SystemSettings."""
+ data = schema.model_dump()
+ mon = data.get("monitoring_settings", {})
+ return DomainSystemSettings(
+ execution_limits=ExecutionLimits(**data.get("execution_limits", {})),
+ security_settings=SecuritySettings(**data.get("security_settings", {})),
+ monitoring_settings=MonitoringSettings(
+ metrics_retention_days=mon.get("metrics_retention_days", 30),
+ log_level=LogLevel(mon.get("log_level", "INFO")),
+ enable_tracing=mon.get("enable_tracing", True),
+ sampling_rate=mon.get("sampling_rate", 0.1),
+ ),
+ )
+
+
router = APIRouter(
prefix="/admin/settings", tags=["admin", "settings"], route_class=DishkaRoute, dependencies=[Depends(admin_user)]
)
@@ -23,8 +53,7 @@ async def get_system_settings(
) -> SystemSettings:
try:
domain_settings = await service.get_system_settings(admin.username)
- settings_mapper = SettingsMapper()
- return SystemSettings(**settings_mapper.system_settings_to_pydantic_dict(domain_settings))
+ return _domain_to_pydantic(domain_settings)
except Exception:
raise HTTPException(status_code=500, detail="Failed to retrieve settings")
@@ -37,8 +66,7 @@ async def update_system_settings(
service: FromDishka[AdminSettingsService],
) -> SystemSettings:
try:
- settings_mapper = SettingsMapper()
- domain_settings = settings_mapper.system_settings_from_pydantic(settings.model_dump())
+ domain_settings = _pydantic_to_domain(settings)
except (ValueError, ValidationError, KeyError) as e:
raise HTTPException(status_code=422, detail=f"Invalid settings: {str(e)}")
except Exception:
@@ -52,9 +80,7 @@ async def update_system_settings(
user_id=admin.user_id,
)
- # Convert back to pydantic schema for response
- settings_mapper = SettingsMapper()
- return SystemSettings(**settings_mapper.system_settings_to_pydantic_dict(updated_domain_settings))
+ return _domain_to_pydantic(updated_domain_settings)
except Exception:
raise HTTPException(status_code=500, detail="Failed to update settings")
@@ -67,8 +93,7 @@ async def reset_system_settings(
) -> SystemSettings:
try:
reset_domain_settings = await service.reset_system_settings(admin.username, admin.user_id)
- settings_mapper = SettingsMapper()
- return SystemSettings(**settings_mapper.system_settings_to_pydantic_dict(reset_domain_settings))
+ return _domain_to_pydantic(reset_domain_settings)
except Exception:
raise HTTPException(status_code=500, detail="Failed to reset settings")
diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py
index da0eb586..606d57f2 100644
--- a/backend/app/api/routes/auth.py
+++ b/backend/app/api/routes/auth.py
@@ -1,16 +1,16 @@
-from datetime import datetime, timedelta, timezone
-from uuid import uuid4
+import logging
+from datetime import timedelta
from dishka import FromDishka
from dishka.integrations.fastapi import DishkaRoute
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.security import OAuth2PasswordRequestForm
+from pymongo.errors import DuplicateKeyError
-from app.core.logging import logger
from app.core.security import security_service
from app.core.utils import get_client_ip
from app.db.repositories import UserRepository
-from app.domain.user import User as DomainAdminUser
+from app.domain.user import DomainUserCreate
from app.schemas_pydantic.user import (
LoginResponse,
MessageResponse,
@@ -29,6 +29,7 @@ async def login(
request: Request,
response: Response,
user_repo: FromDishka[UserRepository],
+ logger: FromDishka[logging.Logger],
form_data: OAuth2PasswordRequestForm = Depends(),
) -> LoginResponse:
logger.info(
@@ -126,6 +127,7 @@ async def register(
request: Request,
user: UserCreate,
user_repo: FromDishka[UserRepository],
+ logger: FromDishka[logging.Logger],
) -> UserResponse:
logger.info(
"Registration attempt",
@@ -151,19 +153,15 @@ async def register(
try:
hashed_password = security_service.get_password_hash(user.password)
- now = datetime.now(timezone.utc)
- domain_user = DomainAdminUser(
- user_id=str(uuid4()),
+ create_data = DomainUserCreate(
username=user.username,
email=str(user.email),
+ hashed_password=hashed_password,
role=user.role,
is_active=True,
is_superuser=False,
- hashed_password=hashed_password,
- created_at=now,
- updated_at=now,
)
- created_user = await user_repo.create_user(domain_user)
+ created_user = await user_repo.create_user(create_data)
logger.info(
"Registration successful",
@@ -184,6 +182,15 @@ async def register(
updated_at=created_user.updated_at,
)
+ except DuplicateKeyError as e:
+ logger.warning(
+ "Registration failed - duplicate email",
+ extra={
+ "username": user.username,
+ "client_ip": get_client_ip(request),
+ },
+ )
+ raise HTTPException(status_code=409, detail="Email already registered") from e
except Exception as e:
logger.error(
f"Registration failed - database error: {str(e)}",
@@ -204,6 +211,7 @@ async def get_current_user_profile(
request: Request,
response: Response,
auth_service: FromDishka[AuthService],
+ logger: FromDishka[logging.Logger],
) -> UserResponse:
current_user = await auth_service.get_current_user(request)
@@ -227,6 +235,7 @@ async def get_current_user_profile(
async def verify_token(
request: Request,
auth_service: FromDishka[AuthService],
+ logger: FromDishka[logging.Logger],
) -> TokenValidationResponse:
current_user = await auth_service.get_current_user(request)
logger.info(
@@ -278,6 +287,7 @@ async def verify_token(
async def logout(
request: Request,
response: Response,
+ logger: FromDishka[logging.Logger],
) -> MessageResponse:
logger.info(
"Logout attempt",
diff --git a/backend/app/api/routes/dlq.py b/backend/app/api/routes/dlq.py
index ffb0b23a..1ab6136d 100644
--- a/backend/app/api/routes/dlq.py
+++ b/backend/app/api/routes/dlq.py
@@ -72,8 +72,8 @@ async def get_dlq_message(event_id: str, repository: FromDishka[DLQRepository])
raise HTTPException(status_code=404, detail="Message not found")
return DLQMessageDetail(
- event_id=message.event_id or "unknown",
- event=message.event.to_dict(),
+ event_id=message.event_id,
+ event=message.event.model_dump(),
event_type=message.event_type,
original_topic=message.original_topic,
error=message.error,
diff --git a/backend/app/api/routes/events.py b/backend/app/api/routes/events.py
index 7fd6303d..bc15166f 100644
--- a/backend/app/api/routes/events.py
+++ b/backend/app/api/routes/events.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from datetime import datetime, timedelta, timezone
from typing import Annotated, Any, Dict, List
@@ -8,7 +9,6 @@
from app.api.dependencies import admin_user, current_user
from app.core.correlation import CorrelationContext
-from app.core.logging import logger
from app.core.utils import get_client_ip
from app.domain.enums.common import SortOrder
from app.domain.events.event_models import EventFilter
@@ -107,7 +107,7 @@ async def query_events(
service_name=filter_request.service_name,
start_time=filter_request.start_time,
end_time=filter_request.end_time,
- text_search=filter_request.text_search,
+ search_text=filter_request.search_text,
)
result = await event_service.query_events_advanced(
@@ -115,7 +115,6 @@ async def query_events(
user_role=current_user.role,
filters=event_filter,
sort_by=filter_request.sort_by,
- sort_order=filter_request.sort_order,
limit=filter_request.limit,
skip=filter_request.skip,
)
@@ -283,6 +282,7 @@ async def delete_event(
event_id: str,
admin: Annotated[UserResponse, Depends(admin_user)],
event_service: FromDishka[EventService],
+ logger: FromDishka[logging.Logger],
) -> DeleteEventResponse:
result = await event_service.delete_event_with_archival(event_id=event_id, deleted_by=str(admin.email))
@@ -290,8 +290,10 @@ async def delete_event(
raise HTTPException(status_code=404, detail="Event not found")
logger.warning(
- f"Event {event_id} deleted by admin {admin.email}",
+ "Event deleted by admin",
extra={
+ "event_id": event_id,
+ "admin_email": admin.email,
"event_type": result.event_type,
"aggregate_id": result.aggregate_id,
"correlation_id": result.correlation_id,
@@ -309,6 +311,7 @@ async def replay_aggregate_events(
admin: Annotated[UserResponse, Depends(admin_user)],
event_service: FromDishka[EventService],
kafka_event_service: FromDishka[KafkaEventService],
+ logger: FromDishka[logging.Logger],
target_service: str | None = Query(None, description="Service to replay events to"),
dry_run: bool = Query(True, description="If true, only show what would be replayed"),
) -> ReplayAggregateResponse:
diff --git a/backend/app/api/routes/execution.py b/backend/app/api/routes/execution.py
index df1d40d4..37723a01 100644
--- a/backend/app/api/routes/execution.py
+++ b/backend/app/api/routes/execution.py
@@ -7,12 +7,12 @@
from fastapi import APIRouter, Depends, Header, HTTPException, Path, Query, Request
from app.api.dependencies import admin_user, current_user
-from app.core.exceptions import IntegrationException
from app.core.tracing import EventAttributes, add_span_attributes
from app.core.utils import get_client_ip
from app.domain.enums.events import EventType
from app.domain.enums.execution import ExecutionStatus
from app.domain.enums.user import UserRole
+from app.domain.exceptions import DomainError
from app.infrastructure.kafka.events.base import BaseEvent
from app.infrastructure.kafka.events.metadata import AvroEventMetadata as EventMetadata
from app.schemas_pydantic.execution import (
@@ -127,7 +127,7 @@ async def create_execution(
return ExecutionResponse.model_validate(exec_result)
- except IntegrationException as e:
+ except DomainError as e:
# Mark as failed for idempotency
if idempotency_key and pseudo_event:
await idempotency_manager.mark_failed(
@@ -136,7 +136,7 @@ async def create_execution(
key_strategy="custom",
custom_key=f"http:{current_user.user_id}:{idempotency_key}",
)
- raise HTTPException(status_code=e.status_code, detail=e.detail) from e
+ raise
except Exception as e:
# Mark as failed for idempotency
if idempotency_key and pseudo_event:
diff --git a/backend/app/api/routes/replay.py b/backend/app/api/routes/replay.py
index 0d4c1a94..f9919cb5 100644
--- a/backend/app/api/routes/replay.py
+++ b/backend/app/api/routes/replay.py
@@ -1,10 +1,12 @@
+from dataclasses import asdict
+
from dishka import FromDishka
from dishka.integrations.fastapi import DishkaRoute
from fastapi import APIRouter, Depends, Query
from app.api.dependencies import admin_user
from app.domain.enums.replay import ReplayStatus
-from app.infrastructure.mappers import ReplayApiMapper
+from app.domain.replay import ReplayConfig
from app.schemas_pydantic.replay import (
CleanupResponse,
ReplayRequest,
@@ -22,9 +24,8 @@ async def create_replay_session(
replay_request: ReplayRequest,
service: FromDishka[ReplayService],
) -> ReplayResponse:
- cfg = ReplayApiMapper.request_to_config(replay_request)
- result = await service.create_session_from_config(cfg)
- return ReplayApiMapper.op_to_response(result.session_id, result.status, result.message)
+ result = await service.create_session_from_config(ReplayConfig(**replay_request.model_dump()))
+ return ReplayResponse(session_id=result.session_id, status=result.status, message=result.message)
@router.post("/sessions/{session_id}/start", response_model=ReplayResponse)
@@ -33,7 +34,7 @@ async def start_replay_session(
service: FromDishka[ReplayService],
) -> ReplayResponse:
result = await service.start_session(session_id)
- return ReplayApiMapper.op_to_response(result.session_id, result.status, result.message)
+ return ReplayResponse(session_id=result.session_id, status=result.status, message=result.message)
@router.post("/sessions/{session_id}/pause", response_model=ReplayResponse)
@@ -42,19 +43,19 @@ async def pause_replay_session(
service: FromDishka[ReplayService],
) -> ReplayResponse:
result = await service.pause_session(session_id)
- return ReplayApiMapper.op_to_response(result.session_id, result.status, result.message)
+ return ReplayResponse(session_id=result.session_id, status=result.status, message=result.message)
@router.post("/sessions/{session_id}/resume", response_model=ReplayResponse)
async def resume_replay_session(session_id: str, service: FromDishka[ReplayService]) -> ReplayResponse:
result = await service.resume_session(session_id)
- return ReplayApiMapper.op_to_response(result.session_id, result.status, result.message)
+ return ReplayResponse(session_id=result.session_id, status=result.status, message=result.message)
@router.post("/sessions/{session_id}/cancel", response_model=ReplayResponse)
async def cancel_replay_session(session_id: str, service: FromDishka[ReplayService]) -> ReplayResponse:
result = await service.cancel_session(session_id)
- return ReplayApiMapper.op_to_response(result.session_id, result.status, result.message)
+ return ReplayResponse(session_id=result.session_id, status=result.status, message=result.message)
@router.get("/sessions", response_model=list[SessionSummary])
@@ -63,8 +64,10 @@ async def list_replay_sessions(
status: ReplayStatus | None = Query(None),
limit: int = Query(100, ge=1, le=1000),
) -> list[SessionSummary]:
- states = service.list_sessions(status=status, limit=limit)
- return [ReplayApiMapper.session_to_summary(s) for s in states]
+ return [
+ SessionSummary.model_validate({**asdict(s), **asdict(s)["config"]})
+ for s in service.list_sessions(status=status, limit=limit)
+ ]
@router.get("/sessions/{session_id}", response_model=ReplaySession)
@@ -79,4 +82,4 @@ async def cleanup_old_sessions(
older_than_hours: int = Query(24, ge=1),
) -> CleanupResponse:
result = await service.cleanup_old_sessions(older_than_hours)
- return ReplayApiMapper.cleanup_to_response(result.removed_sessions, result.message)
+ return CleanupResponse(removed_sessions=result.removed_sessions, message=result.message)
diff --git a/backend/app/api/routes/saga.py b/backend/app/api/routes/saga.py
index 40037a3c..3dff8d11 100644
--- a/backend/app/api/routes/saga.py
+++ b/backend/app/api/routes/saga.py
@@ -3,7 +3,6 @@
from fastapi import APIRouter, Query, Request
from app.domain.enums.saga import SagaState
-from app.infrastructure.mappers import UserMapper as AdminUserMapper
from app.schemas_pydantic.saga import (
SagaCancellationResponse,
SagaListResponse,
@@ -42,11 +41,9 @@ async def get_saga_status(
HTTPException: 404 if saga not found, 403 if access denied
"""
current_user = await auth_service.get_current_user(request)
-
- service_user = User.from_response(current_user)
- domain_user = AdminUserMapper.from_pydantic_service_user(service_user)
- saga = await saga_service.get_saga_with_access_check(saga_id, domain_user)
- return SagaStatusResponse.from_domain(saga)
+ user = User.model_validate(current_user)
+ saga = await saga_service.get_saga_with_access_check(saga_id, user)
+ return SagaStatusResponse.model_validate(saga)
@router.get("/execution/{execution_id}", response_model=SagaListResponse)
@@ -77,11 +74,9 @@ async def get_execution_sagas(
HTTPException: 403 if access denied
"""
current_user = await auth_service.get_current_user(request)
-
- service_user = User.from_response(current_user)
- domain_user = AdminUserMapper.from_pydantic_service_user(service_user)
- result = await saga_service.get_execution_sagas(execution_id, domain_user, state, limit=limit, skip=skip)
- saga_responses = [SagaStatusResponse.from_domain(s) for s in result.sagas]
+ user = User.model_validate(current_user)
+ result = await saga_service.get_execution_sagas(execution_id, user, state, limit=limit, skip=skip)
+ saga_responses = [SagaStatusResponse.model_validate(s) for s in result.sagas]
return SagaListResponse(
sagas=saga_responses,
total=result.total,
@@ -114,11 +109,9 @@ async def list_sagas(
Paginated list of sagas
"""
current_user = await auth_service.get_current_user(request)
-
- service_user = User.from_response(current_user)
- domain_user = AdminUserMapper.from_pydantic_service_user(service_user)
- result = await saga_service.list_user_sagas(domain_user, state, limit, skip)
- saga_responses = [SagaStatusResponse.from_domain(s) for s in result.sagas]
+ user = User.model_validate(current_user)
+ result = await saga_service.list_user_sagas(user, state, limit, skip)
+ saga_responses = [SagaStatusResponse.model_validate(s) for s in result.sagas]
return SagaListResponse(
sagas=saga_responses,
total=result.total,
@@ -150,10 +143,8 @@ async def cancel_saga(
HTTPException: 404 if not found, 403 if denied, 400 if invalid state
"""
current_user = await auth_service.get_current_user(request)
-
- service_user = User.from_response(current_user)
- domain_user = AdminUserMapper.from_pydantic_service_user(service_user)
- success = await saga_service.cancel_saga(saga_id, domain_user)
+ user = User.model_validate(current_user)
+ success = await saga_service.cancel_saga(saga_id, user)
return SagaCancellationResponse(
success=success,
diff --git a/backend/app/core/container.py b/backend/app/core/container.py
index fef3e1b3..c7f8b9d7 100644
--- a/backend/app/core/container.py
+++ b/backend/app/core/container.py
@@ -9,6 +9,7 @@
CoreServicesProvider,
DatabaseProvider,
EventProvider,
+ LoggingProvider,
MessagingProvider,
RedisProvider,
ResultProcessorProvider,
@@ -23,6 +24,7 @@ def create_app_container() -> AsyncContainer:
"""
return make_async_container(
SettingsProvider(),
+ LoggingProvider(),
DatabaseProvider(),
RedisProvider(),
CoreServicesProvider(),
@@ -44,6 +46,7 @@ def create_result_processor_container() -> AsyncContainer:
"""
return make_async_container(
SettingsProvider(),
+ LoggingProvider(),
DatabaseProvider(),
CoreServicesProvider(),
ConnectionProvider(),
diff --git a/backend/app/core/correlation.py b/backend/app/core/correlation.py
index ea62ac34..4007fef6 100644
--- a/backend/app/core/correlation.py
+++ b/backend/app/core/correlation.py
@@ -5,7 +5,7 @@
from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
-from app.core.logging import correlation_id_context, logger, request_metadata_context
+from app.core.logging import correlation_id_context, request_metadata_context
class CorrelationContext:
@@ -16,7 +16,6 @@ def generate_correlation_id() -> str:
@staticmethod
def set_correlation_id(correlation_id: str) -> str:
correlation_id_context.set(correlation_id)
- logger.debug(f"Set correlation ID: {correlation_id}")
return correlation_id
@staticmethod
@@ -26,7 +25,6 @@ def get_correlation_id() -> str:
@staticmethod
def set_request_metadata(metadata: Dict[str, Any]) -> None:
request_metadata_context.set(metadata)
- logger.debug(f"Set request metadata: {metadata}")
@staticmethod
def get_request_metadata() -> Dict[str, Any]:
@@ -36,7 +34,6 @@ def get_request_metadata() -> Dict[str, Any]:
def clear() -> None:
correlation_id_context.set(None)
request_metadata_context.set(None)
- logger.debug("Cleared correlation context")
class CorrelationMiddleware:
diff --git a/backend/app/core/database_context.py b/backend/app/core/database_context.py
index d7c40631..06913e03 100644
--- a/backend/app/core/database_context.py
+++ b/backend/app/core/database_context.py
@@ -1,281 +1,14 @@
-import contextvars
-from collections.abc import AsyncIterator
-from contextlib import asynccontextmanager
-from dataclasses import dataclass
-from typing import Any, AsyncContextManager, Protocol, TypeVar, runtime_checkable
+from typing import Any
-from motor.motor_asyncio import (
- AsyncIOMotorClient,
- AsyncIOMotorClientSession,
- AsyncIOMotorCollection,
- AsyncIOMotorCursor,
- AsyncIOMotorDatabase,
-)
-from pymongo.errors import ServerSelectionTimeoutError
+from pymongo.asynchronous.client_session import AsyncClientSession
+from pymongo.asynchronous.collection import AsyncCollection
+from pymongo.asynchronous.cursor import AsyncCursor
+from pymongo.asynchronous.database import AsyncDatabase
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
-from app.core.logging import logger
-
-# Python 3.12 type aliases using the new 'type' statement
-# MongoDocument represents the raw document type returned by Motor operations
type MongoDocument = dict[str, Any]
-type DBClient = AsyncIOMotorClient[MongoDocument]
-type Database = AsyncIOMotorDatabase[MongoDocument]
-type Collection = AsyncIOMotorCollection[MongoDocument]
-type Cursor = AsyncIOMotorCursor[MongoDocument]
-type DBSession = AsyncIOMotorClientSession
-
-# Type variable for generic database provider
-T = TypeVar("T")
-
-
-class DatabaseError(Exception):
- pass
-
-
-class DatabaseNotInitializedError(DatabaseError):
- """Raised when attempting to use database before initialization."""
-
- pass
-
-
-class DatabaseAlreadyInitializedError(DatabaseError):
- """Raised when attempting to initialize an already initialized database."""
-
- pass
-
-
-@dataclass(frozen=True)
-class DatabaseConfig:
- mongodb_url: str
- db_name: str
- server_selection_timeout_ms: int = 5000
- connect_timeout_ms: int = 10000
- max_pool_size: int = 100
- min_pool_size: int = 10
- retry_writes: bool = True
- retry_reads: bool = True
- write_concern: str = "majority"
- journal: bool = True
-
-
-@runtime_checkable
-class DatabaseProvider(Protocol):
- @property
- def client(self) -> DBClient:
- """Get the MongoDB client."""
- ...
-
- @property
- def database(self) -> Database:
- """Get the database instance."""
- ...
-
- @property
- def db_name(self) -> str:
- """Get the database name."""
- ...
-
- def is_initialized(self) -> bool:
- """Check if the provider is initialized."""
- ...
-
- def session(self) -> AsyncContextManager[DBSession]:
- """Create a database session for transactions."""
- ...
-
-
-class AsyncDatabaseConnection:
- __slots__ = ("_client", "_database", "_db_name", "_config")
-
- def __init__(self, config: DatabaseConfig) -> None:
- self._config = config
- self._client: DBClient | None = None
- self._database: Database | None = None
- self._db_name: str = config.db_name
-
- async def connect(self) -> None:
- """
- Establish connection to MongoDB.
-
- Raises:
- DatabaseAlreadyInitializedError: If already connected
- ServerSelectionTimeoutError: If cannot connect to MongoDB
- """
- if self._client is not None:
- raise DatabaseAlreadyInitializedError("Connection already established")
-
- logger.info(f"Connecting to MongoDB database: {self._db_name}")
-
- # Always explicitly bind to current event loop for consistency
- import asyncio
-
- client: DBClient = AsyncIOMotorClient(
- self._config.mongodb_url,
- serverSelectionTimeoutMS=self._config.server_selection_timeout_ms,
- connectTimeoutMS=self._config.connect_timeout_ms,
- maxPoolSize=self._config.max_pool_size,
- minPoolSize=self._config.min_pool_size,
- retryWrites=self._config.retry_writes,
- retryReads=self._config.retry_reads,
- w=self._config.write_concern,
- journal=self._config.journal,
- io_loop=asyncio.get_running_loop(), # Always bind to current loop
- )
-
- # Verify connection
- try:
- await client.admin.command("ping")
- logger.info("Successfully connected to MongoDB")
- except ServerSelectionTimeoutError as e:
- logger.error(f"Failed to connect to MongoDB: {e}")
- client.close()
- raise
-
- self._client = client
- self._database = client[self._db_name]
-
- async def disconnect(self) -> None:
- if self._client is not None:
- logger.info("Closing MongoDB connection")
- self._client.close()
- self._client = None
- self._database = None
-
- @property
- def client(self) -> DBClient:
- if self._client is None:
- raise DatabaseNotInitializedError("Database connection not established")
- return self._client
-
- @property
- def database(self) -> Database:
- if self._database is None:
- raise DatabaseNotInitializedError("Database connection not established")
- return self._database
-
- @property
- def db_name(self) -> str:
- return self._db_name
-
- def is_connected(self) -> bool:
- return self._client is not None
-
- @asynccontextmanager
- async def session(self) -> AsyncIterator[DBSession]:
- """
- Create a database session for transactions.
-
- Yields:
- Database session for use in transactions
-
- Example:
- async with connection.session() as session:
- await collection.insert_one(doc, session=session)
- """
- async with await self.client.start_session() as session:
- async with session.start_transaction():
- yield session
-
-
-class ContextualDatabaseProvider(DatabaseProvider):
- def __init__(self) -> None:
- self._connection_var: contextvars.ContextVar[AsyncDatabaseConnection | None] = contextvars.ContextVar(
- "db_connection", default=None
- )
-
- def set_connection(self, connection: AsyncDatabaseConnection) -> None:
- self._connection_var.set(connection)
-
- def clear_connection(self) -> None:
- self._connection_var.set(None)
-
- @property
- def _connection(self) -> AsyncDatabaseConnection:
- connection = self._connection_var.get()
- if connection is None:
- raise DatabaseNotInitializedError(
- "No database connection in current context. Ensure connection is set in the request lifecycle."
- )
- return connection
-
- @property
- def client(self) -> DBClient:
- return self._connection.client
-
- @property
- def database(self) -> Database:
- return self._connection.database
-
- @property
- def db_name(self) -> str:
- return self._connection.db_name
-
- def is_initialized(self) -> bool:
- connection = self._connection_var.get()
- return connection is not None and connection.is_connected()
-
- def session(self) -> AsyncContextManager[DBSession]:
- return self._connection.session()
-
-
-class DatabaseConnectionPool:
- def __init__(self) -> None:
- self._connections: dict[str, AsyncDatabaseConnection] = {}
-
- async def create_connection(self, key: str, config: DatabaseConfig) -> AsyncDatabaseConnection:
- """
- Create and store a new database connection.
-
- Args:
- key: Unique identifier for this connection
- config: Database configuration
-
- Returns:
- The created connection
-
- Raises:
- DatabaseAlreadyInitializedError: If key already exists
- """
- if key in self._connections:
- raise DatabaseAlreadyInitializedError(f"Connection '{key}' already exists")
-
- connection = AsyncDatabaseConnection(config)
- await connection.connect()
- self._connections[key] = connection
- return connection
-
- def get_connection(self, key: str) -> AsyncDatabaseConnection:
- """
- Get a connection by key.
-
- Raises:
- KeyError: If connection not found
- """
- return self._connections[key]
-
- async def close_connection(self, key: str) -> None:
- if key in self._connections:
- await self._connections[key].disconnect()
- del self._connections[key]
-
- async def close_all(self) -> None:
- for connection in self._connections.values():
- await connection.disconnect()
- self._connections.clear()
-
-
-# Factory functions for dependency injection
-def create_database_connection(config: DatabaseConfig) -> AsyncDatabaseConnection:
- return AsyncDatabaseConnection(config)
-
-
-def create_contextual_provider() -> ContextualDatabaseProvider:
- return ContextualDatabaseProvider()
-
-
-def create_connection_pool() -> DatabaseConnectionPool:
- return DatabaseConnectionPool()
-
-
-async def get_database_provider() -> DatabaseProvider:
- raise RuntimeError("Database provider not configured. This dependency should be overridden in app startup.")
+type DBClient = AsyncMongoClient[MongoDocument]
+type Database = AsyncDatabase[MongoDocument]
+type Collection = AsyncCollection[MongoDocument]
+type Cursor = AsyncCursor[MongoDocument]
+type DBSession = AsyncClientSession
diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py
index 60017e52..038fd18d 100644
--- a/backend/app/core/dishka_lifespan.py
+++ b/backend/app/core/dishka_lifespan.py
@@ -1,15 +1,16 @@
+import logging
from contextlib import AsyncExitStack, asynccontextmanager
from typing import AsyncGenerator
import redis.asyncio as redis
+from beanie import init_beanie
from dishka import AsyncContainer
from fastapi import FastAPI
from app.core.database_context import Database
-from app.core.logging import logger
from app.core.startup import initialize_metrics_context, initialize_rate_limits
from app.core.tracing import init_tracing
-from app.db.schema.schema_manager import SchemaManager
+from app.db.docs import ALL_DOCUMENTS
from app.events.event_store_consumer import EventStoreConsumer
from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas
from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge
@@ -27,6 +28,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
- Dishka handles all lifecycle automatically
"""
settings = get_settings()
+
+ # Get logger from DI container
+ container: AsyncContainer = app.state.dishka_container
+ logger = await container.get(logging.Logger)
+
logger.info(
"Starting application with dishka DI",
extra={
@@ -41,6 +47,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Initialize tracing
instrumentation_report = init_tracing(
service_name=settings.TRACING_SERVICE_NAME,
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
sampling_rate=settings.TRACING_SAMPLING_RATE,
enable_console_exporter=settings.TESTING,
@@ -59,24 +66,22 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
)
# Initialize schema registry once at startup
- container: AsyncContainer = app.state.dishka_container
schema_registry = await container.get(SchemaRegistryManager)
await initialize_event_schemas(schema_registry)
- # Initialize database schema at application scope using app-scoped DB
+ # Initialize Beanie ODM with database from DI container
database = await container.get(Database)
- schema_manager = SchemaManager(database)
- await schema_manager.apply_all()
- logger.info("Database schema ensured by SchemaManager")
+ await init_beanie(database=database, document_models=ALL_DOCUMENTS)
+ logger.info(f"Beanie ODM initialized with {len(ALL_DOCUMENTS)} document models")
# Initialize metrics context with instances from DI container
# This must happen early so services can access metrics via contextvars
- await initialize_metrics_context(container)
+ await initialize_metrics_context(container, logger)
logger.info("Metrics context initialized with contextvars")
# Initialize default rate limits in Redis
redis_client = await container.get(redis.Redis)
- await initialize_rate_limits(redis_client, settings)
+ await initialize_rate_limits(redis_client, settings, logger)
logger.info("Rate limits initialized in Redis")
# Rate limit middleware added during app creation; service resolved lazily at runtime
diff --git a/backend/app/core/exceptions/__init__.py b/backend/app/core/exceptions/__init__.py
index 30a7fd2d..6c67f136 100644
--- a/backend/app/core/exceptions/__init__.py
+++ b/backend/app/core/exceptions/__init__.py
@@ -1,13 +1,25 @@
-from app.core.exceptions.base import AuthenticationError, IntegrationException, ServiceError
-
-# Import handler configuration function
from app.core.exceptions.handlers import configure_exception_handlers
+from app.domain.exceptions import (
+ ConflictError,
+ DomainError,
+ ForbiddenError,
+ InfrastructureError,
+ InvalidStateError,
+ NotFoundError,
+ ThrottledError,
+ UnauthorizedError,
+ ValidationError,
+)
__all__ = [
- # Exception classes
- "IntegrationException",
- "AuthenticationError",
- "ServiceError",
- # Configuration function
+ "ConflictError",
+ "DomainError",
+ "ForbiddenError",
+ "InfrastructureError",
+ "InvalidStateError",
+ "NotFoundError",
+ "ThrottledError",
+ "UnauthorizedError",
+ "ValidationError",
"configure_exception_handlers",
]
diff --git a/backend/app/core/exceptions/base.py b/backend/app/core/exceptions/base.py
deleted file mode 100644
index beeab159..00000000
--- a/backend/app/core/exceptions/base.py
+++ /dev/null
@@ -1,24 +0,0 @@
-class IntegrationException(Exception):
- """Exception raised for integration errors."""
-
- def __init__(self, status_code: int, detail: str) -> None:
- self.status_code = status_code
- self.detail = detail
- super().__init__(detail)
-
-
-class AuthenticationError(Exception):
- """Exception raised for authentication errors."""
-
- def __init__(self, detail: str) -> None:
- self.detail = detail
- super().__init__(detail)
-
-
-class ServiceError(Exception):
- """Exception raised for service-related errors."""
-
- def __init__(self, message: str, status_code: int = 500) -> None:
- self.message = message
- self.status_code = status_code
- super().__init__(message)
diff --git a/backend/app/core/exceptions/handlers.py b/backend/app/core/exceptions/handlers.py
index 0f57f97b..0d7e2e0a 100644
--- a/backend/app/core/exceptions/handlers.py
+++ b/backend/app/core/exceptions/handlers.py
@@ -1,44 +1,44 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
-from app.core.exceptions.base import AuthenticationError, IntegrationException, ServiceError
-from app.domain.saga.exceptions import (
- SagaAccessDeniedError,
- SagaInvalidStateError,
- SagaNotFoundError,
+from app.domain.exceptions import (
+ ConflictError,
+ DomainError,
+ ForbiddenError,
+ InfrastructureError,
+ InvalidStateError,
+ NotFoundError,
+ ThrottledError,
+ UnauthorizedError,
+ ValidationError,
)
def configure_exception_handlers(app: FastAPI) -> None:
- @app.exception_handler(IntegrationException)
- async def integration_exception_handler(request: Request, exc: IntegrationException) -> JSONResponse:
+ @app.exception_handler(DomainError)
+ async def domain_error_handler(request: Request, exc: DomainError) -> JSONResponse:
+ status_code = _map_to_status_code(exc)
return JSONResponse(
- status_code=exc.status_code,
- content={"detail": exc.detail},
+ status_code=status_code,
+ content={"detail": exc.message, "type": type(exc).__name__},
)
- @app.exception_handler(AuthenticationError)
- async def authentication_error_handler(request: Request, exc: AuthenticationError) -> JSONResponse:
- return JSONResponse(
- status_code=401,
- content={"detail": exc.detail},
- )
-
- @app.exception_handler(ServiceError)
- async def service_error_handler(request: Request, exc: ServiceError) -> JSONResponse:
- return JSONResponse(
- status_code=exc.status_code,
- content={"detail": exc.message},
- )
-
- @app.exception_handler(SagaNotFoundError)
- async def saga_not_found_handler(request: Request, exc: SagaNotFoundError) -> JSONResponse:
- return JSONResponse(status_code=404, content={"detail": "Saga not found"})
-
- @app.exception_handler(SagaAccessDeniedError)
- async def saga_access_denied_handler(request: Request, exc: SagaAccessDeniedError) -> JSONResponse:
- return JSONResponse(status_code=403, content={"detail": "Access denied"})
- @app.exception_handler(SagaInvalidStateError)
- async def saga_invalid_state_handler(request: Request, exc: SagaInvalidStateError) -> JSONResponse:
- return JSONResponse(status_code=400, content={"detail": str(exc)})
+def _map_to_status_code(exc: DomainError) -> int:
+ if isinstance(exc, NotFoundError):
+ return 404
+ if isinstance(exc, ValidationError):
+ return 422
+ if isinstance(exc, ThrottledError):
+ return 429
+ if isinstance(exc, ConflictError):
+ return 409
+ if isinstance(exc, UnauthorizedError):
+ return 401
+ if isinstance(exc, ForbiddenError):
+ return 403
+ if isinstance(exc, InvalidStateError):
+ return 400
+ if isinstance(exc, InfrastructureError):
+ return 500
+ return 500
diff --git a/backend/app/core/k8s_clients.py b/backend/app/core/k8s_clients.py
index ba953f0d..2a475df3 100644
--- a/backend/app/core/k8s_clients.py
+++ b/backend/app/core/k8s_clients.py
@@ -1,10 +1,9 @@
+import logging
from dataclasses import dataclass
from kubernetes import client as k8s_client
from kubernetes import config as k8s_config
-from app.core.logging import logger
-
@dataclass(frozen=True)
class K8sClients:
@@ -14,7 +13,9 @@ class K8sClients:
networking_v1: k8s_client.NetworkingV1Api
-def create_k8s_clients(kubeconfig_path: str | None = None, in_cluster: bool | None = None) -> K8sClients:
+def create_k8s_clients(
+ logger: logging.Logger, kubeconfig_path: str | None = None, in_cluster: bool | None = None
+) -> K8sClients:
if in_cluster:
k8s_config.load_incluster_config()
elif kubeconfig_path:
diff --git a/backend/app/core/logging.py b/backend/app/core/logging.py
index 99a0b5c3..45fcc24b 100644
--- a/backend/app/core/logging.py
+++ b/backend/app/core/logging.py
@@ -7,8 +7,6 @@
from opentelemetry import trace
-from app.settings import get_settings
-
correlation_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar("correlation_id", default=None)
request_metadata_context: contextvars.ContextVar[Dict[str, Any] | None] = contextvars.ContextVar(
@@ -100,9 +98,19 @@ def format(self, record: logging.LogRecord) -> str:
return json.dumps(log_data, ensure_ascii=False)
-def setup_logger() -> logging.Logger:
- logger = logging.getLogger("integr8scode")
- logger.handlers.clear()
+LOG_LEVELS: dict[str, int] = {
+ "DEBUG": logging.DEBUG,
+ "INFO": logging.INFO,
+ "WARNING": logging.WARNING,
+ "ERROR": logging.ERROR,
+ "CRITICAL": logging.CRITICAL,
+}
+
+
+def setup_logger(log_level: str) -> logging.Logger:
+ """Create and configure the application logger. Called by DI with Settings.LOG_LEVEL."""
+ new_logger = logging.getLogger("integr8scode")
+ new_logger.handlers.clear()
console_handler = logging.StreamHandler()
formatter = JSONFormatter()
@@ -131,15 +139,9 @@ def filter(self, record: logging.LogRecord) -> bool:
console_handler.addFilter(TracingFilter())
- logger.addHandler(console_handler)
-
- # Get log level from configuration
- settings = get_settings()
- log_level_name = settings.LOG_LEVEL.upper()
- log_level = getattr(logging, log_level_name, logging.DEBUG)
- logger.setLevel(log_level)
-
- return logger
+ new_logger.addHandler(console_handler)
+ level = LOG_LEVELS.get(log_level.upper(), logging.DEBUG)
+ new_logger.setLevel(level)
-logger = setup_logger()
+ return new_logger
diff --git a/backend/app/core/metrics/context.py b/backend/app/core/metrics/context.py
index 1f382cbe..54a88e60 100644
--- a/backend/app/core/metrics/context.py
+++ b/backend/app/core/metrics/context.py
@@ -1,7 +1,7 @@
import contextvars
+import logging
from typing import Any, Generic, Optional, Type, TypeVar
-from app.core.logging import logger
from app.core.metrics import (
ConnectionMetrics,
CoordinatorMetrics,
@@ -29,17 +29,19 @@ class MetricsContextVar(Generic[T]):
and provides a clean interface for getting and setting metrics.
"""
- def __init__(self, name: str, metric_class: Type[T]) -> None:
+ def __init__(self, name: str, metric_class: Type[T], logger: logging.Logger) -> None:
"""
Initialize a metrics context variable.
Args:
name: Name for the context variable (for debugging)
metric_class: The class of the metric this context holds
+ logger: Logger instance for logging
"""
self._context_var: contextvars.ContextVar[Optional[T]] = contextvars.ContextVar(f"metrics_{name}", default=None)
self._metric_class = metric_class
self._name = name
+ self.logger = logger
def get(self) -> T:
"""
@@ -55,7 +57,7 @@ def get(self) -> T:
metric = self._context_var.get()
if metric is None:
# Lazy initialization with logging
- logger.debug(f"Lazy initializing {self._name} metrics in context")
+ self.logger.debug(f"Lazy initializing {self._name} metrics in context")
metric = self._metric_class()
self._context_var.set(metric)
return metric
@@ -81,20 +83,32 @@ def is_set(self) -> bool:
return self._context_var.get() is not None
+# Module-level logger for lazy initialization
+_module_logger: Optional[logging.Logger] = None
+
+
+def _get_module_logger() -> logging.Logger:
+ """Get or create module logger for lazy initialization."""
+ global _module_logger
+ if _module_logger is None:
+ _module_logger = logging.getLogger(__name__)
+ return _module_logger
+
+
# Create module-level context variables for each metric type
# These are singletons that live for the lifetime of the application
-_connection_ctx = MetricsContextVar("connection", ConnectionMetrics)
-_coordinator_ctx = MetricsContextVar("coordinator", CoordinatorMetrics)
-_database_ctx = MetricsContextVar("database", DatabaseMetrics)
-_dlq_ctx = MetricsContextVar("dlq", DLQMetrics)
-_event_ctx = MetricsContextVar("event", EventMetrics)
-_execution_ctx = MetricsContextVar("execution", ExecutionMetrics)
-_health_ctx = MetricsContextVar("health", HealthMetrics)
-_kubernetes_ctx = MetricsContextVar("kubernetes", KubernetesMetrics)
-_notification_ctx = MetricsContextVar("notification", NotificationMetrics)
-_rate_limit_ctx = MetricsContextVar("rate_limit", RateLimitMetrics)
-_replay_ctx = MetricsContextVar("replay", ReplayMetrics)
-_security_ctx = MetricsContextVar("security", SecurityMetrics)
+_connection_ctx = MetricsContextVar("connection", ConnectionMetrics, _get_module_logger())
+_coordinator_ctx = MetricsContextVar("coordinator", CoordinatorMetrics, _get_module_logger())
+_database_ctx = MetricsContextVar("database", DatabaseMetrics, _get_module_logger())
+_dlq_ctx = MetricsContextVar("dlq", DLQMetrics, _get_module_logger())
+_event_ctx = MetricsContextVar("event", EventMetrics, _get_module_logger())
+_execution_ctx = MetricsContextVar("execution", ExecutionMetrics, _get_module_logger())
+_health_ctx = MetricsContextVar("health", HealthMetrics, _get_module_logger())
+_kubernetes_ctx = MetricsContextVar("kubernetes", KubernetesMetrics, _get_module_logger())
+_notification_ctx = MetricsContextVar("notification", NotificationMetrics, _get_module_logger())
+_rate_limit_ctx = MetricsContextVar("rate_limit", RateLimitMetrics, _get_module_logger())
+_replay_ctx = MetricsContextVar("replay", ReplayMetrics, _get_module_logger())
+_security_ctx = MetricsContextVar("security", SecurityMetrics, _get_module_logger())
class MetricsContext:
@@ -107,7 +121,7 @@ class MetricsContext:
"""
@classmethod
- def initialize_all(cls, **metrics: Any) -> None:
+ def initialize_all(cls, logger: logging.Logger, **metrics: Any) -> None:
"""
Initialize all metrics contexts at application startup.
@@ -150,7 +164,7 @@ def initialize_all(cls, **metrics: Any) -> None:
logger.info(f"Initialized {name} metrics in context")
@classmethod
- def reset_all(cls) -> None:
+ def reset_all(cls, logger: logging.Logger) -> None:
"""
Reset all metrics contexts.
diff --git a/backend/app/core/middlewares/metrics.py b/backend/app/core/middlewares/metrics.py
index 562b8489..3513bee4 100644
--- a/backend/app/core/middlewares/metrics.py
+++ b/backend/app/core/middlewares/metrics.py
@@ -1,3 +1,4 @@
+import logging
import os
import re
import time
@@ -12,7 +13,6 @@
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_VERSION, Resource
from starlette.types import ASGIApp, Message, Receive, Scope, Send
-from app.core.logging import logger
from app.settings import get_settings
@@ -118,7 +118,7 @@ def _get_path_template(path: str) -> str:
return path
-def setup_metrics(app: FastAPI) -> None:
+def setup_metrics(app: FastAPI, logger: logging.Logger) -> None:
"""Set up OpenTelemetry metrics with OTLP exporter."""
settings = get_settings()
# Fast opt-out for tests or when explicitly disabled
diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py
index ddc4937d..0ab818d4 100644
--- a/backend/app/core/providers.py
+++ b/backend/app/core/providers.py
@@ -1,16 +1,13 @@
+import logging
from typing import AsyncIterator
import redis.asyncio as redis
from dishka import Provider, Scope, provide
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
-from app.core.database_context import (
- AsyncDatabaseConnection,
- Database,
- DatabaseConfig,
- create_database_connection,
-)
+from app.core.database_context import Database
from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients
-from app.core.logging import logger
+from app.core.logging import setup_logger
from app.core.metrics import (
CoordinatorMetrics,
DatabaseMetrics,
@@ -83,37 +80,19 @@ def get_settings(self) -> Settings:
return get_settings()
-class DatabaseProvider(Provider):
+class LoggingProvider(Provider):
scope = Scope.APP
- @provide(scope=Scope.APP)
- async def get_database_connection(self, settings: Settings) -> AsyncIterator[AsyncDatabaseConnection]:
- db_config = DatabaseConfig(
- mongodb_url=settings.MONGODB_URL,
- db_name=settings.DATABASE_NAME,
- server_selection_timeout_ms=5000,
- connect_timeout_ms=5000,
- max_pool_size=50,
- min_pool_size=10,
- )
-
- db_connection = create_database_connection(db_config)
- await db_connection.connect()
- try:
- yield db_connection
- finally:
- await db_connection.disconnect()
-
@provide
- def get_database(self, db_connection: AsyncDatabaseConnection) -> Database:
- return db_connection.database
+ def get_logger(self, settings: Settings) -> logging.Logger:
+ return setup_logger(settings.LOG_LEVEL)
class RedisProvider(Provider):
scope = Scope.APP
@provide
- async def get_redis_client(self, settings: Settings) -> AsyncIterator[redis.Redis]:
+ async def get_redis_client(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[redis.Redis]:
# Create Redis client - it will automatically use the current event loop
client = redis.Redis(
host=settings.REDIS_HOST,
@@ -141,6 +120,22 @@ def get_rate_limit_service(
return RateLimitService(redis_client, settings, rate_limit_metrics)
+class DatabaseProvider(Provider):
+ scope = Scope.APP
+
+ @provide
+ async def get_database(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[Database]:
+ client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient(
+ settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000
+ )
+ database = client[settings.DATABASE_NAME]
+ logger.info(f"MongoDB connected: {settings.DATABASE_NAME}")
+ try:
+ yield database
+ finally:
+ await client.close()
+
+
class CoreServicesProvider(Provider):
scope = Scope.APP
@@ -154,10 +149,10 @@ class MessagingProvider(Provider):
@provide
async def get_kafka_producer(
- self, settings: Settings, schema_registry: SchemaRegistryManager
+ self, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger
) -> AsyncIterator[UnifiedProducer]:
config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS)
- producer = UnifiedProducer(config, schema_registry)
+ producer = UnifiedProducer(config, schema_registry, logger)
await producer.start()
try:
yield producer
@@ -165,8 +160,10 @@ async def get_kafka_producer(
await producer.stop()
@provide
- async def get_dlq_manager(self, database: Database) -> AsyncIterator[DLQManager]:
- manager = create_dlq_manager(database)
+ async def get_dlq_manager(
+ self, schema_registry: SchemaRegistryManager, logger: logging.Logger
+ ) -> AsyncIterator[DLQManager]:
+ manager = create_dlq_manager(schema_registry, logger)
await manager.start()
try:
yield manager
@@ -178,8 +175,10 @@ def get_idempotency_repository(self, redis_client: redis.Redis) -> RedisIdempote
return RedisIdempotencyRepository(redis_client, key_prefix="idempotency")
@provide
- async def get_idempotency_manager(self, repo: RedisIdempotencyRepository) -> AsyncIterator[IdempotencyManager]:
- manager = create_idempotency_manager(repository=repo, config=IdempotencyConfig())
+ async def get_idempotency_manager(
+ self, repo: RedisIdempotencyRepository, logger: logging.Logger
+ ) -> AsyncIterator[IdempotencyManager]:
+ manager = create_idempotency_manager(repository=repo, config=IdempotencyConfig(), logger=logger)
await manager.initialize()
try:
yield manager
@@ -191,17 +190,21 @@ class EventProvider(Provider):
scope = Scope.APP
@provide
- def get_schema_registry(self) -> SchemaRegistryManager:
- return create_schema_registry_manager()
+ def get_schema_registry(self, logger: logging.Logger) -> SchemaRegistryManager:
+ return create_schema_registry_manager(logger)
@provide
- async def get_event_store(self, database: Database, schema_registry: SchemaRegistryManager) -> EventStore:
- store = create_event_store(db=database, schema_registry=schema_registry, ttl_days=90)
+ async def get_event_store(self, schema_registry: SchemaRegistryManager, logger: logging.Logger) -> EventStore:
+ store = create_event_store(schema_registry=schema_registry, logger=logger, ttl_days=90)
return store
@provide
async def get_event_store_consumer(
- self, event_store: EventStore, schema_registry: SchemaRegistryManager, kafka_producer: UnifiedProducer
+ self,
+ event_store: EventStore,
+ schema_registry: SchemaRegistryManager,
+ kafka_producer: UnifiedProducer,
+ logger: logging.Logger,
) -> EventStoreConsumer:
topics = get_all_topics()
return create_event_store_consumer(
@@ -209,11 +212,12 @@ async def get_event_store_consumer(
topics=list(topics),
schema_registry_manager=schema_registry,
producer=kafka_producer,
+ logger=logger,
)
@provide
- async def get_event_bus_manager(self) -> AsyncIterator[EventBusManager]:
- manager = EventBusManager()
+ async def get_event_bus_manager(self, logger: logging.Logger) -> AsyncIterator[EventBusManager]:
+ manager = EventBusManager(logger)
try:
yield manager
finally:
@@ -224,8 +228,8 @@ class KubernetesProvider(Provider):
scope = Scope.APP
@provide
- async def get_k8s_clients(self, settings: Settings) -> AsyncIterator[K8sClients]:
- clients = create_k8s_clients()
+ async def get_k8s_clients(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[K8sClients]:
+ clients = create_k8s_clients(logger)
try:
yield clients
finally:
@@ -287,8 +291,8 @@ def get_security_metrics(self) -> SecurityMetrics:
return SecurityMetrics()
@provide(scope=Scope.REQUEST)
- def get_sse_shutdown_manager(self) -> SSEShutdownManager:
- return create_sse_shutdown_manager()
+ def get_sse_shutdown_manager(self, logger: logging.Logger) -> SSEShutdownManager:
+ return create_sse_shutdown_manager(logger=logger)
@provide(scope=Scope.APP)
async def get_sse_kafka_redis_bridge(
@@ -297,21 +301,23 @@ async def get_sse_kafka_redis_bridge(
settings: Settings,
event_metrics: EventMetrics,
sse_redis_bus: SSERedisBus,
+ logger: logging.Logger,
) -> SSEKafkaRedisBridge:
return create_sse_kafka_redis_bridge(
schema_registry=schema_registry,
settings=settings,
event_metrics=event_metrics,
sse_bus=sse_redis_bus,
+ logger=logger,
)
@provide
- def get_sse_repository(self, database: Database) -> SSERepository:
- return SSERepository(database)
+ def get_sse_repository(self) -> SSERepository:
+ return SSERepository()
@provide
- async def get_sse_redis_bus(self, redis_client: redis.Redis) -> AsyncIterator[SSERedisBus]:
- bus = SSERedisBus(redis_client)
+ async def get_sse_redis_bus(self, redis_client: redis.Redis, logger: logging.Logger) -> AsyncIterator[SSERedisBus]:
+ bus = SSERedisBus(redis_client, logger)
yield bus
@provide(scope=Scope.REQUEST)
@@ -322,6 +328,7 @@ def get_sse_service(
sse_redis_bus: SSERedisBus,
shutdown_manager: SSEShutdownManager,
settings: Settings,
+ logger: logging.Logger,
) -> SSEService:
# Ensure shutdown manager coordinates with the router in this request scope
shutdown_manager.set_router(router)
@@ -331,6 +338,7 @@ def get_sse_service(
sse_bus=sse_redis_bus,
shutdown_manager=shutdown_manager,
settings=settings,
+ logger=logger,
)
@@ -338,24 +346,24 @@ class AuthProvider(Provider):
scope = Scope.APP
@provide
- def get_user_repository(self, database: Database) -> UserRepository:
- return UserRepository(database)
+ def get_user_repository(self) -> UserRepository:
+ return UserRepository()
@provide
- def get_auth_service(self, user_repository: UserRepository) -> AuthService:
- return AuthService(user_repository)
+ def get_auth_service(self, user_repository: UserRepository, logger: logging.Logger) -> AuthService:
+ return AuthService(user_repository, logger)
class UserServicesProvider(Provider):
scope = Scope.APP
@provide
- def get_user_settings_repository(self, database: Database) -> UserSettingsRepository:
- return UserSettingsRepository(database)
+ def get_user_settings_repository(self, logger: logging.Logger) -> UserSettingsRepository:
+ return UserSettingsRepository(logger)
@provide
- def get_event_repository(self, database: Database) -> EventRepository:
- return EventRepository(database)
+ def get_event_repository(self, logger: logging.Logger) -> EventRepository:
+ return EventRepository(logger)
@provide
async def get_event_service(self, event_repository: EventRepository) -> EventService:
@@ -363,9 +371,9 @@ async def get_event_service(self, event_repository: EventRepository) -> EventSer
@provide
async def get_kafka_event_service(
- self, event_repository: EventRepository, kafka_producer: UnifiedProducer
+ self, event_repository: EventRepository, kafka_producer: UnifiedProducer, logger: logging.Logger
) -> KafkaEventService:
- return KafkaEventService(event_repository=event_repository, kafka_producer=kafka_producer)
+ return KafkaEventService(event_repository=event_repository, kafka_producer=kafka_producer, logger=logger)
@provide
async def get_user_settings_service(
@@ -373,8 +381,9 @@ async def get_user_settings_service(
repository: UserSettingsRepository,
kafka_event_service: KafkaEventService,
event_bus_manager: EventBusManager,
+ logger: logging.Logger,
) -> UserSettingsService:
- service = UserSettingsService(repository, kafka_event_service)
+ service = UserSettingsService(repository, kafka_event_service, logger)
await service.initialize(event_bus_manager)
return service
@@ -383,39 +392,41 @@ class AdminServicesProvider(Provider):
scope = Scope.APP
@provide
- def get_admin_events_repository(self, database: Database) -> AdminEventsRepository:
- return AdminEventsRepository(database)
+ def get_admin_events_repository(self) -> AdminEventsRepository:
+ return AdminEventsRepository()
@provide(scope=Scope.REQUEST)
def get_admin_events_service(
self,
admin_events_repository: AdminEventsRepository,
replay_service: ReplayService,
+ logger: logging.Logger,
) -> AdminEventsService:
- return AdminEventsService(admin_events_repository, replay_service)
+ return AdminEventsService(admin_events_repository, replay_service, logger)
@provide
- def get_admin_settings_repository(self, database: Database) -> AdminSettingsRepository:
- return AdminSettingsRepository(database)
+ def get_admin_settings_repository(self, logger: logging.Logger) -> AdminSettingsRepository:
+ return AdminSettingsRepository(logger)
@provide
def get_admin_settings_service(
self,
admin_settings_repository: AdminSettingsRepository,
+ logger: logging.Logger,
) -> AdminSettingsService:
- return AdminSettingsService(admin_settings_repository)
+ return AdminSettingsService(admin_settings_repository, logger)
@provide
- def get_admin_user_repository(self, database: Database) -> AdminUserRepository:
- return AdminUserRepository(database)
+ def get_admin_user_repository(self) -> AdminUserRepository:
+ return AdminUserRepository()
@provide
- def get_saga_repository(self, database: Database) -> SagaRepository:
- return SagaRepository(database)
+ def get_saga_repository(self) -> SagaRepository:
+ return SagaRepository()
@provide
- def get_notification_repository(self, database: Database) -> NotificationRepository:
- return NotificationRepository(database)
+ def get_notification_repository(self, logger: logging.Logger) -> NotificationRepository:
+ return NotificationRepository(logger)
@provide
def get_notification_service(
@@ -426,6 +437,7 @@ def get_notification_service(
schema_registry: SchemaRegistryManager,
sse_redis_bus: SSERedisBus,
settings: Settings,
+ logger: logging.Logger,
) -> NotificationService:
service = NotificationService(
notification_repository=notification_repository,
@@ -434,6 +446,7 @@ def get_notification_service(
schema_registry_manager=schema_registry,
sse_bus=sse_redis_bus,
settings=settings,
+ logger=logger,
)
service.initialize()
return service
@@ -442,32 +455,33 @@ def get_notification_service(
def get_grafana_alert_processor(
self,
notification_service: NotificationService,
+ logger: logging.Logger,
) -> GrafanaAlertProcessor:
- return GrafanaAlertProcessor(notification_service)
+ return GrafanaAlertProcessor(notification_service, logger)
class BusinessServicesProvider(Provider):
scope = Scope.REQUEST
@provide
- def get_execution_repository(self, database: Database) -> ExecutionRepository:
- return ExecutionRepository(database)
+ def get_execution_repository(self, logger: logging.Logger) -> ExecutionRepository:
+ return ExecutionRepository(logger)
@provide
- def get_resource_allocation_repository(self, database: Database) -> ResourceAllocationRepository:
- return ResourceAllocationRepository(database)
+ def get_resource_allocation_repository(self) -> ResourceAllocationRepository:
+ return ResourceAllocationRepository()
@provide
- def get_saved_script_repository(self, database: Database) -> SavedScriptRepository:
- return SavedScriptRepository(database)
+ def get_saved_script_repository(self) -> SavedScriptRepository:
+ return SavedScriptRepository()
@provide
- def get_dlq_repository(self, database: Database) -> DLQRepository:
- return DLQRepository(database)
+ def get_dlq_repository(self, logger: logging.Logger) -> DLQRepository:
+ return DLQRepository(logger)
@provide
- def get_replay_repository(self, database: Database) -> ReplayRepository:
- return ReplayRepository(database)
+ def get_replay_repository(self, logger: logging.Logger) -> ReplayRepository:
+ return ReplayRepository(logger)
@provide
async def get_saga_orchestrator(
@@ -507,9 +521,13 @@ def get_saga_service(
saga_repository: SagaRepository,
execution_repository: ExecutionRepository,
saga_orchestrator: SagaOrchestrator,
+ logger: logging.Logger,
) -> SagaService:
return SagaService(
- saga_repo=saga_repository, execution_repo=execution_repository, orchestrator=saga_orchestrator
+ saga_repo=saga_repository,
+ execution_repo=execution_repository,
+ orchestrator=saga_orchestrator,
+ logger=logger,
)
@provide
@@ -519,23 +537,34 @@ def get_execution_service(
kafka_producer: UnifiedProducer,
event_store: EventStore,
settings: Settings,
+ logger: logging.Logger,
) -> ExecutionService:
return ExecutionService(
- execution_repo=execution_repository, producer=kafka_producer, event_store=event_store, settings=settings
+ execution_repo=execution_repository,
+ producer=kafka_producer,
+ event_store=event_store,
+ settings=settings,
+ logger=logger,
)
@provide
- def get_saved_script_service(self, saved_script_repository: SavedScriptRepository) -> SavedScriptService:
- return SavedScriptService(saved_script_repository)
+ def get_saved_script_service(
+ self, saved_script_repository: SavedScriptRepository, logger: logging.Logger
+ ) -> SavedScriptService:
+ return SavedScriptService(saved_script_repository, logger)
@provide
async def get_replay_service(
- self, replay_repository: ReplayRepository, kafka_producer: UnifiedProducer, event_store: EventStore
+ self,
+ replay_repository: ReplayRepository,
+ kafka_producer: UnifiedProducer,
+ event_store: EventStore,
+ logger: logging.Logger,
) -> ReplayService:
event_replay_service = EventReplayService(
- repository=replay_repository, producer=kafka_producer, event_store=event_store
+ repository=replay_repository, producer=kafka_producer, event_store=event_store, logger=logger
)
- return ReplayService(replay_repository, event_replay_service)
+ return ReplayService(replay_repository, event_replay_service, logger)
@provide
def get_admin_user_service(
@@ -544,12 +573,14 @@ def get_admin_user_service(
event_service: EventService,
execution_service: ExecutionService,
rate_limit_service: RateLimitService,
+ logger: logging.Logger,
) -> AdminUserService:
return AdminUserService(
user_repository=admin_user_repository,
event_service=event_service,
execution_service=execution_service,
rate_limit_service=rate_limit_service,
+ logger=logger,
)
@provide
@@ -560,6 +591,7 @@ async def get_execution_coordinator(
event_store: EventStore,
execution_repository: ExecutionRepository,
idempotency_manager: IdempotencyManager,
+ logger: logging.Logger,
) -> AsyncIterator[ExecutionCoordinator]:
coordinator = ExecutionCoordinator(
producer=kafka_producer,
@@ -567,6 +599,7 @@ async def get_execution_coordinator(
event_store=event_store,
execution_repository=execution_repository,
idempotency_manager=idempotency_manager,
+ logger=logger,
)
try:
yield coordinator
@@ -578,5 +611,5 @@ class ResultProcessorProvider(Provider):
scope = Scope.APP
@provide
- def get_execution_repository(self, database: Database) -> ExecutionRepository:
- return ExecutionRepository(database)
+ def get_execution_repository(self, logger: logging.Logger) -> ExecutionRepository:
+ return ExecutionRepository(logger)
diff --git a/backend/app/core/security.py b/backend/app/core/security.py
index 4f0bb130..1e6e0277 100644
--- a/backend/app/core/security.py
+++ b/backend/app/core/security.py
@@ -2,10 +2,11 @@
from typing import Any
import jwt
-from fastapi import HTTPException, Request, status
+from fastapi import Request
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
+from app.domain.user import AuthenticationRequiredError, CSRFValidationError, InvalidCredentialsError
from app.domain.user import User as DomainAdminUser
from app.settings import get_settings
@@ -15,11 +16,7 @@
def get_token_from_cookie(request: Request) -> str:
token = request.cookies.get("access_token")
if not token:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Authentication token not found",
- headers={"WWW-Authenticate": "Bearer"},
- )
+ raise AuthenticationRequiredError("Authentication token not found")
return token
@@ -46,21 +43,16 @@ async def get_current_user(
token: str,
user_repo: Any, # Avoid circular import by using Any
) -> DomainAdminUser:
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
try:
payload = jwt.decode(token, self.settings.SECRET_KEY, algorithms=[self.settings.ALGORITHM])
username: str = payload.get("sub")
if username is None:
- raise credentials_exception
+ raise InvalidCredentialsError()
except jwt.PyJWTError as e:
- raise credentials_exception from e
+ raise InvalidCredentialsError() from e
user = await user_repo.get_user(username)
if user is None:
- raise credentials_exception
+ raise InvalidCredentialsError()
return user # type: ignore[no-any-return]
def generate_csrf_token(self) -> str:
@@ -107,9 +99,9 @@ def validate_csrf_token(request: Request) -> str:
cookie_token = request.cookies.get("csrf_token", "")
if not header_token:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token missing")
+ raise CSRFValidationError("CSRF token missing")
if not security_service.validate_csrf_token(header_token, cookie_token):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token invalid")
+ raise CSRFValidationError("CSRF token invalid")
return header_token
diff --git a/backend/app/core/startup.py b/backend/app/core/startup.py
index fccdfb86..afabada3 100644
--- a/backend/app/core/startup.py
+++ b/backend/app/core/startup.py
@@ -1,7 +1,8 @@
+import logging
+
import redis.asyncio as redis
from dishka import AsyncContainer
-from app.core.logging import logger
from app.core.metrics import (
ConnectionMetrics,
CoordinatorMetrics,
@@ -22,7 +23,7 @@
from app.settings import Settings
-async def initialize_metrics_context(container: AsyncContainer) -> None:
+async def initialize_metrics_context(container: AsyncContainer, logger: logging.Logger) -> None:
try:
# Get all metrics from the container
# These are created as APP-scoped singletons by providers
@@ -44,7 +45,7 @@ async def initialize_metrics_context(container: AsyncContainer) -> None:
metrics_mapping["security"] = await container.get(SecurityMetrics)
# Initialize the context with available metrics
- MetricsContext.initialize_all(**metrics_mapping)
+ MetricsContext.initialize_all(logger=logger, **metrics_mapping)
logger.info(f"Initialized metrics context with {len(metrics_mapping)} metric types")
@@ -54,7 +55,7 @@ async def initialize_metrics_context(container: AsyncContainer) -> None:
# The context will lazy-initialize metrics as needed
-async def initialize_rate_limits(redis_client: redis.Redis, settings: Settings) -> None:
+async def initialize_rate_limits(redis_client: redis.Redis, settings: Settings, logger: logging.Logger) -> None:
"""
Initialize default rate limits in Redis on application startup.
This ensures default limits are always available.
diff --git a/backend/app/core/tracing/config.py b/backend/app/core/tracing/config.py
index 379ed081..4eeb74c2 100644
--- a/backend/app/core/tracing/config.py
+++ b/backend/app/core/tracing/config.py
@@ -1,3 +1,4 @@
+import logging
import os
from opentelemetry import trace
@@ -14,7 +15,6 @@
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from app.core.adaptive_sampling import create_adaptive_sampler
-from app.core.logging import logger
from app.core.tracing.models import (
InstrumentationReport,
InstrumentationResult,
@@ -87,9 +87,10 @@ def _get_environment(self) -> str:
class TracingInitializer:
"""Initializes OpenTelemetry tracing with instrumentation."""
- def __init__(self, config: TracingConfiguration) -> None:
+ def __init__(self, config: TracingConfiguration, logger: logging.Logger) -> None:
self.config = config
self.instrumentation_report = InstrumentationReport()
+ self.logger = logger
def initialize(self) -> InstrumentationReport:
"""Initialize tracing and instrument libraries."""
@@ -100,7 +101,7 @@ def initialize(self) -> InstrumentationReport:
self._instrument_libraries()
- logger.info(
+ self.logger.info(
f"OpenTelemetry tracing initialized for {self.config.service_name}",
extra={"instrumentation_summary": self.instrumentation_report.get_summary()},
)
@@ -167,7 +168,7 @@ def _instrument_library(self, lib: LibraryInstrumentation) -> InstrumentationRes
lib.instrumentor.instrument(**lib.config)
return InstrumentationResult(library=lib.name, status=InstrumentationStatus.SUCCESS)
except Exception as e:
- logger.warning(
+ self.logger.warning(
f"Failed to instrument {lib.name}", exc_info=True, extra={"library": lib.name, "error": str(e)}
)
return InstrumentationResult(library=lib.name, status=InstrumentationStatus.FAILED, error=e)
@@ -175,6 +176,7 @@ def _instrument_library(self, lib: LibraryInstrumentation) -> InstrumentationRes
def init_tracing(
service_name: str,
+ logger: logging.Logger,
service_version: str = "1.0.0",
otlp_endpoint: str | None = None,
enable_console_exporter: bool = False,
@@ -191,5 +193,5 @@ def init_tracing(
adaptive_sampling=adaptive_sampling,
)
- initializer = TracingInitializer(config)
+ initializer = TracingInitializer(config, logger)
return initializer.initialize()
diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py
index 110071b1..3203204b 100644
--- a/backend/app/db/__init__.py
+++ b/backend/app/db/__init__.py
@@ -11,7 +11,6 @@
UserRepository,
UserSettingsRepository,
)
-from app.db.schema.schema_manager import SchemaManager
__all__ = [
"AdminSettingsRepository",
@@ -25,5 +24,4 @@
"SSERepository",
"UserRepository",
"UserSettingsRepository",
- "SchemaManager",
]
diff --git a/backend/app/db/docs/__init__.py b/backend/app/db/docs/__init__.py
new file mode 100644
index 00000000..8d23bbc0
--- /dev/null
+++ b/backend/app/db/docs/__init__.py
@@ -0,0 +1,93 @@
+from app.db.docs.admin_settings import (
+ AuditLogDocument,
+ ExecutionLimitsConfig,
+ MonitoringSettingsConfig,
+ SecuritySettingsConfig,
+ SystemSettingsDocument,
+)
+from app.db.docs.dlq import DLQMessageDocument
+from app.db.docs.event import (
+ EventArchiveDocument,
+ EventDocument,
+ EventStoreDocument,
+)
+from app.db.docs.execution import ExecutionDocument, ResourceUsage
+from app.db.docs.notification import (
+ NotificationDocument,
+ NotificationSubscriptionDocument,
+)
+from app.db.docs.replay import (
+ ReplayConfig,
+ ReplayFilter,
+ ReplaySessionDocument,
+)
+from app.db.docs.resource import ResourceAllocationDocument
+from app.db.docs.saga import SagaDocument
+from app.db.docs.saved_script import SavedScriptDocument
+from app.db.docs.user import UserDocument
+from app.db.docs.user_settings import (
+ EditorSettings,
+ NotificationSettings,
+ UserSettingsDocument,
+ UserSettingsSnapshotDocument,
+)
+
+# All document classes that need to be initialized with Beanie
+ALL_DOCUMENTS = [
+ UserDocument,
+ ExecutionDocument,
+ SavedScriptDocument,
+ NotificationDocument,
+ NotificationSubscriptionDocument,
+ UserSettingsDocument,
+ UserSettingsSnapshotDocument,
+ SagaDocument,
+ DLQMessageDocument,
+ EventDocument,
+ EventStoreDocument,
+ EventArchiveDocument,
+ ReplaySessionDocument,
+ ResourceAllocationDocument,
+ SystemSettingsDocument,
+ AuditLogDocument,
+]
+
+__all__ = [
+ # User
+ "UserDocument",
+ # Execution
+ "ExecutionDocument",
+ "ResourceUsage",
+ # Saved Script
+ "SavedScriptDocument",
+ # Notification
+ "NotificationDocument",
+ "NotificationSubscriptionDocument",
+ # User Settings
+ "UserSettingsDocument",
+ "UserSettingsSnapshotDocument",
+ "NotificationSettings",
+ "EditorSettings",
+ # Saga
+ "SagaDocument",
+ # DLQ
+ "DLQMessageDocument",
+ # Event
+ "EventDocument",
+ "EventStoreDocument",
+ "EventArchiveDocument",
+ # Replay
+ "ReplaySessionDocument",
+ "ReplayConfig",
+ "ReplayFilter",
+ # Resource
+ "ResourceAllocationDocument",
+ # Admin Settings
+ "SystemSettingsDocument",
+ "AuditLogDocument",
+ "ExecutionLimitsConfig",
+ "SecuritySettingsConfig",
+ "MonitoringSettingsConfig",
+ # All documents list for Beanie init
+ "ALL_DOCUMENTS",
+]
diff --git a/backend/app/db/docs/admin_settings.py b/backend/app/db/docs/admin_settings.py
new file mode 100644
index 00000000..d4df5151
--- /dev/null
+++ b/backend/app/db/docs/admin_settings.py
@@ -0,0 +1,60 @@
+from datetime import datetime, timezone
+from typing import Any, Dict
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import BaseModel, ConfigDict, Field
+
+from app.domain.admin import AuditAction, LogLevel
+
+
+class ExecutionLimitsConfig(BaseModel):
+ max_timeout_seconds: int = 300
+ max_memory_mb: int = 512
+ max_cpu_cores: int = 2
+ max_concurrent_executions: int = 10
+
+
+class SecuritySettingsConfig(BaseModel):
+ password_min_length: int = 8
+ session_timeout_minutes: int = 60
+ max_login_attempts: int = 5
+ lockout_duration_minutes: int = 15
+
+
+class MonitoringSettingsConfig(BaseModel):
+ metrics_retention_days: int = 30
+ log_level: LogLevel = LogLevel.INFO
+ enable_tracing: bool = True
+ sampling_rate: float = 0.1
+
+
+class SystemSettingsDocument(Document):
+ settings_id: str = "global"
+ execution_limits: ExecutionLimitsConfig = Field(default_factory=ExecutionLimitsConfig)
+ security_settings: SecuritySettingsConfig = Field(default_factory=SecuritySettingsConfig)
+ monitoring_settings: MonitoringSettingsConfig = Field(default_factory=MonitoringSettingsConfig)
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "system_settings"
+ use_state_management = True
+
+
+class AuditLogDocument(Document):
+ audit_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ action: AuditAction
+ user_id: Indexed(str) # type: ignore[valid-type]
+ username: str
+ timestamp: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ changes: Dict[str, Any] = Field(default_factory=dict)
+ reason: str = ""
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "audit_log"
+ use_state_management = True
diff --git a/backend/app/db/docs/dlq.py b/backend/app/db/docs/dlq.py
new file mode 100644
index 00000000..192fe29c
--- /dev/null
+++ b/backend/app/db/docs/dlq.py
@@ -0,0 +1,60 @@
+from datetime import datetime, timezone
+from typing import Any
+
+from beanie import Document, Indexed
+from pydantic import ConfigDict, Field
+from pymongo import ASCENDING, DESCENDING, IndexModel
+
+from app.dlq.models import DLQMessageStatus
+from app.domain.enums.events import EventType
+
+
+class DLQMessageDocument(Document):
+ """Unified DLQ message document for the entire system.
+
+ Copied from DLQMessage dataclass.
+ """
+
+ # Core fields - always required
+ event: dict[str, Any] # The original event as dict (BaseEvent serialized)
+ event_id: Indexed(str, unique=True) # type: ignore[valid-type]
+ event_type: EventType # Indexed via Settings.indexes
+ original_topic: Indexed(str) # type: ignore[valid-type]
+ error: str # Error message from the failure
+ retry_count: Indexed(int) # type: ignore[valid-type]
+ failed_at: Indexed(datetime) # type: ignore[valid-type]
+ status: DLQMessageStatus # Indexed via Settings.indexes
+ producer_id: str # ID of the producer that sent to DLQ
+
+ # Optional fields
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ last_updated: datetime | None = None
+ next_retry_at: Indexed(datetime) | None = None # type: ignore[valid-type]
+ retried_at: datetime | None = None
+ discarded_at: datetime | None = None
+ discard_reason: str | None = None
+ dlq_offset: int | None = None
+ dlq_partition: int | None = None
+ last_error: str | None = None
+
+ # Kafka message headers (optional)
+ headers: dict[str, str] = Field(default_factory=dict)
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "dlq_messages"
+ use_state_management = True
+ indexes = [
+ IndexModel([("event_type", ASCENDING)], name="idx_dlq_event_type"),
+ IndexModel([("status", ASCENDING)], name="idx_dlq_status"),
+ IndexModel([("failed_at", DESCENDING)], name="idx_dlq_failed_desc"),
+ # TTL index - auto-delete after 7 days
+ IndexModel([("created_at", ASCENDING)], name="idx_dlq_created_ttl", expireAfterSeconds=7 * 24 * 3600),
+ ]
+
+ @property
+ def age_seconds(self) -> float:
+ """Get message age in seconds since failure."""
+ failed_at: datetime = self.failed_at
+ return (datetime.now(timezone.utc) - failed_at).total_seconds()
diff --git a/backend/app/db/docs/event.py b/backend/app/db/docs/event.py
new file mode 100644
index 00000000..ac64ad4e
--- /dev/null
+++ b/backend/app/db/docs/event.py
@@ -0,0 +1,174 @@
+from datetime import datetime, timedelta, timezone
+from typing import Any, Dict
+from uuid import uuid4
+
+import pymongo
+from beanie import Document, Indexed
+from pydantic import BaseModel, ConfigDict, Field
+from pymongo import ASCENDING, DESCENDING, IndexModel
+
+from app.domain.enums.common import Environment
+from app.domain.enums.events import EventType
+
+
+# Pydantic model required here because Beanie embedded documents must be Pydantic BaseModel subclasses.
+# This is NOT an API schema - it defines the MongoDB subdocument structure.
+class EventMetadata(BaseModel):
+ """Event metadata embedded document for Beanie storage."""
+
+ model_config = ConfigDict(from_attributes=True)
+
+ service_name: str
+ service_version: str
+ correlation_id: str = Field(default_factory=lambda: str(uuid4()))
+ user_id: str | None = None
+ ip_address: str | None = None
+ user_agent: str | None = None
+ environment: Environment = Environment.PRODUCTION
+
+
+class EventDocument(Document):
+ """Event document as stored in database.
+
+ Copied from EventInDB schema. Uses extra="allow" to store
+ additional fields from polymorphic BaseEvent subclasses.
+ """
+
+ event_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ event_type: EventType # Indexed via Settings.indexes
+ event_version: str = "1.0"
+ timestamp: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ aggregate_id: Indexed(str) | None = None # type: ignore[valid-type]
+ metadata: EventMetadata
+ payload: Dict[str, Any] = Field(default_factory=dict)
+ stored_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ ttl_expires_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc) + timedelta(days=30))
+
+ model_config = ConfigDict(from_attributes=True, extra="allow")
+
+ class Settings:
+ name = "events"
+ use_state_management = True
+ indexes = [
+ # Compound indexes for common query patterns
+ IndexModel([("event_type", ASCENDING), ("timestamp", DESCENDING)], name="idx_event_type_ts"),
+ IndexModel([("aggregate_id", ASCENDING), ("timestamp", DESCENDING)], name="idx_aggregate_ts"),
+ IndexModel([("metadata.correlation_id", ASCENDING)], name="idx_meta_correlation"),
+ IndexModel([("metadata.user_id", ASCENDING), ("timestamp", DESCENDING)], name="idx_meta_user_ts"),
+ IndexModel([("metadata.service_name", ASCENDING), ("timestamp", DESCENDING)], name="idx_meta_service_ts"),
+ # Payload sparse indexes
+ IndexModel([("payload.execution_id", ASCENDING)], name="idx_payload_execution", sparse=True),
+ IndexModel([("payload.pod_name", ASCENDING)], name="idx_payload_pod", sparse=True),
+ # TTL index (expireAfterSeconds=0 means use ttl_expires_at value directly)
+ IndexModel([("ttl_expires_at", ASCENDING)], name="idx_ttl", expireAfterSeconds=0),
+ # Additional compound indexes for query optimization
+ IndexModel([("event_type", ASCENDING), ("aggregate_id", ASCENDING)], name="idx_events_type_agg"),
+ IndexModel([("aggregate_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_agg_ts"),
+ IndexModel([("event_type", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_type_ts_asc"),
+ IndexModel([("metadata.user_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_user_ts"),
+ IndexModel([("metadata.user_id", ASCENDING), ("event_type", ASCENDING)], name="idx_events_user_type"),
+ IndexModel(
+ [("event_type", ASCENDING), ("metadata.user_id", ASCENDING), ("timestamp", DESCENDING)],
+ name="idx_events_type_user_ts",
+ ),
+ # Text search index
+ IndexModel(
+ [
+ ("event_type", pymongo.TEXT),
+ ("metadata.service_name", pymongo.TEXT),
+ ("metadata.user_id", pymongo.TEXT),
+ ("payload", pymongo.TEXT),
+ ],
+ name="idx_text_search",
+ language_override="none",
+ default_language="english",
+ ),
+ ]
+
+
+class EventStoreDocument(Document):
+ """Event store document for permanent event storage.
+
+ Same structure as EventDocument but in event_store collection.
+ Uses extra="allow" to store additional fields from polymorphic events.
+ No TTL index since this is permanent storage.
+ """
+
+ event_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ event_type: EventType # Indexed via Settings.indexes
+ event_version: str = "1.0"
+ timestamp: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ aggregate_id: Indexed(str) | None = None # type: ignore[valid-type]
+ metadata: EventMetadata
+ payload: Dict[str, Any] = Field(default_factory=dict)
+ stored_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ ttl_expires_at: datetime | None = None
+
+ model_config = ConfigDict(from_attributes=True, extra="allow")
+
+ class Settings:
+ name = "event_store"
+ use_state_management = True
+ indexes = [
+ # Compound indexes for common query patterns
+ IndexModel([("event_type", ASCENDING), ("timestamp", DESCENDING)], name="idx_event_type_ts"),
+ IndexModel([("aggregate_id", ASCENDING), ("timestamp", DESCENDING)], name="idx_aggregate_ts"),
+ IndexModel([("metadata.correlation_id", ASCENDING)], name="idx_meta_correlation"),
+ IndexModel([("metadata.user_id", ASCENDING), ("timestamp", DESCENDING)], name="idx_meta_user_ts"),
+ IndexModel([("metadata.service_name", ASCENDING), ("timestamp", DESCENDING)], name="idx_meta_service_ts"),
+ # Payload sparse indexes
+ IndexModel([("payload.execution_id", ASCENDING)], name="idx_payload_execution", sparse=True),
+ IndexModel([("payload.pod_name", ASCENDING)], name="idx_payload_pod", sparse=True),
+ # Additional compound indexes for query optimization
+ IndexModel([("event_type", ASCENDING), ("aggregate_id", ASCENDING)], name="idx_events_type_agg"),
+ IndexModel([("aggregate_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_agg_ts"),
+ IndexModel([("event_type", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_type_ts_asc"),
+ IndexModel([("metadata.user_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_user_ts"),
+ IndexModel([("metadata.user_id", ASCENDING), ("event_type", ASCENDING)], name="idx_events_user_type"),
+ IndexModel(
+ [("event_type", ASCENDING), ("metadata.user_id", ASCENDING), ("timestamp", DESCENDING)],
+ name="idx_events_type_user_ts",
+ ),
+ # Text search index
+ IndexModel(
+ [
+ ("event_type", pymongo.TEXT),
+ ("metadata.service_name", pymongo.TEXT),
+ ("metadata.user_id", pymongo.TEXT),
+ ("payload", pymongo.TEXT),
+ ],
+ name="idx_text_search",
+ language_override="none",
+ default_language="english",
+ ),
+ ]
+
+
+class EventArchiveDocument(Document):
+ """Archived event with deletion metadata.
+
+ Uses extra="allow" to preserve all fields from polymorphic events.
+ """
+
+ event_id: Indexed(str, unique=True) # type: ignore[valid-type]
+ event_type: EventType # Indexed via Settings.indexes
+ event_version: str = "1.0"
+ timestamp: Indexed(datetime) # type: ignore[valid-type]
+ aggregate_id: str | None = None
+ metadata: EventMetadata
+ payload: Dict[str, Any] = Field(default_factory=dict)
+ stored_at: datetime | None = None
+ ttl_expires_at: datetime | None = None
+
+ # Archive metadata
+ deleted_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ deleted_by: str | None = None
+
+ model_config = ConfigDict(from_attributes=True, extra="allow")
+
+ class Settings:
+ name = "events_archive"
+ use_state_management = True
+ indexes = [
+ IndexModel([("event_type", 1)]),
+ ]
diff --git a/backend/app/db/docs/execution.py b/backend/app/db/docs/execution.py
new file mode 100644
index 00000000..80724e35
--- /dev/null
+++ b/backend/app/db/docs/execution.py
@@ -0,0 +1,53 @@
+from datetime import datetime, timezone
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import BaseModel, ConfigDict, Field
+from pymongo import IndexModel
+
+from app.domain.enums.execution import ExecutionStatus
+from app.domain.enums.storage import ExecutionErrorType
+
+
+# Pydantic model required here because Beanie embedded documents must be Pydantic BaseModel subclasses.
+# This is NOT an API schema - it defines the MongoDB subdocument structure.
+class ResourceUsage(BaseModel):
+ execution_time_wall_seconds: float = 0.0
+ cpu_time_jiffies: int = 0
+ clk_tck_hertz: int = 0
+ peak_memory_kb: int = 0
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class ExecutionDocument(Document):
+ """Execution document as stored in database.
+
+ Copied from ExecutionInDB schema.
+ """
+
+ # From ExecutionBase
+ script: str = Field(..., max_length=50000, description="Script content (max 50,000 characters)")
+ status: ExecutionStatus = ExecutionStatus.QUEUED # Indexed via Settings.indexes
+ stdout: str | None = None
+ stderr: str | None = None
+ lang: str = "python"
+ lang_version: str = "3.11"
+
+ # From ExecutionInDB
+ execution_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ created_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ resource_usage: ResourceUsage | None = None
+ user_id: Indexed(str) | None = None # type: ignore[valid-type]
+ exit_code: int | None = None
+ error_type: ExecutionErrorType | None = None
+
+ model_config = ConfigDict(populate_by_name=True, from_attributes=True)
+
+ class Settings:
+ name = "executions"
+ use_state_management = True
+ indexes = [
+ IndexModel([("status", 1)]),
+ ]
diff --git a/backend/app/db/docs/notification.py b/backend/app/db/docs/notification.py
new file mode 100644
index 00000000..70944d79
--- /dev/null
+++ b/backend/app/db/docs/notification.py
@@ -0,0 +1,112 @@
+from datetime import UTC, datetime
+from typing import Any
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import ConfigDict, Field, field_validator
+from pymongo import ASCENDING, DESCENDING, IndexModel
+
+from app.domain.enums.notification import (
+ NotificationChannel,
+ NotificationSeverity,
+ NotificationStatus,
+)
+
+
+class NotificationDocument(Document):
+ """Individual notification instance.
+
+ Copied from Notification schema.
+ """
+
+ notification_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ user_id: Indexed(str) # type: ignore[valid-type]
+ channel: NotificationChannel
+ severity: NotificationSeverity = NotificationSeverity.MEDIUM
+ status: NotificationStatus = NotificationStatus.PENDING # Indexed via Settings.indexes
+
+ # Content
+ subject: str
+ body: str
+ action_url: str | None = None
+ tags: list[str] = Field(default_factory=list)
+
+ # Tracking
+ created_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(UTC)) # type: ignore[valid-type]
+ scheduled_for: datetime | None = None
+ sent_at: datetime | None = None
+ delivered_at: datetime | None = None
+ read_at: datetime | None = None
+ clicked_at: datetime | None = None
+ failed_at: datetime | None = None
+
+ # Error handling
+ retry_count: int = 0
+ max_retries: int = 3
+ error_message: str | None = None
+
+ # Context
+ metadata: dict[str, Any] = Field(default_factory=dict)
+
+ # Webhook specific
+ webhook_url: str | None = None
+ webhook_headers: dict[str, str] | None = None
+
+ @field_validator("scheduled_for")
+ @classmethod
+ def validate_scheduled_for(cls, v: datetime | None) -> datetime | None:
+ if v and v < datetime.now(UTC):
+ raise ValueError("scheduled_for must be in the future")
+ return v
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "notifications"
+ use_state_management = True
+ indexes = [
+ IndexModel([("user_id", ASCENDING), ("created_at", DESCENDING)], name="idx_notif_user_created_desc"),
+ IndexModel([("status", ASCENDING), ("scheduled_for", ASCENDING)], name="idx_notif_status_sched"),
+ ]
+
+
+class NotificationSubscriptionDocument(Document):
+ """User subscription preferences for notifications.
+
+ Copied from NotificationSubscription schema.
+ """
+
+ user_id: Indexed(str) # type: ignore[valid-type]
+ channel: NotificationChannel
+ severities: list[NotificationSeverity] = Field(default_factory=list)
+ include_tags: list[str] = Field(default_factory=list)
+ exclude_tags: list[str] = Field(default_factory=list)
+ enabled: bool = True # Indexed via Settings.indexes
+
+ # Channel-specific settings
+ webhook_url: str | None = None
+ slack_webhook: str | None = None
+
+ # Delivery preferences
+ quiet_hours_enabled: bool = False
+ quiet_hours_start: str | None = None # "22:00"
+ quiet_hours_end: str | None = None # "08:00"
+ timezone: str = "UTC"
+
+ # Batching preferences
+ batch_interval_minutes: int = 60
+
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "notification_subscriptions"
+ use_state_management = True
+ indexes = [
+ IndexModel(
+ [("user_id", ASCENDING), ("channel", ASCENDING)], name="idx_sub_user_channel_unique", unique=True
+ ),
+ IndexModel([("enabled", ASCENDING)], name="idx_sub_enabled"),
+ ]
diff --git a/backend/app/db/docs/replay.py b/backend/app/db/docs/replay.py
new file mode 100644
index 00000000..b707cd0e
--- /dev/null
+++ b/backend/app/db/docs/replay.py
@@ -0,0 +1,114 @@
+from datetime import datetime, timezone
+from typing import Any, Dict, List
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import BaseModel, ConfigDict, Field
+from pymongo import IndexModel
+
+from app.domain.enums.events import EventType
+from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
+
+
+class ReplayFilter(BaseModel):
+ """Replay filter configuration (embedded document).
+
+ Copied from domain/replay/models.py ReplayFilter.
+ """
+
+ execution_id: str | None = None
+ event_types: List[EventType] | None = None
+ start_time: datetime | None = None
+ end_time: datetime | None = None
+ user_id: str | None = None
+ service_name: str | None = None
+ custom_query: Dict[str, Any] | None = None
+ exclude_event_types: List[EventType] | None = None
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class ReplayConfig(BaseModel):
+ """Replay configuration (embedded document).
+
+ Copied from domain/replay/models.py ReplayConfig.
+ """
+
+ replay_type: ReplayType
+ target: ReplayTarget = ReplayTarget.KAFKA
+ filter: ReplayFilter
+
+ speed_multiplier: float = Field(default=1.0, ge=0.1, le=100.0)
+ preserve_timestamps: bool = False
+ batch_size: int = Field(default=100, ge=1, le=1000)
+ max_events: int | None = Field(default=None, ge=1)
+
+ target_topics: Dict[str, str] | None = None # EventType -> topic mapping as strings
+ target_file_path: str | None = None
+
+ skip_errors: bool = True
+ retry_failed: bool = False
+ retry_attempts: int = 3
+
+ enable_progress_tracking: bool = True
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class ReplaySessionDocument(Document):
+ """Domain replay session model stored in database.
+
+ Single source of truth for replay sessions. Used by both
+ ReplayService and AdminEventsRepository.
+ """
+
+ session_id: Indexed(str, unique=True) # type: ignore[valid-type]
+ config: ReplayConfig
+ status: ReplayStatus = ReplayStatus.CREATED # Indexed via Settings.indexes
+
+ total_events: int = 0
+ replayed_events: int = 0
+ failed_events: int = 0
+ skipped_events: int = 0
+
+ created_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ started_at: datetime | None = None
+ completed_at: datetime | None = None
+ last_event_at: datetime | None = None
+
+ errors: list[dict[str, Any]] = Field(default_factory=list)
+
+ # Tracking and admin fields
+ correlation_id: str = Field(default_factory=lambda: str(uuid4()))
+ created_by: str | None = None
+ target_service: str | None = None
+ dry_run: bool = False
+ triggered_executions: list[str] = Field(default_factory=list)
+ error: str | None = None # Single error message for admin display
+
+ model_config = ConfigDict(from_attributes=True)
+
+ @property
+ def progress_percentage(self) -> float:
+ """Calculate progress percentage."""
+ if self.total_events == 0:
+ return 0.0
+ return round((self.replayed_events / self.total_events) * 100, 2)
+
+ @property
+ def is_completed(self) -> bool:
+ """Check if session is completed."""
+ return self.status in [ReplayStatus.COMPLETED, ReplayStatus.FAILED, ReplayStatus.CANCELLED]
+
+ @property
+ def is_running(self) -> bool:
+ """Check if session is running."""
+ return self.status == ReplayStatus.RUNNING
+
+ class Settings:
+ name = "replay_sessions"
+ use_state_management = True
+ indexes = [
+ IndexModel([("status", 1)]),
+ IndexModel([("correlation_id", 1)]),
+ ]
diff --git a/backend/app/db/docs/resource.py b/backend/app/db/docs/resource.py
new file mode 100644
index 00000000..ef16e814
--- /dev/null
+++ b/backend/app/db/docs/resource.py
@@ -0,0 +1,28 @@
+from datetime import datetime, timezone
+
+from beanie import Document, Indexed
+from pydantic import ConfigDict, Field
+
+
+class ResourceAllocationDocument(Document):
+ """Resource allocation bookkeeping document used by saga steps.
+
+ Based on ResourceAllocationRepository document structure.
+ """
+
+ allocation_id: Indexed(str, unique=True) # type: ignore[valid-type]
+ execution_id: Indexed(str) # type: ignore[valid-type]
+ language: Indexed(str) # type: ignore[valid-type]
+ cpu_request: str
+ memory_request: str
+ cpu_limit: str
+ memory_limit: str
+ status: Indexed(str) = "active" # type: ignore[valid-type] # "active" | "released"
+ allocated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ released_at: datetime | None = None
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "resource_allocations"
+ use_state_management = True
diff --git a/backend/app/db/docs/saga.py b/backend/app/db/docs/saga.py
new file mode 100644
index 00000000..87b7a62a
--- /dev/null
+++ b/backend/app/db/docs/saga.py
@@ -0,0 +1,40 @@
+from datetime import datetime, timezone
+from typing import Any
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import ConfigDict, Field
+from pymongo import ASCENDING, IndexModel
+
+from app.domain.enums.saga import SagaState
+
+
+class SagaDocument(Document):
+ """Domain model for saga stored in database.
+
+ Copied from Saga/SagaInstance dataclass.
+ """
+
+ saga_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ saga_name: Indexed(str) # type: ignore[valid-type]
+ execution_id: Indexed(str) # type: ignore[valid-type]
+ state: SagaState = SagaState.CREATED # Indexed via Settings.indexes
+ current_step: str | None = None
+ completed_steps: list[str] = Field(default_factory=list)
+ compensated_steps: list[str] = Field(default_factory=list)
+ context_data: dict[str, Any] = Field(default_factory=dict)
+ error_message: str | None = None
+ created_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ completed_at: datetime | None = None
+ retry_count: int = 0
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "sagas"
+ use_state_management = True
+ indexes = [
+ IndexModel([("state", ASCENDING)], name="idx_saga_state"),
+ IndexModel([("state", ASCENDING), ("created_at", ASCENDING)], name="idx_saga_state_created"),
+ ]
diff --git a/backend/app/db/docs/saved_script.py b/backend/app/db/docs/saved_script.py
new file mode 100644
index 00000000..fd371d5e
--- /dev/null
+++ b/backend/app/db/docs/saved_script.py
@@ -0,0 +1,31 @@
+from datetime import datetime, timezone
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import ConfigDict, Field
+
+
+class SavedScriptDocument(Document):
+ """Saved script document as stored in database.
+
+ Copied from SavedScriptInDB schema.
+ """
+
+ # From SavedScriptBase
+ name: str
+ script: str
+ lang: str = "python"
+ lang_version: str = "3.11"
+ description: str | None = None
+
+ # From SavedScriptInDB
+ script_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ user_id: Indexed(str) # type: ignore[valid-type]
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "saved_scripts"
+ use_state_management = True
diff --git a/backend/app/db/docs/user.py b/backend/app/db/docs/user.py
new file mode 100644
index 00000000..eea8cd91
--- /dev/null
+++ b/backend/app/db/docs/user.py
@@ -0,0 +1,33 @@
+from datetime import datetime, timezone
+from uuid import uuid4
+
+from beanie import Document, Indexed
+from pydantic import ConfigDict, EmailStr, Field
+
+from app.domain.enums.user import UserRole
+
+
+class UserDocument(Document):
+ """User document as stored in database (with hashed password).
+
+ Copied from UserInDB schema.
+ """
+
+ # From UserBase
+ username: Indexed(str, unique=True) # type: ignore[valid-type]
+ email: Indexed(EmailStr, unique=True) # type: ignore[valid-type]
+ role: UserRole = UserRole.USER
+ is_active: bool = True
+
+ # From UserInDB
+ user_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type]
+ hashed_password: str
+ is_superuser: bool = False
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+
+ model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
+
+ class Settings:
+ name = "users"
+ use_state_management = True
diff --git a/backend/app/db/docs/user_settings.py b/backend/app/db/docs/user_settings.py
new file mode 100644
index 00000000..3518633a
--- /dev/null
+++ b/backend/app/db/docs/user_settings.py
@@ -0,0 +1,108 @@
+from datetime import datetime, timezone
+from typing import Any, Dict, List
+
+from beanie import Document, Indexed
+from pydantic import BaseModel, ConfigDict, Field, field_validator
+
+from app.domain.enums.common import Theme
+from app.domain.enums.notification import NotificationChannel
+
+
+class NotificationSettings(BaseModel):
+ """User notification preferences (embedded document).
+
+ Copied from user_settings.py NotificationSettings.
+ """
+
+ model_config = ConfigDict(from_attributes=True)
+
+ execution_completed: bool = True
+ execution_failed: bool = True
+ system_updates: bool = True
+ security_alerts: bool = True
+ channels: List[NotificationChannel] = [NotificationChannel.IN_APP]
+
+
+class EditorSettings(BaseModel):
+ """Code editor preferences (embedded document).
+
+ Copied from user_settings.py EditorSettings.
+ """
+
+ model_config = ConfigDict(from_attributes=True)
+
+ theme: str = "auto"
+ font_size: int = 14
+ tab_size: int = 4
+ use_tabs: bool = False
+ word_wrap: bool = True
+ show_line_numbers: bool = True
+
+ @field_validator("font_size")
+ @classmethod
+ def validate_font_size(cls, v: int) -> int:
+ if v < 8 or v > 32:
+ raise ValueError("Font size must be between 8 and 32")
+ return v
+
+ @field_validator("tab_size")
+ @classmethod
+ def validate_tab_size(cls, v: int) -> int:
+ if v not in (2, 4, 8):
+ raise ValueError("Tab size must be 2, 4, or 8")
+ return v
+
+
+class UserSettingsDocument(Document):
+ """Complete user settings model.
+
+ Copied from UserSettings schema.
+ """
+
+ user_id: Indexed(str, unique=True) # type: ignore[valid-type]
+ theme: Theme = Theme.AUTO
+ timezone: str = "UTC"
+ date_format: str = "YYYY-MM-DD"
+ time_format: str = "24h"
+ notifications: NotificationSettings = Field(default_factory=NotificationSettings)
+ editor: EditorSettings = Field(default_factory=EditorSettings)
+ custom_settings: Dict[str, Any] = Field(default_factory=dict)
+ version: int = 1
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "user_settings"
+ use_state_management = True
+
+
+class UserSettingsSnapshotDocument(Document):
+ """Snapshot of user settings for history/restore.
+
+ Based on UserSettings with additional snapshot metadata.
+ """
+
+ user_id: Indexed(str) # type: ignore[valid-type]
+ snapshot_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type]
+
+ # Full settings snapshot
+ theme: Theme = Theme.AUTO
+ timezone: str = "UTC"
+ date_format: str = "YYYY-MM-DD"
+ time_format: str = "24h"
+ notifications: NotificationSettings = Field(default_factory=NotificationSettings)
+ editor: EditorSettings = Field(default_factory=EditorSettings)
+ custom_settings: Dict[str, Any] = Field(default_factory=dict)
+ version: int = 1
+
+ # Snapshot metadata
+ reason: str | None = None
+ correlation_id: str | None = None
+
+ model_config = ConfigDict(from_attributes=True)
+
+ class Settings:
+ name = "user_settings_snapshots"
+ use_state_management = True
diff --git a/backend/app/db/repositories/admin/admin_events_repository.py b/backend/app/db/repositories/admin/admin_events_repository.py
index a7aa3b00..1190752e 100644
--- a/backend/app/db/repositories/admin/admin_events_repository.py
+++ b/backend/app/db/repositories/admin/admin_events_repository.py
@@ -1,126 +1,121 @@
+from dataclasses import asdict
from datetime import datetime, timedelta, timezone
-from typing import Any, Dict, List
+from typing import Any
-from pymongo import ReturnDocument
+from beanie.odm.enums import SortDirection
+from beanie.operators import GTE, LTE, In, Text
-from app.core.database_context import Collection, Database
-from app.domain.admin import (
- ReplayQuery,
- ReplaySession,
- ReplaySessionData,
- ReplaySessionFields,
- ReplaySessionStatusDetail,
+from app.db.docs import (
+ EventArchiveDocument,
+ EventStoreDocument,
+ ExecutionDocument,
+ ReplaySessionDocument,
)
+from app.domain.admin import ReplayQuery, ReplaySessionData, ReplaySessionStatusDetail
from app.domain.admin.replay_updates import ReplaySessionUpdate
from app.domain.enums.replay import ReplayStatus
+from app.domain.events import EventMetadata as DomainEventMetadata
from app.domain.events.event_models import (
- CollectionNames,
Event,
EventBrowseResult,
EventDetail,
EventExportRow,
- EventFields,
EventFilter,
EventStatistics,
EventSummary,
HourlyEventCount,
- SortDirection,
UserEventCount,
)
-from app.domain.events.query_builders import (
- EventStatsAggregation,
-)
-from app.infrastructure.mappers import (
- EventExportRowMapper,
- EventFilterMapper,
- EventMapper,
- EventSummaryMapper,
- ReplayQueryMapper,
- ReplaySessionMapper,
-)
+from app.domain.events.query_builders import EventStatsAggregation
+from app.domain.replay.models import ReplayConfig, ReplaySessionState
class AdminEventsRepository:
- """Repository for admin event operations using domain models."""
-
- def __init__(self, db: Database):
- self.db = db
- self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
- self.event_store_collection: Collection = self.db.get_collection(CollectionNames.EVENT_STORE)
- # Bind related collections used by this repository
- self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
- self.events_archive_collection: Collection = self.db.get_collection(CollectionNames.EVENTS_ARCHIVE)
- self.replay_mapper = ReplaySessionMapper()
- self.replay_query_mapper = ReplayQueryMapper()
- self.replay_sessions_collection: Collection = self.db.get_collection(CollectionNames.REPLAY_SESSIONS)
- self.mapper = EventMapper()
- self.summary_mapper = EventSummaryMapper()
+ def _event_filter_conditions(self, f: EventFilter) -> list[Any]:
+ """Build Beanie query conditions from EventFilter for EventStoreDocument."""
+ conditions = [
+ In(EventStoreDocument.event_type, f.event_types) if f.event_types else None,
+ EventStoreDocument.aggregate_id == f.aggregate_id if f.aggregate_id else None,
+ EventStoreDocument.metadata.correlation_id == f.correlation_id if f.correlation_id else None,
+ EventStoreDocument.metadata.user_id == f.user_id if f.user_id else None,
+ EventStoreDocument.metadata.service_name == f.service_name if f.service_name else None,
+ GTE(EventStoreDocument.timestamp, f.start_time) if f.start_time else None,
+ LTE(EventStoreDocument.timestamp, f.end_time) if f.end_time else None,
+ Text(f.search_text) if f.search_text else None,
+ ]
+ return [c for c in conditions if c is not None]
+
+ def _replay_conditions_for_store(self, q: ReplayQuery) -> list[Any]:
+ """Build Beanie query conditions from ReplayQuery for EventStoreDocument."""
+ conditions = [
+ In(EventStoreDocument.event_id, q.event_ids) if q.event_ids else None,
+ EventStoreDocument.metadata.correlation_id == q.correlation_id if q.correlation_id else None,
+ EventStoreDocument.aggregate_id == q.aggregate_id if q.aggregate_id else None,
+ GTE(EventStoreDocument.timestamp, q.start_time) if q.start_time else None,
+ LTE(EventStoreDocument.timestamp, q.end_time) if q.end_time else None,
+ ]
+ return [c for c in conditions if c is not None]
async def browse_events(
self,
event_filter: EventFilter,
skip: int = 0,
limit: int = 50,
- sort_by: str = EventFields.TIMESTAMP,
- sort_order: int = SortDirection.DESCENDING,
+ sort_by: str = "timestamp",
+ sort_order: SortDirection = SortDirection.DESCENDING,
) -> EventBrowseResult:
- """Browse events with filters using domain models."""
- query = EventFilterMapper.to_mongo_query(event_filter)
-
- # Get total count
- total = await self.events_collection.count_documents(query)
-
- # Execute query with pagination
- cursor = self.events_collection.find(query)
- cursor = cursor.sort(sort_by, sort_order)
- cursor = cursor.skip(skip).limit(limit)
-
- # Fetch events and convert to domain models
- event_docs = await cursor.to_list(length=limit)
- events = [self.mapper.from_mongo_document(doc) for doc in event_docs]
+ conditions = self._event_filter_conditions(event_filter)
+ query = EventStoreDocument.find(*conditions)
+ total = await query.count()
+
+ docs = await query.sort([(sort_by, sort_order)]).skip(skip).limit(limit).to_list()
+ doc_fields = set(EventStoreDocument.model_fields.keys()) - {"id", "revision_id"}
+ events = [
+ Event(**{**d.model_dump(include=doc_fields), "metadata": DomainEventMetadata(**d.metadata.model_dump())})
+ for d in docs
+ ]
return EventBrowseResult(events=events, total=total, skip=skip, limit=limit)
async def get_event_detail(self, event_id: str) -> EventDetail | None:
- """Get detailed information about an event."""
- event_doc = await self.events_collection.find_one({EventFields.EVENT_ID: event_id})
-
- if not event_doc:
+ doc = await EventStoreDocument.find_one({"event_id": event_id})
+ if not doc:
return None
- event = self.mapper.from_mongo_document(event_doc)
-
- # Get related events
- cursor = (
- self.events_collection.find(
- {EventFields.METADATA_CORRELATION_ID: event.correlation_id, EventFields.EVENT_ID: {"$ne": event_id}}
- )
- .sort(EventFields.TIMESTAMP, SortDirection.ASCENDING)
- .limit(10)
+ doc_fields = set(EventStoreDocument.model_fields.keys()) - {"id", "revision_id"}
+ event = Event(
+ **{**doc.model_dump(include=doc_fields), "metadata": DomainEventMetadata(**doc.metadata.model_dump())}
)
- related_docs = await cursor.to_list(length=10)
- related_events = [self.summary_mapper.from_mongo_document(doc) for doc in related_docs]
-
- # Build timeline (could be expanded with more logic)
- timeline = related_events[:5] # Simple timeline for now
-
- detail = EventDetail(event=event, related_events=related_events, timeline=timeline)
+ related_query = {"metadata.correlation_id": doc.metadata.correlation_id, "event_id": {"$ne": event_id}}
+ related_docs = await (
+ EventStoreDocument.find(related_query).sort([("timestamp", SortDirection.ASCENDING)]).limit(10).to_list()
+ )
+ related_events = [
+ EventSummary(
+ event_id=d.event_id,
+ event_type=str(d.event_type),
+ timestamp=d.timestamp,
+ aggregate_id=d.aggregate_id,
+ )
+ for d in related_docs
+ ]
+ timeline = related_events[:5]
- return detail
+ return EventDetail(event=event, related_events=related_events, timeline=timeline)
async def delete_event(self, event_id: str) -> bool:
- """Delete an event."""
- result = await self.events_collection.delete_one({EventFields.EVENT_ID: event_id})
- return result.deleted_count > 0
+ doc = await EventStoreDocument.find_one({"event_id": event_id})
+ if not doc:
+ return False
+ await doc.delete()
+ return True
async def get_event_stats(self, hours: int = 24) -> EventStatistics:
- """Get event statistics for the last N hours."""
start_time = datetime.now(timezone.utc) - timedelta(hours=hours)
- # Get overview statistics
overview_pipeline = EventStatsAggregation.build_overview_pipeline(start_time)
- overview_result = await self.events_collection.aggregate(overview_pipeline).to_list(1)
+ overview_result = await EventStoreDocument.aggregate(overview_pipeline).to_list()
stats = (
overview_result[0]
@@ -128,42 +123,31 @@ async def get_event_stats(self, hours: int = 24) -> EventStatistics:
else {"total_events": 0, "event_type_count": 0, "unique_user_count": 0, "service_count": 0}
)
- # Get error rate
- error_count = await self.events_collection.count_documents(
+ error_count = await EventStoreDocument.find(
{
- EventFields.TIMESTAMP: {"$gte": start_time},
- EventFields.EVENT_TYPE: {"$regex": "failed|error|timeout", "$options": "i"},
+ "timestamp": {"$gte": start_time},
+ "event_type": {"$regex": "failed|error|timeout", "$options": "i"},
}
- )
+ ).count()
error_rate = (error_count / stats["total_events"] * 100) if stats["total_events"] > 0 else 0
- # Get event types with counts
type_pipeline = EventStatsAggregation.build_event_types_pipeline(start_time)
- top_types = await self.events_collection.aggregate(type_pipeline).to_list(10)
+ top_types = await EventStoreDocument.aggregate(type_pipeline).to_list()
events_by_type = {t["_id"]: t["count"] for t in top_types}
- # Get events by hour
hourly_pipeline = EventStatsAggregation.build_hourly_events_pipeline(start_time)
- hourly_cursor = self.events_collection.aggregate(hourly_pipeline)
+ hourly_result = await EventStoreDocument.aggregate(hourly_pipeline).to_list()
events_by_hour: list[HourlyEventCount | dict[str, Any]] = [
- HourlyEventCount(hour=doc["_id"], count=doc["count"]) async for doc in hourly_cursor
+ HourlyEventCount(hour=doc["_id"], count=doc["count"]) for doc in hourly_result
]
- # Get top users
user_pipeline = EventStatsAggregation.build_top_users_pipeline(start_time)
- top_users_cursor = self.events_collection.aggregate(user_pipeline)
+ top_users_result = await EventStoreDocument.aggregate(user_pipeline).to_list()
top_users = [
- UserEventCount(user_id=doc["_id"], event_count=doc["count"])
- async for doc in top_users_cursor
- if doc["_id"] # Filter out None user_ids
+ UserEventCount(user_id=doc["_id"], event_count=doc["count"]) for doc in top_users_result if doc["_id"]
]
- # Get average processing time from executions collection
- # Since execution timing data is stored in executions, not events
- executions_collection = self.executions_collection
-
- # Calculate average execution time from completed executions in the last 24 hours
exec_pipeline: list[dict[str, Any]] = [
{
"$match": {
@@ -175,12 +159,12 @@ async def get_event_stats(self, hours: int = 24) -> EventStatistics:
{"$group": {"_id": None, "avg_duration": {"$avg": "$resource_usage.execution_time_wall_seconds"}}},
]
- exec_result = await executions_collection.aggregate(exec_pipeline).to_list(1)
+ exec_result = await ExecutionDocument.aggregate(exec_pipeline).to_list()
avg_processing_time = (
exec_result[0]["avg_duration"] if exec_result and exec_result[0].get("avg_duration") else 0
)
- statistics = EventStatistics(
+ return EventStatistics(
total_events=stats["total_events"],
events_by_type=events_by_type,
events_by_hour=events_by_hour,
@@ -189,222 +173,202 @@ async def get_event_stats(self, hours: int = 24) -> EventStatistics:
avg_processing_time=round(avg_processing_time, 2),
)
- return statistics
-
- async def export_events_csv(self, event_filter: EventFilter) -> List[EventExportRow]:
- """Export events as CSV data."""
- query = EventFilterMapper.to_mongo_query(event_filter)
-
- cursor = self.events_collection.find(query).sort(EventFields.TIMESTAMP, SortDirection.DESCENDING).limit(10000)
-
- event_docs = await cursor.to_list(length=10000)
-
- # Convert to export rows
- export_rows = []
- for doc in event_docs:
- event = self.mapper.from_mongo_document(doc)
- export_row = EventExportRowMapper.from_event(event)
- export_rows.append(export_row)
+ async def export_events_csv(self, event_filter: EventFilter) -> list[EventExportRow]:
+ conditions = self._event_filter_conditions(event_filter)
+ docs = await (
+ EventStoreDocument.find(*conditions).sort([("timestamp", SortDirection.DESCENDING)]).limit(10000).to_list()
+ )
- return export_rows
+ return [
+ EventExportRow(
+ event_id=doc.event_id,
+ event_type=str(doc.event_type),
+ timestamp=doc.timestamp.isoformat(),
+ correlation_id=doc.metadata.correlation_id or "",
+ aggregate_id=doc.aggregate_id or "",
+ user_id=doc.metadata.user_id or "",
+ service=doc.metadata.service_name,
+ status="",
+ error="",
+ )
+ for doc in docs
+ ]
async def archive_event(self, event: Event, deleted_by: str) -> bool:
- """Archive an event before deletion."""
- # Add deletion metadata
- event_dict = self.mapper.to_mongo_document(event)
- event_dict["_deleted_at"] = datetime.now(timezone.utc)
- event_dict["_deleted_by"] = deleted_by
-
- # Insert into bound archive collection
- result = await self.events_archive_collection.insert_one(event_dict)
- return result.inserted_id is not None
-
- async def create_replay_session(self, session: ReplaySession) -> str:
- """Create a new replay session."""
- session_dict = self.replay_mapper.to_dict(session)
- await self.replay_sessions_collection.insert_one(session_dict)
+ archive_doc = EventArchiveDocument(
+ event_id=event.event_id,
+ event_type=event.event_type,
+ event_version=event.event_version,
+ timestamp=event.timestamp,
+ aggregate_id=event.aggregate_id,
+ metadata=event.metadata,
+ payload=event.payload,
+ stored_at=event.stored_at,
+ ttl_expires_at=event.ttl_expires_at,
+ deleted_at=datetime.now(timezone.utc),
+ deleted_by=deleted_by,
+ )
+ await archive_doc.insert()
+ return True
+
+ async def create_replay_session(self, session: ReplaySessionState) -> str:
+ data = asdict(session)
+ data["config"] = session.config.model_dump()
+ doc = ReplaySessionDocument(**data)
+ await doc.insert()
return session.session_id
- async def get_replay_session(self, session_id: str) -> ReplaySession | None:
- """Get replay session by ID."""
- doc = await self.replay_sessions_collection.find_one({ReplaySessionFields.SESSION_ID: session_id})
- return self.replay_mapper.from_dict(doc) if doc else None
+ async def get_replay_session(self, session_id: str) -> ReplaySessionState | None:
+ doc = await ReplaySessionDocument.find_one({"session_id": session_id})
+ if not doc:
+ return None
+ data = doc.model_dump(exclude={"id", "revision_id"})
+ data["config"] = ReplayConfig.model_validate(data["config"])
+ return ReplaySessionState(**data)
async def update_replay_session(self, session_id: str, updates: ReplaySessionUpdate) -> bool:
- """Update replay session fields."""
- if not updates.has_updates():
+ update_dict = {k: (v.value if hasattr(v, "value") else v) for k, v in asdict(updates).items() if v is not None}
+ if not update_dict:
return False
- mongo_updates = updates.to_dict()
+ doc = await ReplaySessionDocument.find_one({"session_id": session_id})
+ if not doc:
+ return False
- result = await self.replay_sessions_collection.update_one(
- {ReplaySessionFields.SESSION_ID: session_id}, {"$set": mongo_updates}
- )
- return result.modified_count > 0
+ await doc.set(update_dict)
+ return True
async def get_replay_status_with_progress(self, session_id: str) -> ReplaySessionStatusDetail | None:
- """Get replay session status with progress updates."""
- doc = await self.replay_sessions_collection.find_one({ReplaySessionFields.SESSION_ID: session_id})
+ doc = await ReplaySessionDocument.find_one({"session_id": session_id})
if not doc:
return None
- session = self.replay_mapper.from_dict(doc)
current_time = datetime.now(timezone.utc)
- # Update status based on time if needed
- if session.status == ReplayStatus.SCHEDULED and session.created_at:
- time_since_created = current_time - session.created_at
+ # Auto-transition from SCHEDULED to RUNNING after 2 seconds
+ if doc.status == ReplayStatus.SCHEDULED and doc.created_at:
+ time_since_created = current_time - doc.created_at
if time_since_created.total_seconds() > 2:
- # Use atomic update to prevent race conditions
- update_result = await self.replay_sessions_collection.find_one_and_update(
- {ReplaySessionFields.SESSION_ID: session_id, ReplaySessionFields.STATUS: ReplayStatus.SCHEDULED},
- {
- "$set": {
- ReplaySessionFields.STATUS: ReplayStatus.RUNNING,
- ReplaySessionFields.STARTED_AT: current_time,
- }
- },
- return_document=ReturnDocument.AFTER,
- )
- if update_result:
- # Update local session object with the atomically updated values
- session = self.replay_mapper.from_dict(update_result)
-
- # Simulate progress if running
- if session.is_running and session.started_at:
- time_since_started = current_time - session.started_at
- # Assume 10 events per second processing rate
- estimated_progress = min(int(time_since_started.total_seconds() * 10), session.total_events)
-
- # Update progress - returns new instance
- updated_session = session.update_progress(estimated_progress)
-
- # Update in database
- session_update = ReplaySessionUpdate(replayed_events=updated_session.replayed_events)
-
- if updated_session.is_completed:
- session_update.status = updated_session.status
- session_update.completed_at = updated_session.completed_at
-
- await self.update_replay_session(session_id, session_update)
-
- # Use the updated session for the rest of the method
- session = updated_session
-
- # Calculate estimated completion
+ doc.status = ReplayStatus.RUNNING
+ doc.started_at = current_time
+ await doc.save()
+
+ # Update progress for running sessions
+ if doc.is_running and doc.started_at:
+ time_since_started = current_time - doc.started_at
+ estimated_progress = min(int(time_since_started.total_seconds() * 10), doc.total_events)
+ doc.replayed_events = estimated_progress
+
+ # Check if completed
+ if doc.replayed_events >= doc.total_events:
+ doc.status = ReplayStatus.COMPLETED
+ doc.completed_at = current_time
+ await doc.save()
+
+ # Calculate estimated completion time
estimated_completion = None
- if session.is_running and session.replayed_events > 0 and session.started_at:
- rate = session.replayed_events / (current_time - session.started_at).total_seconds()
- remaining = session.total_events - session.replayed_events
- if rate > 0:
- estimated_completion = current_time + timedelta(seconds=remaining / rate)
-
- # Fetch execution results from the original events that were replayed
- execution_results = []
- # Get the query that was used for replay from the session's config
- original_query = {}
- if doc and "config" in doc:
- config = doc.get("config", {})
- filter_config = config.get("filter", {})
- original_query = filter_config.get("custom_query", {})
-
- if original_query:
- # Find the original events that were replayed
- original_events = await self.events_collection.find(original_query).to_list(10)
-
- # Get unique execution IDs from original events
+ if doc.is_running and doc.replayed_events > 0 and doc.started_at:
+ elapsed = (current_time - doc.started_at).total_seconds()
+ if elapsed > 0:
+ rate = doc.replayed_events / elapsed
+ remaining = doc.total_events - doc.replayed_events
+ if rate > 0:
+ estimated_completion = current_time + timedelta(seconds=remaining / rate)
+
+ # Fetch related execution results
+ execution_results: list[dict[str, Any]] = []
+ if doc.config and doc.config.filter and doc.config.filter.custom_query:
+ original_query = doc.config.filter.custom_query
+ original_events = await EventStoreDocument.find(original_query).limit(10).to_list()
+
execution_ids = set()
for event in original_events:
- # Try to get execution_id from various locations
- exec_id = event.get("execution_id")
- if not exec_id and event.get("payload"):
- exec_id = event.get("payload", {}).get("execution_id")
- if not exec_id:
- exec_id = event.get("aggregate_id")
+ exec_id = event.payload.get("execution_id") or event.aggregate_id
if exec_id:
execution_ids.add(exec_id)
- # Fetch execution details
- if execution_ids:
- executions_collection = self.executions_collection
- for exec_id in list(execution_ids)[:10]: # Limit to 10
- exec_doc = await executions_collection.find_one({"execution_id": exec_id})
- if exec_doc:
- execution_results.append(
- {
- "execution_id": exec_doc.get("execution_id"),
- "status": exec_doc.get("status"),
- "stdout": exec_doc.get("stdout"),
- "stderr": exec_doc.get("stderr"),
- "exit_code": exec_doc.get("exit_code"),
- "execution_time": exec_doc.get("execution_time"),
- "lang": exec_doc.get("lang"),
- "lang_version": exec_doc.get("lang_version"),
- "created_at": exec_doc.get("created_at"),
- "updated_at": exec_doc.get("updated_at"),
- }
- )
+ for exec_id in list(execution_ids)[:10]:
+ exec_doc = await ExecutionDocument.find_one({"execution_id": exec_id})
+ if exec_doc:
+ execution_results.append(
+ {
+ "execution_id": exec_doc.execution_id,
+ "status": exec_doc.status.value if exec_doc.status else None,
+ "stdout": exec_doc.stdout,
+ "stderr": exec_doc.stderr,
+ "exit_code": exec_doc.exit_code,
+ "lang": exec_doc.lang,
+ "lang_version": exec_doc.lang_version,
+ "created_at": exec_doc.created_at,
+ "updated_at": exec_doc.updated_at,
+ }
+ )
+
+ # Convert document to domain
+ data = doc.model_dump(exclude={"id", "revision_id"})
+ data["config"] = ReplayConfig.model_validate(data["config"])
+ session = ReplaySessionState(**data)
return ReplaySessionStatusDetail(
- session=session, estimated_completion=estimated_completion, execution_results=execution_results
+ session=session,
+ estimated_completion=estimated_completion,
+ execution_results=execution_results,
)
- async def count_events_for_replay(self, query: Dict[str, Any]) -> int:
- """Count events matching replay query."""
- return await self.events_collection.count_documents(query)
-
- async def get_events_preview_for_replay(self, query: Dict[str, Any], limit: int = 100) -> List[EventSummary]:
- """Get preview of events for replay."""
- cursor = self.events_collection.find(query).limit(limit)
- event_docs = await cursor.to_list(length=limit)
- return [self.summary_mapper.from_mongo_document(doc) for doc in event_docs]
-
- def build_replay_query(self, replay_query: ReplayQuery) -> Dict[str, Any]:
- """Build MongoDB query from replay query model."""
- return self.replay_query_mapper.to_mongodb_query(replay_query)
+ async def count_events_for_replay(self, replay_query: ReplayQuery) -> int:
+ conditions = self._replay_conditions_for_store(replay_query)
+ return await EventStoreDocument.find(*conditions).count()
+
+ async def get_events_preview_for_replay(self, replay_query: ReplayQuery, limit: int = 100) -> list[EventSummary]:
+ conditions = self._replay_conditions_for_store(replay_query)
+ docs = await EventStoreDocument.find(*conditions).limit(limit).to_list()
+ return [
+ EventSummary(
+ event_id=doc.event_id,
+ event_type=str(doc.event_type),
+ timestamp=doc.timestamp,
+ aggregate_id=doc.aggregate_id,
+ )
+ for doc in docs
+ ]
async def prepare_replay_session(
- self, query: Dict[str, Any], dry_run: bool, replay_correlation_id: str, max_events: int = 1000
+ self, replay_query: ReplayQuery, dry_run: bool, replay_correlation_id: str, max_events: int = 1000
) -> ReplaySessionData:
- """Prepare replay session with validation and preview."""
- event_count = await self.count_events_for_replay(query)
+ event_count = await self.count_events_for_replay(replay_query)
if event_count == 0:
raise ValueError("No events found matching the criteria")
if event_count > max_events and not dry_run:
raise ValueError(f"Too many events to replay ({event_count}). Maximum is {max_events}.")
- # Get events preview for dry run
- events_preview: List[EventSummary] = []
+ events_preview: list[EventSummary] = []
if dry_run:
- events_preview = await self.get_events_preview_for_replay(query, limit=100)
+ events_preview = await self.get_events_preview_for_replay(replay_query, limit=100)
- # Return unified session data
- session_data = ReplaySessionData(
+ return ReplaySessionData(
total_events=event_count,
replay_correlation_id=replay_correlation_id,
dry_run=dry_run,
- query=query,
+ query=replay_query,
events_preview=events_preview,
)
- return session_data
-
async def get_replay_events_preview(
- self, event_ids: List[str] | None = None, correlation_id: str | None = None, aggregate_id: str | None = None
- ) -> Dict[str, Any]:
- """Get preview of events that would be replayed - backward compatibility."""
+ self, event_ids: list[str] | None = None, correlation_id: str | None = None, aggregate_id: str | None = None
+ ) -> dict[str, Any]:
replay_query = ReplayQuery(event_ids=event_ids, correlation_id=correlation_id, aggregate_id=aggregate_id)
+ conditions = self._replay_conditions_for_store(replay_query)
- query = self.replay_query_mapper.to_mongodb_query(replay_query)
-
- if not query:
+ if not conditions:
return {"events": [], "total": 0}
- total = await self.event_store_collection.count_documents(query)
-
- cursor = self.event_store_collection.find(query).sort(EventFields.TIMESTAMP, SortDirection.ASCENDING).limit(100)
-
- # Batch fetch all events from cursor
- events = await cursor.to_list(length=100)
+ total = await EventStoreDocument.find(*conditions).count()
+ docs = (
+ await EventStoreDocument.find(*conditions)
+ .sort([("timestamp", SortDirection.ASCENDING)])
+ .limit(100)
+ .to_list()
+ )
+ events = [doc.model_dump() for doc in docs]
return {"events": events, "total": total}
diff --git a/backend/app/db/repositories/admin/admin_settings_repository.py b/backend/app/db/repositories/admin/admin_settings_repository.py
index 1e2e0d19..2899e117 100644
--- a/backend/app/db/repositories/admin/admin_settings_repository.py
+++ b/backend/app/db/repositories/admin/admin_settings_repository.py
@@ -1,75 +1,83 @@
+import logging
from datetime import datetime, timezone
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
-from app.domain.admin import (
- AuditAction,
- AuditLogEntry,
- SystemSettings,
+from app.db.docs import (
+ AuditLogDocument,
+ ExecutionLimitsConfig,
+ MonitoringSettingsConfig,
+ SecuritySettingsConfig,
+ SystemSettingsDocument,
)
-from app.infrastructure.mappers import AuditLogMapper, SettingsMapper
+from app.domain.admin import AuditAction, ExecutionLimits, MonitoringSettings, SecuritySettings, SystemSettings
class AdminSettingsRepository:
- def __init__(self, db: Database):
- self.db = db
- self.settings_collection: Collection = self.db.get_collection("system_settings")
- self.audit_log_collection: Collection = self.db.get_collection("audit_log")
- self.settings_mapper = SettingsMapper()
- self.audit_mapper = AuditLogMapper()
+ def __init__(self, logger: logging.Logger):
+ self.logger = logger
async def get_system_settings(self) -> SystemSettings:
- """Get system settings from database, creating defaults if not found."""
- settings_doc = await self.settings_collection.find_one({"_id": "global"})
- if not settings_doc:
- logger.info("System settings not found, creating defaults")
- # Create default settings
- default_settings = SystemSettings()
- settings_dict = self.settings_mapper.system_settings_to_dict(default_settings)
-
- # Insert default settings
- await self.settings_collection.insert_one(settings_dict)
- return default_settings
-
- return self.settings_mapper.system_settings_from_dict(settings_doc)
-
- async def update_system_settings(self, settings: SystemSettings, updated_by: str, user_id: str) -> SystemSettings:
- """Update system-wide settings."""
- # Update settings metadata
- settings.updated_at = datetime.now(timezone.utc)
-
- # Convert to dict and save
- settings_dict = self.settings_mapper.system_settings_to_dict(settings)
-
- await self.settings_collection.replace_one({"_id": "global"}, settings_dict, upsert=True)
+ doc = await SystemSettingsDocument.find_one({"settings_id": "global"})
+ if not doc:
+ self.logger.info("System settings not found, creating defaults")
+ doc = SystemSettingsDocument(
+ settings_id="global",
+ execution_limits=ExecutionLimitsConfig(),
+ security_settings=SecuritySettingsConfig(),
+ monitoring_settings=MonitoringSettingsConfig(),
+ )
+ await doc.insert()
+ return SystemSettings(
+ execution_limits=ExecutionLimits(**doc.execution_limits.model_dump()),
+ security_settings=SecuritySettings(**doc.security_settings.model_dump()),
+ monitoring_settings=MonitoringSettings(**doc.monitoring_settings.model_dump()),
+ created_at=doc.created_at,
+ updated_at=doc.updated_at,
+ )
- # Create audit log entry
- audit_entry = AuditLogEntry(
+ async def update_system_settings(
+ self,
+ settings: SystemSettings,
+ updated_by: str,
+ user_id: str,
+ ) -> SystemSettings:
+ doc = await SystemSettingsDocument.find_one({"settings_id": "global"})
+ if not doc:
+ doc = SystemSettingsDocument(settings_id="global")
+
+ doc.execution_limits = ExecutionLimitsConfig(**settings.execution_limits.__dict__)
+ doc.security_settings = SecuritySettingsConfig(**settings.security_settings.__dict__)
+ doc.monitoring_settings = MonitoringSettingsConfig(**settings.monitoring_settings.__dict__)
+ doc.updated_at = datetime.now(timezone.utc)
+ await doc.save()
+
+ audit_entry = AuditLogDocument(
action=AuditAction.SYSTEM_SETTINGS_UPDATED,
user_id=user_id,
username=updated_by,
timestamp=datetime.now(timezone.utc),
- changes=settings_dict,
+ changes=doc.model_dump(exclude={"id", "revision_id"}),
+ )
+ await audit_entry.insert()
+
+ return SystemSettings(
+ execution_limits=ExecutionLimits(**doc.execution_limits.model_dump()),
+ security_settings=SecuritySettings(**doc.security_settings.model_dump()),
+ monitoring_settings=MonitoringSettings(**doc.monitoring_settings.model_dump()),
+ created_at=doc.created_at,
+ updated_at=doc.updated_at,
)
-
- await self.audit_log_collection.insert_one(self.audit_mapper.to_dict(audit_entry))
-
- return settings
async def reset_system_settings(self, username: str, user_id: str) -> SystemSettings:
- """Reset system settings to defaults."""
- # Delete current settings
- await self.settings_collection.delete_one({"_id": "global"})
+ doc = await SystemSettingsDocument.find_one({"settings_id": "global"})
+ if doc:
+ await doc.delete()
- # Create audit log entry
- audit_entry = AuditLogEntry(
+ audit_entry = AuditLogDocument(
action=AuditAction.SYSTEM_SETTINGS_RESET,
user_id=user_id,
username=username,
timestamp=datetime.now(timezone.utc),
)
+ await audit_entry.insert()
- await self.audit_log_collection.insert_one(self.audit_mapper.to_dict(audit_entry))
-
- # Return default settings
return SystemSettings()
diff --git a/backend/app/db/repositories/admin/admin_user_repository.py b/backend/app/db/repositories/admin/admin_user_repository.py
index f4d38549..f7aed21a 100644
--- a/backend/app/db/repositories/admin/admin_user_repository.py
+++ b/backend/app/db/repositories/admin/admin_user_repository.py
@@ -1,133 +1,115 @@
+import re
+from dataclasses import asdict
from datetime import datetime, timezone
-from app.core.database_context import Collection, Database
+from beanie.odm.operators.find import BaseFindOperator
+from beanie.operators import Eq, Or, RegEx
+
from app.core.security import SecurityService
-from app.domain.enums import UserRole
-from app.domain.events.event_models import CollectionNames
-from app.domain.user import (
- PasswordReset,
- User,
- UserFields,
- UserListResult,
- UserSearchFilter,
- UserUpdate,
+from app.db.docs import (
+ EventDocument,
+ ExecutionDocument,
+ NotificationDocument,
+ SagaDocument,
+ SavedScriptDocument,
+ UserDocument,
+ UserSettingsDocument,
)
-from app.infrastructure.mappers import UserMapper
+from app.domain.enums import UserRole
+from app.domain.user import DomainUserCreate, PasswordReset, User, UserListResult, UserUpdate
class AdminUserRepository:
- def __init__(self, db: Database):
- self.db = db
- self.users_collection: Collection = self.db.get_collection(CollectionNames.USERS)
-
- # Related collections used by this repository (e.g., cascade deletes)
- self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
- self.saved_scripts_collection: Collection = self.db.get_collection(CollectionNames.SAVED_SCRIPTS)
- self.notifications_collection: Collection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
- self.user_settings_collection: Collection = self.db.get_collection(CollectionNames.USER_SETTINGS)
- self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
- self.sagas_collection: Collection = self.db.get_collection(CollectionNames.SAGAS)
+ def __init__(self) -> None:
self.security_service = SecurityService()
- self.mapper = UserMapper()
+
+ async def create_user(self, create_data: DomainUserCreate) -> User:
+ doc = UserDocument(**asdict(create_data))
+ await doc.insert()
+ return User(**doc.model_dump(exclude={"id", "revision_id"}))
async def list_users(
self, limit: int = 100, offset: int = 0, search: str | None = None, role: UserRole | None = None
) -> UserListResult:
- """List all users with optional filtering."""
- # Create search filter
- search_filter = UserSearchFilter(search_text=search, role=role)
-
- query = self.mapper.search_filter_to_query(search_filter)
-
- # Get total count
- total = await self.users_collection.count_documents(query)
-
- # Get users with pagination
- cursor = self.users_collection.find(query).skip(offset).limit(limit)
-
- users = []
- async for user_doc in cursor:
- users.append(self.mapper.from_mongo_document(user_doc))
-
+ conditions: list[BaseFindOperator] = []
+
+ if search:
+ escaped_search = re.escape(search)
+ conditions.append(
+ Or(
+ RegEx(UserDocument.username, escaped_search, options="i"),
+ RegEx(UserDocument.email, escaped_search, options="i"),
+ )
+ )
+
+ if role:
+ conditions.append(Eq(UserDocument.role, role))
+
+ query = UserDocument.find(*conditions)
+ total = await query.count()
+ docs = await query.skip(offset).limit(limit).to_list()
+ users = [User(**doc.model_dump(exclude={"id", "revision_id"})) for doc in docs]
return UserListResult(users=users, total=total, offset=offset, limit=limit)
async def get_user_by_id(self, user_id: str) -> User | None:
- """Get user by ID."""
- user_doc = await self.users_collection.find_one({UserFields.USER_ID: user_id})
- if user_doc:
- return self.mapper.from_mongo_document(user_doc)
- return None
+ doc = await UserDocument.find_one({"user_id": user_id})
+ return User(**doc.model_dump(exclude={"id", "revision_id"})) if doc else None
async def update_user(self, user_id: str, update_data: UserUpdate) -> User | None:
- """Update user details."""
- if not update_data.has_updates():
- return await self.get_user_by_id(user_id)
-
- # Get update dict
- update_dict = self.mapper.to_update_dict(update_data)
-
- # Hash password if provided
- if update_data.password:
- update_dict[UserFields.HASHED_PASSWORD] = self.security_service.get_password_hash(update_data.password)
- # Ensure no plaintext password field is persisted
- update_dict.pop("password", None)
-
- # Add updated_at timestamp
- update_dict[UserFields.UPDATED_AT] = datetime.now(timezone.utc)
+ doc = await UserDocument.find_one({"user_id": user_id})
+ if not doc:
+ return None
- result = await self.users_collection.update_one({UserFields.USER_ID: user_id}, {"$set": update_dict})
+ update_dict = {k: v for k, v in asdict(update_data).items() if v is not None}
+ # Handle password hashing
+ if "password" in update_dict:
+ update_dict["hashed_password"] = self.security_service.get_password_hash(update_dict.pop("password"))
- if result.modified_count > 0:
- return await self.get_user_by_id(user_id)
-
- return None
+ if update_dict:
+ update_dict["updated_at"] = datetime.now(timezone.utc)
+ await doc.set(update_dict)
+ return User(**doc.model_dump(exclude={"id", "revision_id"}))
async def delete_user(self, user_id: str, cascade: bool = True) -> dict[str, int]:
- """Delete user with optional cascade deletion of related data."""
deleted_counts = {}
- result = await self.users_collection.delete_one({UserFields.USER_ID: user_id})
- deleted_counts["user"] = result.deleted_count
+ doc = await UserDocument.find_one({"user_id": user_id})
+ if doc:
+ await doc.delete()
+ deleted_counts["user"] = 1
+ else:
+ deleted_counts["user"] = 0
if not cascade:
return deleted_counts
- # Delete user's executions
- executions_result = await self.executions_collection.delete_many({"user_id": user_id})
- deleted_counts["executions"] = executions_result.deleted_count
+ # Cascade delete related data
+ exec_result = await ExecutionDocument.find({"user_id": user_id}).delete()
+ deleted_counts["executions"] = exec_result.deleted_count if exec_result else 0
- # Delete user's saved scripts
- scripts_result = await self.saved_scripts_collection.delete_many({"user_id": user_id})
- deleted_counts["saved_scripts"] = scripts_result.deleted_count
+ scripts_result = await SavedScriptDocument.find({"user_id": user_id}).delete()
+ deleted_counts["saved_scripts"] = scripts_result.deleted_count if scripts_result else 0
- # Delete user's notifications
- notifications_result = await self.notifications_collection.delete_many({"user_id": user_id})
- deleted_counts["notifications"] = notifications_result.deleted_count
+ notif_result = await NotificationDocument.find({"user_id": user_id}).delete()
+ deleted_counts["notifications"] = notif_result.deleted_count if notif_result else 0
- # Delete user's settings
- settings_result = await self.user_settings_collection.delete_many({"user_id": user_id})
- deleted_counts["user_settings"] = settings_result.deleted_count
+ settings_result = await UserSettingsDocument.find({"user_id": user_id}).delete()
+ deleted_counts["user_settings"] = settings_result.deleted_count if settings_result else 0
- # Delete user's events (if needed)
- events_result = await self.events_collection.delete_many({"user_id": user_id})
- deleted_counts["events"] = events_result.deleted_count
+ events_result = await EventDocument.find({"metadata.user_id": user_id}).delete()
+ deleted_counts["events"] = events_result.deleted_count if events_result else 0
- # Delete user's sagas
- sagas_result = await self.sagas_collection.delete_many({"user_id": user_id})
- deleted_counts["sagas"] = sagas_result.deleted_count
+ sagas_result = await SagaDocument.find({"context_data.user_id": user_id}).delete()
+ deleted_counts["sagas"] = sagas_result.deleted_count if sagas_result else 0
return deleted_counts
- async def reset_user_password(self, password_reset: PasswordReset) -> bool:
- """Reset user password."""
- if not password_reset.is_valid():
- raise ValueError("Invalid password reset data")
-
- hashed_password = self.security_service.get_password_hash(password_reset.new_password)
-
- result = await self.users_collection.update_one(
- {UserFields.USER_ID: password_reset.user_id},
- {"$set": {UserFields.HASHED_PASSWORD: hashed_password, UserFields.UPDATED_AT: datetime.now(timezone.utc)}},
- )
+ async def reset_user_password(self, reset_data: PasswordReset) -> bool:
+ doc = await UserDocument.find_one({"user_id": reset_data.user_id})
+ if not doc:
+ return False
- return result.modified_count > 0
+ doc.hashed_password = self.security_service.get_password_hash(reset_data.new_password)
+ doc.updated_at = datetime.now(timezone.utc)
+ await doc.save()
+ return True
diff --git a/backend/app/db/repositories/dlq_repository.py b/backend/app/db/repositories/dlq_repository.py
index b1659a49..5ab12674 100644
--- a/backend/app/db/repositories/dlq_repository.py
+++ b/backend/app/db/repositories/dlq_repository.py
@@ -1,14 +1,14 @@
+import logging
from datetime import datetime, timezone
-from typing import Dict, List, Mapping
+from typing import Any, Dict, List, Mapping
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
+from beanie.odm.enums import SortDirection
+
+from app.db.docs import DLQMessageDocument
from app.dlq import (
AgeStatistics,
DLQBatchRetryResult,
- DLQFields,
DLQMessage,
- DLQMessageFilter,
DLQMessageListResult,
DLQMessageStatus,
DLQRetryResult,
@@ -19,28 +19,25 @@
)
from app.dlq.manager import DLQManager
from app.domain.enums.events import EventType
-from app.domain.events.event_models import CollectionNames
-from app.infrastructure.mappers.dlq_mapper import DLQMapper
+from app.infrastructure.kafka.mappings import get_event_class_for_type
class DLQRepository:
- def __init__(self, db: Database):
- self.db = db
- self.dlq_collection: Collection = self.db.get_collection(CollectionNames.DLQ_MESSAGES)
+ def __init__(self, logger: logging.Logger):
+ self.logger = logger
+
+ def _doc_to_message(self, doc: DLQMessageDocument) -> DLQMessage:
+ event_class = get_event_class_for_type(doc.event_type)
+ if not event_class:
+ raise ValueError(f"Unknown event type: {doc.event_type}")
+ data = doc.model_dump(exclude={"id", "revision_id"})
+ return DLQMessage(**{**data, "event": event_class(**data["event"])})
async def get_dlq_stats(self) -> DLQStatistics:
# Get counts by status
- status_pipeline: list[Mapping[str, object]] = [
- {"$group": {"_id": f"${DLQFields.STATUS}", "count": {"$sum": 1}}}
- ]
-
- status_results = []
- async for doc in self.dlq_collection.aggregate(status_pipeline):
- status_results.append(doc)
-
- # Convert status results to dict
+ status_pipeline: list[Mapping[str, object]] = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
by_status: Dict[str, int] = {}
- for doc in status_results:
+ async for doc in DLQMessageDocument.aggregate(status_pipeline):
if doc["_id"]:
by_status[doc["_id"]] = doc["count"]
@@ -48,40 +45,36 @@ async def get_dlq_stats(self) -> DLQStatistics:
topic_pipeline: list[Mapping[str, object]] = [
{
"$group": {
- "_id": f"${DLQFields.ORIGINAL_TOPIC}",
+ "_id": "$original_topic",
"count": {"$sum": 1},
- "avg_retry_count": {"$avg": f"${DLQFields.RETRY_COUNT}"},
+ "avg_retry_count": {"$avg": "$retry_count"},
}
},
{"$sort": {"count": -1}},
{"$limit": 10},
]
-
by_topic: List[TopicStatistic] = []
- async for doc in self.dlq_collection.aggregate(topic_pipeline):
+ async for doc in DLQMessageDocument.aggregate(topic_pipeline):
by_topic.append(
TopicStatistic(topic=doc["_id"], count=doc["count"], avg_retry_count=round(doc["avg_retry_count"], 2))
)
# Get counts by event type
event_type_pipeline: list[Mapping[str, object]] = [
- {"$group": {"_id": f"${DLQFields.EVENT_TYPE}", "count": {"$sum": 1}}},
+ {"$group": {"_id": "$event_type", "count": {"$sum": 1}}},
{"$sort": {"count": -1}},
{"$limit": 10},
]
-
by_event_type: List[EventTypeStatistic] = []
- async for doc in self.dlq_collection.aggregate(event_type_pipeline):
- if doc["_id"]: # Skip null event types
+ async for doc in DLQMessageDocument.aggregate(event_type_pipeline):
+ if doc["_id"]:
by_event_type.append(EventTypeStatistic(event_type=doc["_id"], count=doc["count"]))
# Get age statistics
age_pipeline: list[Mapping[str, object]] = [
{
"$project": {
- "age_seconds": {
- "$divide": [{"$subtract": [datetime.now(timezone.utc), f"${DLQFields.FAILED_AT}"]}, 1000]
- }
+ "age_seconds": {"$divide": [{"$subtract": [datetime.now(timezone.utc), "$failed_at"]}, 1000]}
}
},
{
@@ -93,8 +86,9 @@ async def get_dlq_stats(self) -> DLQStatistics:
}
},
]
-
- age_result = await self.dlq_collection.aggregate(age_pipeline).to_list(1)
+ age_result = []
+ async for doc in DLQMessageDocument.aggregate(age_pipeline):
+ age_result.append(doc)
age_stats_data = age_result[0] if age_result else {}
age_stats = AgeStatistics(
min_age_seconds=age_stats_data.get("min_age", 0.0),
@@ -112,43 +106,46 @@ async def get_messages(
limit: int = 50,
offset: int = 0,
) -> DLQMessageListResult:
- msg_filter = DLQMessageFilter(status=status, topic=topic, event_type=event_type)
- query = DLQMapper.filter_to_query(msg_filter)
- total_count = await self.dlq_collection.count_documents(query)
-
- cursor = self.dlq_collection.find(query).sort(DLQFields.FAILED_AT, -1).skip(offset).limit(limit)
+ conditions: list[Any] = [
+ DLQMessageDocument.status == status if status else None,
+ DLQMessageDocument.original_topic == topic if topic else None,
+ DLQMessageDocument.event_type == event_type if event_type else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
- messages = []
- async for doc in cursor:
- messages.append(DLQMapper.from_mongo_document(doc))
+ query = DLQMessageDocument.find(*conditions)
+ total_count = await query.count()
+ docs = await query.sort([("failed_at", SortDirection.DESCENDING)]).skip(offset).limit(limit).to_list()
- return DLQMessageListResult(messages=messages, total=total_count, offset=offset, limit=limit)
+ return DLQMessageListResult(
+ messages=[self._doc_to_message(d) for d in docs],
+ total=total_count,
+ offset=offset,
+ limit=limit,
+ )
async def get_message_by_id(self, event_id: str) -> DLQMessage | None:
- doc = await self.dlq_collection.find_one({DLQFields.EVENT_ID: event_id})
- if not doc:
- return None
-
- return DLQMapper.from_mongo_document(doc)
+ doc = await DLQMessageDocument.find_one({"event_id": event_id})
+ return self._doc_to_message(doc) if doc else None
async def get_topics_summary(self) -> list[DLQTopicSummary]:
pipeline: list[Mapping[str, object]] = [
{
"$group": {
- "_id": f"${DLQFields.ORIGINAL_TOPIC}",
+ "_id": "$original_topic",
"count": {"$sum": 1},
- "statuses": {"$push": f"${DLQFields.STATUS}"},
- "oldest_message": {"$min": f"${DLQFields.FAILED_AT}"},
- "newest_message": {"$max": f"${DLQFields.FAILED_AT}"},
- "avg_retry_count": {"$avg": f"${DLQFields.RETRY_COUNT}"},
- "max_retry_count": {"$max": f"${DLQFields.RETRY_COUNT}"},
+ "statuses": {"$push": "$status"},
+ "oldest_message": {"$min": "$failed_at"},
+ "newest_message": {"$max": "$failed_at"},
+ "avg_retry_count": {"$avg": "$retry_count"},
+ "max_retry_count": {"$max": "$retry_count"},
}
},
{"$sort": {"count": -1}},
]
topics = []
- async for result in self.dlq_collection.aggregate(pipeline):
+ async for result in DLQMessageDocument.aggregate(pipeline):
status_counts: dict[str, int] = {}
for status in result["statuses"]:
status_counts[status] = status_counts.get(status, 0) + 1
@@ -168,55 +165,43 @@ async def get_topics_summary(self) -> list[DLQTopicSummary]:
return topics
async def mark_message_retried(self, event_id: str) -> bool:
+ doc = await DLQMessageDocument.find_one({"event_id": event_id})
+ if not doc:
+ return False
now = datetime.now(timezone.utc)
- result = await self.dlq_collection.update_one(
- {DLQFields.EVENT_ID: event_id},
- {
- "$set": {
- DLQFields.STATUS: DLQMessageStatus.RETRIED,
- DLQFields.RETRIED_AT: now,
- DLQFields.LAST_UPDATED: now,
- }
- },
- )
- return result.modified_count > 0
+ doc.status = DLQMessageStatus.RETRIED
+ doc.retried_at = now
+ doc.last_updated = now
+ await doc.save()
+ return True
async def mark_message_discarded(self, event_id: str, reason: str) -> bool:
+ doc = await DLQMessageDocument.find_one({"event_id": event_id})
+ if not doc:
+ return False
now = datetime.now(timezone.utc)
- result = await self.dlq_collection.update_one(
- {DLQFields.EVENT_ID: event_id},
- {
- "$set": {
- DLQFields.STATUS: DLQMessageStatus.DISCARDED.value,
- DLQFields.DISCARDED_AT: now,
- DLQFields.DISCARD_REASON: reason,
- DLQFields.LAST_UPDATED: now,
- }
- },
- )
- return result.modified_count > 0
+ doc.status = DLQMessageStatus.DISCARDED
+ doc.discarded_at = now
+ doc.discard_reason = reason
+ doc.last_updated = now
+ await doc.save()
+ return True
async def retry_messages_batch(self, event_ids: list[str], dlq_manager: DLQManager) -> DLQBatchRetryResult:
- """Retry a batch of DLQ messages."""
details = []
successful = 0
failed = 0
for event_id in event_ids:
try:
- # Get message from repository
- message = await self.get_message_by_id(event_id)
-
- if not message:
+ doc = await DLQMessageDocument.find_one({"event_id": event_id})
+ if not doc:
failed += 1
details.append(DLQRetryResult(event_id=event_id, status="failed", error="Message not found"))
continue
- # Use dlq_manager for retry logic
success = await dlq_manager.retry_message_manually(event_id)
-
if success:
- # Mark as retried
await self.mark_message_retried(event_id)
successful += 1
details.append(DLQRetryResult(event_id=event_id, status="success"))
@@ -225,7 +210,7 @@ async def retry_messages_batch(self, event_ids: list[str], dlq_manager: DLQManag
details.append(DLQRetryResult(event_id=event_id, status="failed", error="Retry failed"))
except Exception as e:
- logger.error(f"Error retrying message {event_id}: {e}")
+ self.logger.error(f"Error retrying message {event_id}: {e}")
failed += 1
details.append(DLQRetryResult(event_id=event_id, status="failed", error=str(e)))
diff --git a/backend/app/db/repositories/event_repository.py b/backend/app/db/repositories/event_repository.py
index a6b673fc..901f72ff 100644
--- a/backend/app/db/repositories/event_repository.py
+++ b/backend/app/db/repositories/event_repository.py
@@ -1,101 +1,89 @@
-from dataclasses import replace
+import logging
+from dataclasses import asdict
from datetime import datetime, timedelta, timezone
-from types import MappingProxyType
-from typing import Any, AsyncIterator, Mapping
+from typing import Any, Mapping
-from pymongo import ASCENDING, DESCENDING
+from beanie.odm.enums import SortDirection
+from beanie.operators import GTE, LT, LTE, In, Not, Or, RegEx
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
from app.core.tracing import EventAttributes
from app.core.tracing.utils import add_span_attributes
+from app.db.docs import EventArchiveDocument, EventDocument
from app.domain.enums.events import EventType
-from app.domain.enums.user import UserRole
-from app.domain.events import (
+from app.domain.events import Event
+from app.domain.events import EventMetadata as DomainEventMetadata
+from app.domain.events.event_models import (
ArchivedEvent,
- Event,
EventAggregationResult,
- EventFields,
- EventFilter,
EventListResult,
EventReplayInfo,
EventStatistics,
)
-from app.domain.events.event_models import CollectionNames
-from app.infrastructure.mappers import ArchivedEventMapper, EventFilterMapper, EventMapper
class EventRepository:
- def __init__(self, database: Database) -> None:
- self.database = database
- self.mapper = EventMapper()
- self._collection: Collection = self.database.get_collection(CollectionNames.EVENTS)
+ def __init__(self, logger: logging.Logger) -> None:
+ self.logger = logger
+
+ def _time_conditions(self, start_time: datetime | None, end_time: datetime | None) -> list[Any]:
+ """Build time range conditions for queries."""
+ conditions = [
+ GTE(EventDocument.timestamp, start_time) if start_time else None,
+ LTE(EventDocument.timestamp, end_time) if end_time else None,
+ ]
+ return [c for c in conditions if c is not None]
def _build_time_filter(self, start_time: datetime | None, end_time: datetime | None) -> dict[str, object]:
- """Build time range filter, eliminating if-else branching."""
+ """Build time filter dict for aggregation pipelines."""
return {key: value for key, value in {"$gte": start_time, "$lte": end_time}.items() if value is not None}
async def store_event(self, event: Event) -> str:
- """
- Store an event in the collection
-
- Args:
- event: Event domain model to store
-
- Returns:
- Event ID of stored event
-
- Raises:
- DuplicateKeyError: If event with same ID already exists
- """
- if not event.stored_at:
- event = replace(event, stored_at=datetime.now(timezone.utc))
-
- event_doc = self.mapper.to_mongo_document(event)
+ data = asdict(event)
+ meta = event.metadata.model_dump() if hasattr(event.metadata, "model_dump") else asdict(event.metadata)
+ data["metadata"] = {k: (v.value if hasattr(v, "value") else v) for k, v in meta.items()}
+ if not data.get("stored_at"):
+ data["stored_at"] = datetime.now(timezone.utc)
+ # Remove None values so EventDocument defaults can apply (e.g., ttl_expires_at)
+ data = {k: v for k, v in data.items() if v is not None}
+
+ doc = EventDocument(**data)
add_span_attributes(
**{
- str(EventAttributes.EVENT_TYPE): event.event_type,
+ str(EventAttributes.EVENT_TYPE): str(event.event_type),
str(EventAttributes.EVENT_ID): event.event_id,
str(EventAttributes.EXECUTION_ID): event.aggregate_id or "",
}
)
- _ = await self._collection.insert_one(event_doc)
-
- logger.debug(f"Stored event {event.event_id} of type {event.event_type}")
+ await doc.insert()
+ self.logger.debug(f"Stored event {event.event_id} of type {event.event_type}")
return event.event_id
async def store_events_batch(self, events: list[Event]) -> list[str]:
- """
- Store multiple events in a batch
-
- Args:
- events: List of event domain models to store
-
- Returns:
- List of stored event IDs
- """
if not events:
return []
now = datetime.now(timezone.utc)
- event_docs = []
+ docs = []
for event in events:
- if not event.stored_at:
- event = replace(event, stored_at=now)
- event_docs.append(self.mapper.to_mongo_document(event))
-
- result = await self._collection.insert_many(event_docs, ordered=False)
- add_span_attributes(
- **{
- "events.batch.count": len(event_docs),
- }
- )
-
- logger.info(f"Stored {len(result.inserted_ids)} events in batch")
+ data = asdict(event)
+ meta = event.metadata.model_dump() if hasattr(event.metadata, "model_dump") else asdict(event.metadata)
+ data["metadata"] = {k: (v.value if hasattr(v, "value") else v) for k, v in meta.items()}
+ if not data.get("stored_at"):
+ data["stored_at"] = now
+ # Remove None values so EventDocument defaults can apply
+ data = {k: v for k, v in data.items() if v is not None}
+ docs.append(EventDocument(**data))
+ await EventDocument.insert_many(docs)
+ add_span_attributes(**{"events.batch.count": len(events)})
+ self.logger.info(f"Stored {len(events)} events in batch")
return [event.event_id for event in events]
async def get_event(self, event_id: str) -> Event | None:
- result = await self._collection.find_one({EventFields.EVENT_ID: event_id})
- return self.mapper.from_mongo_document(result) if result else None
+ doc = await EventDocument.find_one({"event_id": event_id})
+ if not doc:
+ return None
+ data = doc.model_dump(exclude={"id", "revision_id"})
+ data["metadata"] = DomainEventMetadata(**data["metadata"])
+ return Event(**data)
async def get_events_by_type(
self,
@@ -105,34 +93,63 @@ async def get_events_by_type(
limit: int = 100,
skip: int = 0,
) -> list[Event]:
- query: dict[str, Any] = {EventFields.EVENT_TYPE: event_type}
- time_filter = self._build_time_filter(start_time, end_time)
- if time_filter:
- query[EventFields.TIMESTAMP] = time_filter
-
- cursor = self._collection.find(query).sort(EventFields.TIMESTAMP, DESCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
- return [self.mapper.from_mongo_document(doc) for doc in docs]
+ conditions = [
+ EventDocument.event_type == event_type,
+ *self._time_conditions(start_time, end_time),
+ ]
+ docs = (
+ await EventDocument.find(*conditions)
+ .sort([("timestamp", SortDirection.DESCENDING)])
+ .skip(skip)
+ .limit(limit)
+ .to_list()
+ )
+ return [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
async def get_events_by_aggregate(
self, aggregate_id: str, event_types: list[EventType] | None = None, limit: int = 100
) -> list[Event]:
- query: dict[str, Any] = {EventFields.AGGREGATE_ID: aggregate_id}
- if event_types:
- query[EventFields.EVENT_TYPE] = {"$in": [t.value for t in event_types]}
-
- cursor = self._collection.find(query).sort(EventFields.TIMESTAMP, ASCENDING).limit(limit)
- docs = await cursor.to_list(length=limit)
- return [self.mapper.from_mongo_document(doc) for doc in docs]
+ conditions = [
+ EventDocument.aggregate_id == aggregate_id,
+ In(EventDocument.event_type, [t.value for t in event_types]) if event_types else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
+ docs = (
+ await EventDocument.find(*conditions).sort([("timestamp", SortDirection.ASCENDING)]).limit(limit).to_list()
+ )
+ return [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
async def get_events_by_correlation(self, correlation_id: str, limit: int = 100, skip: int = 0) -> EventListResult:
- query: dict[str, Any] = {EventFields.METADATA_CORRELATION_ID: correlation_id}
- total_count = await self._collection.count_documents(query)
-
- cursor = self._collection.find(query).sort(EventFields.TIMESTAMP, ASCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
+ query = EventDocument.find(EventDocument.metadata.correlation_id == correlation_id)
+ total_count = await query.count()
+ docs = await query.sort([("timestamp", SortDirection.ASCENDING)]).skip(skip).limit(limit).to_list()
+ events = [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
return EventListResult(
- events=[self.mapper.from_mongo_document(doc) for doc in docs],
+ events=events,
total=total_count,
skip=skip,
limit=limit,
@@ -148,143 +165,89 @@ async def get_events_by_user(
limit: int = 100,
skip: int = 0,
) -> list[Event]:
- query: dict[str, Any] = {EventFields.METADATA_USER_ID: user_id}
- if event_types:
- query[EventFields.EVENT_TYPE] = {"$in": event_types}
- time_filter = self._build_time_filter(start_time, end_time)
- if time_filter:
- query[EventFields.TIMESTAMP] = time_filter
-
- cursor = self._collection.find(query).sort(EventFields.TIMESTAMP, DESCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
- return [self.mapper.from_mongo_document(doc) for doc in docs]
+ conditions = [
+ EventDocument.metadata.user_id == user_id,
+ In(EventDocument.event_type, event_types) if event_types else None,
+ *self._time_conditions(start_time, end_time),
+ ]
+ conditions = [c for c in conditions if c is not None]
+ docs = (
+ await EventDocument.find(*conditions)
+ .sort([("timestamp", SortDirection.DESCENDING)])
+ .skip(skip)
+ .limit(limit)
+ .to_list()
+ )
+ return [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
async def get_execution_events(
self, execution_id: str, limit: int = 100, skip: int = 0, exclude_system_events: bool = False
) -> EventListResult:
- query: dict[str, Any] = {
- "$or": [{EventFields.PAYLOAD_EXECUTION_ID: execution_id}, {EventFields.AGGREGATE_ID: execution_id}]
- }
-
- # Filter out system events at DB level for accurate pagination
- if exclude_system_events:
- query[EventFields.METADATA_SERVICE_NAME] = {"$not": {"$regex": "^system-"}}
-
- total_count = await self._collection.count_documents(query)
-
- cursor = self._collection.find(query).sort(EventFields.TIMESTAMP, ASCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
+ conditions: list[Any] = [
+ Or(
+ EventDocument.payload["execution_id"] == execution_id,
+ EventDocument.aggregate_id == execution_id,
+ ),
+ Not(RegEx(EventDocument.metadata.service_name, "^system-")) if exclude_system_events else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
+ query = EventDocument.find(*conditions)
+ total_count = await query.count()
+ docs = await query.sort([("timestamp", SortDirection.ASCENDING)]).skip(skip).limit(limit).to_list()
+ events = [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
return EventListResult(
- events=[self.mapper.from_mongo_document(doc) for doc in docs],
+ events=events,
total=total_count,
skip=skip,
limit=limit,
has_more=(skip + limit) < total_count,
)
- async def search_events(
- self, text_query: str, filters: dict[str, object] | None = None, limit: int = 100, skip: int = 0
- ) -> list[Event]:
- query: dict[str, object] = {"$text": {"$search": text_query}}
- if filters:
- query.update(filters)
-
- cursor = self._collection.find(query).sort(EventFields.TIMESTAMP, DESCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
- return [self.mapper.from_mongo_document(doc) for doc in docs]
-
async def get_event_statistics(
- self, start_time: datetime | None = None, end_time: datetime | None = None
- ) -> EventStatistics:
- pipeline: list[Mapping[str, object]] = []
-
- time_filter = self._build_time_filter(start_time, end_time)
- if time_filter:
- pipeline.append({"$match": {EventFields.TIMESTAMP: time_filter}})
-
- pipeline.extend(
- [
- {
- "$facet": {
- "by_type": [
- {"$group": {"_id": f"${EventFields.EVENT_TYPE}", "count": {"$sum": 1}}},
- {"$sort": {"count": -1}},
- ],
- "by_service": [
- {"$group": {"_id": f"${EventFields.METADATA_SERVICE_NAME}", "count": {"$sum": 1}}},
- {"$sort": {"count": -1}},
- ],
- "by_hour": [
- {
- "$group": {
- "_id": {
- "$dateToString": {
- "format": "%Y-%m-%d %H:00",
- "date": f"${EventFields.TIMESTAMP}",
- }
- },
- "count": {"$sum": 1},
- }
- },
- {"$sort": {"_id": 1}},
- ],
- "total": [{"$count": "count"}],
- }
- }
- ]
- )
-
- result = await self._collection.aggregate(pipeline).to_list(length=1)
-
- if result:
- stats = result[0]
- return EventStatistics(
- total_events=stats["total"][0]["count"] if stats["total"] else 0,
- events_by_type={item["_id"]: item["count"] for item in stats["by_type"]},
- events_by_service={item["_id"]: item["count"] for item in stats["by_service"]},
- events_by_hour=stats["by_hour"],
- )
-
- return EventStatistics(total_events=0, events_by_type={}, events_by_service={}, events_by_hour=[])
-
- async def get_event_statistics_filtered(
self,
- match: Mapping[str, object] = MappingProxyType({}),
start_time: datetime | None = None,
end_time: datetime | None = None,
+ match: dict[str, object] | None = None,
) -> EventStatistics:
pipeline: list[Mapping[str, object]] = []
-
- and_clauses: list[dict[str, object]] = []
if match:
- and_clauses.append(dict(match))
+ pipeline.append({"$match": match})
time_filter = self._build_time_filter(start_time, end_time)
if time_filter:
- and_clauses.append({EventFields.TIMESTAMP: time_filter})
- if and_clauses:
- pipeline.append({"$match": {"$and": and_clauses}})
+ pipeline.append({"$match": {"timestamp": time_filter}})
pipeline.extend(
[
{
"$facet": {
"by_type": [
- {"$group": {"_id": f"${EventFields.EVENT_TYPE}", "count": {"$sum": 1}}},
+ {"$group": {"_id": "$event_type", "count": {"$sum": 1}}},
{"$sort": {"count": -1}},
],
"by_service": [
- {"$group": {"_id": f"${EventFields.METADATA_SERVICE_NAME}", "count": {"$sum": 1}}},
+ {"$group": {"_id": "$metadata.service_name", "count": {"$sum": 1}}},
{"$sort": {"count": -1}},
],
"by_hour": [
{
"$group": {
- "_id": {
- "$dateToString": {
- "format": "%Y-%m-%d %H:00",
- "date": f"${EventFields.TIMESTAMP}",
- }
- },
+ "_id": {"$dateToString": {"format": "%Y-%m-%d %H:00", "date": "$timestamp"}},
"count": {"$sum": 1},
}
},
@@ -292,70 +255,57 @@ async def get_event_statistics_filtered(
],
"total": [{"$count": "count"}],
}
- }
+ },
+ {
+ "$project": {
+ "_id": 0,
+ "total_events": {"$ifNull": [{"$arrayElemAt": ["$total.count", 0]}, 0]},
+ "events_by_type": {
+ "$arrayToObject": {
+ "$map": {"input": "$by_type", "as": "t", "in": {"k": "$$t._id", "v": "$$t.count"}}
+ }
+ },
+ "events_by_service": {
+ "$arrayToObject": {
+ "$map": {"input": "$by_service", "as": "s", "in": {"k": "$$s._id", "v": "$$s.count"}}
+ }
+ },
+ "events_by_hour": {
+ "$map": {
+ "input": "$by_hour",
+ "as": "h",
+ "in": {"hour": "$$h._id", "count": "$$h.count"},
+ }
+ },
+ }
+ },
]
)
- result = await self._collection.aggregate(pipeline).to_list(length=1)
- if result:
- stats = result[0]
- return EventStatistics(
- total_events=stats["total"][0]["count"] if stats["total"] else 0,
- events_by_type={item["_id"]: item["count"] for item in stats["by_type"]},
- events_by_service={item["_id"]: item["count"] for item in stats["by_service"]},
- events_by_hour=stats["by_hour"],
- )
- return EventStatistics(total_events=0, events_by_type={}, events_by_service={}, events_by_hour=[])
-
- async def stream_events(
- self, filters: dict[str, object] | None = None, start_after: dict[str, object] | None = None
- ) -> AsyncIterator[dict[str, object]]:
- """
- Stream events using change streams for real-time updates
-
- Args:
- filters: Optional filters for events
- start_after: Resume token for continuing from previous position
- """
- pipeline: list[Mapping[str, object]] = []
- if filters:
- pipeline.append({"$match": filters})
+ async for doc in EventDocument.aggregate(pipeline):
+ return EventStatistics(**doc)
- async with self._collection.watch(pipeline, start_after=start_after, full_document="updateLookup") as stream:
- async for change in stream:
- if change["operationType"] in ["insert", "update", "replace"]:
- yield change["fullDocument"]
+ return EventStatistics(total_events=0, events_by_type={}, events_by_service={}, events_by_hour=[])
async def cleanup_old_events(
self, older_than_days: int = 30, event_types: list[str] | None = None, dry_run: bool = False
) -> int:
- """
- Manually cleanup old events (in addition to TTL)
-
- Args:
- older_than_days: Delete events older than this many days
- event_types: Only cleanup specific event types
- dry_run: If True, only count events without deleting
-
- Returns:
- Number of events deleted (or would be deleted if dry_run)
- """
cutoff_dt = datetime.now(timezone.utc) - timedelta(days=older_than_days)
-
- query: dict[str, Any] = {EventFields.TIMESTAMP: {"$lt": cutoff_dt}}
- if event_types:
- query[EventFields.EVENT_TYPE] = {"$in": event_types}
+ conditions: list[Any] = [
+ LT(EventDocument.timestamp, cutoff_dt),
+ In(EventDocument.event_type, event_types) if event_types else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
if dry_run:
- count = await self._collection.count_documents(query)
- logger.info(f"Would delete {count} events older than {older_than_days} days")
+ count = await EventDocument.find(*conditions).count()
+ self.logger.info(f"Would delete {count} events older than {older_than_days} days")
return count
- result = await self._collection.delete_many(query)
- logger.info(f"Deleted {result.deleted_count} events older than {older_than_days} days")
- return result.deleted_count
-
- # Access checks are handled in the service layer.
+ result = await EventDocument.find(*conditions).delete()
+ deleted_count = result.deleted_count if result else 0
+ self.logger.info(f"Deleted {deleted_count} events older than {older_than_days} days")
+ return deleted_count
async def get_user_events_paginated(
self,
@@ -367,180 +317,144 @@ async def get_user_events_paginated(
skip: int = 0,
sort_order: str = "desc",
) -> EventListResult:
- """Get paginated user events with count"""
- query: dict[str, Any] = {EventFields.METADATA_USER_ID: user_id}
- if event_types:
- query[EventFields.EVENT_TYPE] = {"$in": event_types}
- time_filter = self._build_time_filter(start_time, end_time)
- if time_filter:
- query[EventFields.TIMESTAMP] = time_filter
-
- total_count = await self._collection.count_documents(query)
-
- sort_direction = DESCENDING if sort_order == "desc" else ASCENDING
- cursor = self._collection.find(query)
- cursor = cursor.sort(EventFields.TIMESTAMP, sort_direction)
- cursor = cursor.skip(skip).limit(limit)
-
- docs = []
- async for doc in cursor:
- docs.append(doc)
+ conditions = [
+ EventDocument.metadata.user_id == user_id,
+ In(EventDocument.event_type, event_types) if event_types else None,
+ *self._time_conditions(start_time, end_time),
+ ]
+ conditions = [c for c in conditions if c is not None]
+
+ query = EventDocument.find(*conditions)
+ total_count = await query.count()
+ sort_direction = SortDirection.DESCENDING if sort_order == "desc" else SortDirection.ASCENDING
+ docs = await query.sort([("timestamp", sort_direction)]).skip(skip).limit(limit).to_list()
+ events = [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
return EventListResult(
- events=[self.mapper.from_mongo_document(doc) for doc in docs],
+ events=events,
total=total_count,
skip=skip,
limit=limit,
has_more=(skip + limit) < total_count,
)
- async def query_events_advanced(self, user_id: str, user_role: str, filters: EventFilter) -> EventListResult | None:
- """Advanced event query with filters"""
- query: dict[str, object] = {}
-
- # User access control
- if filters.user_id:
- if filters.user_id != user_id and user_role != UserRole.ADMIN:
- return None # Signal unauthorized
- query[EventFields.METADATA_USER_ID] = filters.user_id
- elif user_role != UserRole.ADMIN:
- query[EventFields.METADATA_USER_ID] = user_id
-
- # Apply filters using mapper from domain filter
- base_query = EventFilterMapper.to_mongo_query(filters)
- query.update(base_query)
-
- total_count = await self._collection.count_documents(query)
-
- sort_field = EventFields.TIMESTAMP
- sort_direction = DESCENDING
-
- cursor = self._collection.find(query)
- cursor = cursor.sort(sort_field, sort_direction)
- cursor = cursor.skip(0).limit(100)
+ async def count_events(self, *conditions: Any) -> int:
+ return await EventDocument.find(*conditions).count()
- docs = []
- async for doc in cursor:
- docs.append(doc)
-
- result_obj = EventListResult(
- events=[self.mapper.from_mongo_document(doc) for doc in docs],
- total=total_count,
- skip=0,
- limit=100,
- has_more=100 < total_count,
- )
- add_span_attributes(**{"events.query.total": total_count})
- return result_obj
-
- async def aggregate_events(self, pipeline: list[dict[str, object]], limit: int = 100) -> EventAggregationResult:
- pipeline = pipeline.copy()
- pipeline.append({"$limit": limit})
-
- results = []
- async for doc in self._collection.aggregate(pipeline):
- if "_id" in doc and isinstance(doc["_id"], dict):
- doc["_id"] = str(doc["_id"])
- results.append(doc)
-
- return EventAggregationResult(results=results, pipeline=pipeline)
-
- async def list_event_types(self, match: Mapping[str, object] = MappingProxyType({})) -> list[str]:
- pipeline: list[Mapping[str, object]] = []
- if match:
- pipeline.append({"$match": dict(match)})
- pipeline.extend([{"$group": {"_id": f"${EventFields.EVENT_TYPE}"}}, {"$sort": {"_id": 1}}])
- event_types: list[str] = []
- async for doc in self._collection.aggregate(pipeline):
- event_types.append(doc["_id"])
- return event_types
-
- async def query_events_generic(
+ async def query_events(
self,
- query: dict[str, object],
- sort_field: str,
- sort_direction: int,
- skip: int,
- limit: int,
+ query: dict[str, Any],
+ sort_field: str = "timestamp",
+ skip: int = 0,
+ limit: int = 100,
) -> EventListResult:
- total_count = await self._collection.count_documents(query)
-
- cursor = self._collection.find(query)
- cursor = cursor.sort(sort_field, sort_direction)
- cursor = cursor.skip(skip).limit(limit)
+ """Query events with filter, sort, and pagination. Always sorts descending (newest first)."""
+ cursor = EventDocument.find(query)
+ total_count = await cursor.count()
+ docs = await cursor.sort([(sort_field, SortDirection.DESCENDING)]).skip(skip).limit(limit).to_list()
+ events = [
+ Event(
+ **{
+ **d.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**d.metadata.model_dump()),
+ }
+ )
+ for d in docs
+ ]
+ return EventListResult(
+ events=events, total=total_count, skip=skip, limit=limit, has_more=(skip + limit) < total_count
+ )
- docs = []
- async for doc in cursor:
- docs.append(doc)
+ async def aggregate_events(self, pipeline: list[dict[str, Any]], limit: int = 100) -> EventAggregationResult:
+ """Run aggregation pipeline on events."""
+ pipeline_with_limit = [*pipeline, {"$limit": limit}]
+ results = await EventDocument.aggregate(pipeline_with_limit).to_list()
+ return EventAggregationResult(results=results, pipeline=pipeline_with_limit)
- return EventListResult(
- events=[self.mapper.from_mongo_document(doc) for doc in docs],
- total=total_count,
- skip=skip,
- limit=limit,
- has_more=(skip + limit) < total_count,
+ async def list_event_types(self, match: dict[str, object] | None = None) -> list[str]:
+ """List distinct event types, optionally filtered."""
+ pipeline: list[dict[str, object]] = []
+ if match:
+ pipeline.append({"$match": match})
+ pipeline.extend(
+ [
+ {"$group": {"_id": "$event_type"}},
+ {"$sort": {"_id": 1}},
+ ]
)
+ results: list[dict[str, str]] = await EventDocument.aggregate(pipeline).to_list()
+ return [doc["_id"] for doc in results if doc.get("_id")]
async def delete_event_with_archival(
self, event_id: str, deleted_by: str, deletion_reason: str = "Admin deletion via API"
) -> ArchivedEvent | None:
- """Delete event and archive it"""
- event = await self.get_event(event_id)
-
- if not event:
+ doc = await EventDocument.find_one({"event_id": event_id})
+ if not doc:
return None
- # Create archived event
- archived_event = ArchivedEvent(
- event_id=event.event_id,
- event_type=event.event_type,
- event_version=event.event_version,
- timestamp=event.timestamp,
- metadata=event.metadata,
- payload=event.payload,
- aggregate_id=event.aggregate_id,
- stored_at=event.stored_at,
- ttl_expires_at=event.ttl_expires_at,
- status=event.status,
- error=event.error,
- deleted_at=datetime.now(timezone.utc),
+ deleted_at = datetime.now(timezone.utc)
+ archived_doc = EventArchiveDocument(
+ event_id=doc.event_id,
+ event_type=doc.event_type,
+ event_version=doc.event_version,
+ timestamp=doc.timestamp,
+ metadata=doc.metadata,
+ payload=doc.payload,
+ aggregate_id=doc.aggregate_id,
+ stored_at=doc.stored_at,
+ ttl_expires_at=doc.ttl_expires_at,
+ deleted_at=deleted_at,
+ deleted_by=deleted_by,
+ deletion_reason=deletion_reason,
+ )
+ await archived_doc.insert()
+ await doc.delete()
+ return ArchivedEvent(
+ **{
+ **doc.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainEventMetadata(**doc.metadata.model_dump()),
+ },
+ deleted_at=deleted_at,
deleted_by=deleted_by,
deletion_reason=deletion_reason,
)
-
- # Archive the event
- archive_collection = self.database.get_collection(CollectionNames.EVENTS_ARCHIVE)
- archived_mapper = ArchivedEventMapper()
- await archive_collection.insert_one(archived_mapper.to_mongo_document(archived_event))
-
- # Delete from main collection
- result = await self._collection.delete_one({EventFields.EVENT_ID: event_id})
-
- if result.deleted_count == 0:
- raise Exception("Failed to delete event")
-
- return archived_event
async def get_aggregate_events_for_replay(self, aggregate_id: str, limit: int = 10000) -> list[Event]:
- """Get all events for an aggregate for replay purposes"""
- events = await self.get_events_by_aggregate(aggregate_id=aggregate_id, limit=limit)
-
- if not events:
- return []
-
- return events
+ return await self.get_events_by_aggregate(aggregate_id=aggregate_id, limit=limit)
async def get_aggregate_replay_info(self, aggregate_id: str) -> EventReplayInfo | None:
- """Get aggregate events and prepare replay information"""
- events = await self.get_aggregate_events_for_replay(aggregate_id)
-
- if not events:
- return None
+ pipeline = [
+ {"$match": {"aggregate_id": aggregate_id}},
+ {"$sort": {"timestamp": 1}},
+ {
+ "$group": {
+ "_id": None,
+ "events": {"$push": "$$ROOT"},
+ "event_count": {"$sum": 1},
+ "event_types": {"$addToSet": "$event_type"},
+ "start_time": {"$min": "$timestamp"},
+ "end_time": {"$max": "$timestamp"},
+ }
+ },
+ {"$project": {"_id": 0}},
+ ]
+
+ async for doc in EventDocument.aggregate(pipeline):
+ events = [Event(**{**e, "metadata": DomainEventMetadata(**e["metadata"])}) for e in doc["events"]]
+ return EventReplayInfo(
+ events=events,
+ event_count=doc["event_count"],
+ event_types=doc["event_types"],
+ start_time=doc["start_time"],
+ end_time=doc["end_time"],
+ )
- return EventReplayInfo(
- events=events,
- event_count=len(events),
- event_types=list(set(e.event_type for e in events)),
- start_time=min(e.timestamp for e in events),
- end_time=max(e.timestamp for e in events),
- )
+ return None
diff --git a/backend/app/db/repositories/execution_repository.py b/backend/app/db/repositories/execution_repository.py
index a0d8fcd0..f0a8fcb6 100644
--- a/backend/app/db/repositories/execution_repository.py
+++ b/backend/app/db/repositories/execution_repository.py
@@ -1,162 +1,116 @@
+import logging
+from dataclasses import asdict
from datetime import datetime, timezone
from typing import Any
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
-from app.domain.enums.execution import ExecutionStatus
-from app.domain.events.event_models import CollectionNames
-from app.domain.execution import DomainExecution, ExecutionResultDomain, ResourceUsageDomain
+from beanie.odm.enums import SortDirection
+from app.db.docs import ExecutionDocument, ResourceUsage
+from app.domain.execution import (
+ DomainExecution,
+ DomainExecutionCreate,
+ DomainExecutionUpdate,
+ ExecutionResultDomain,
+ ResourceUsageDomain,
+)
-class ExecutionRepository:
- def __init__(self, db: Database):
- self.db = db
- self.collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
- self.results_collection: Collection = self.db.get_collection(CollectionNames.EXECUTION_RESULTS)
- async def create_execution(self, execution: DomainExecution) -> DomainExecution:
- execution_dict = {
- "execution_id": execution.execution_id,
- "script": execution.script,
- "status": execution.status,
- "stdout": execution.stdout,
- "stderr": execution.stderr,
- "lang": execution.lang,
- "lang_version": execution.lang_version,
- "created_at": execution.created_at,
- "updated_at": execution.updated_at,
- "resource_usage": execution.resource_usage.to_dict() if execution.resource_usage else None,
- "user_id": execution.user_id,
- "exit_code": execution.exit_code,
- "error_type": execution.error_type,
- }
- logger.info(f"Inserting execution {execution.execution_id} into MongoDB")
- result = await self.collection.insert_one(execution_dict)
- logger.info(f"Inserted execution {execution.execution_id} with _id: {result.inserted_id}")
- return execution
+class ExecutionRepository:
+ def __init__(self, logger: logging.Logger):
+ self.logger = logger
+
+ async def create_execution(self, create_data: DomainExecutionCreate) -> DomainExecution:
+ doc = ExecutionDocument(**asdict(create_data))
+ self.logger.info("Inserting execution into MongoDB", extra={"execution_id": doc.execution_id})
+ await doc.insert()
+ self.logger.info("Inserted execution", extra={"execution_id": doc.execution_id})
+ return DomainExecution(
+ **{
+ **doc.model_dump(exclude={"id"}),
+ "resource_usage": ResourceUsageDomain(**doc.resource_usage.model_dump())
+ if doc.resource_usage
+ else None,
+ }
+ )
async def get_execution(self, execution_id: str) -> DomainExecution | None:
- logger.info(f"Searching for execution {execution_id} in MongoDB")
- document = await self.collection.find_one({"execution_id": execution_id})
- if not document:
- logger.warning(f"Execution {execution_id} not found in MongoDB")
+ self.logger.info("Searching for execution in MongoDB", extra={"execution_id": execution_id})
+ doc = await ExecutionDocument.find_one({"execution_id": execution_id})
+ if not doc:
+ self.logger.warning("Execution not found in MongoDB", extra={"execution_id": execution_id})
return None
- logger.info(f"Found execution {execution_id} in MongoDB")
-
- result_doc = await self.results_collection.find_one({"execution_id": execution_id})
- if result_doc:
- document["stdout"] = result_doc.get("stdout")
- document["stderr"] = result_doc.get("stderr")
- document["exit_code"] = result_doc.get("exit_code")
- document["resource_usage"] = result_doc.get("resource_usage")
- document["error_type"] = result_doc.get("error_type")
- if result_doc.get("status"):
- document["status"] = result_doc.get("status")
-
- sv = document.get("status")
- resource_usage_data = document.get("resource_usage")
+ self.logger.info("Found execution in MongoDB", extra={"execution_id": execution_id})
return DomainExecution(
- execution_id=document["execution_id"],
- script=document.get("script", ""),
- status=ExecutionStatus(str(sv)),
- stdout=document.get("stdout"),
- stderr=document.get("stderr"),
- lang=document.get("lang", "python"),
- lang_version=document.get("lang_version", "3.11"),
- created_at=document.get("created_at", datetime.now(timezone.utc)),
- updated_at=document.get("updated_at", datetime.now(timezone.utc)),
- resource_usage=(
- ResourceUsageDomain.from_dict(resource_usage_data) if resource_usage_data is not None else None
- ),
- user_id=document.get("user_id"),
- exit_code=document.get("exit_code"),
- error_type=document.get("error_type"),
+ **{
+ **doc.model_dump(exclude={"id"}),
+ "resource_usage": ResourceUsageDomain(**doc.resource_usage.model_dump())
+ if doc.resource_usage
+ else None,
+ }
)
- async def update_execution(self, execution_id: str, update_data: dict[str, Any]) -> bool:
- update_data.setdefault("updated_at", datetime.now(timezone.utc))
- update_payload = {"$set": update_data}
-
- result = await self.collection.update_one({"execution_id": execution_id}, update_payload)
- return result.matched_count > 0
-
- async def write_terminal_result(self, exec_result: ExecutionResultDomain) -> bool:
- base = await self.collection.find_one({"execution_id": exec_result.execution_id}, {"user_id": 1}) or {}
- user_id = base.get("user_id")
-
- doc = {
- "_id": exec_result.execution_id,
- "execution_id": exec_result.execution_id,
- "status": exec_result.status.value,
- "exit_code": exec_result.exit_code,
- "stdout": exec_result.stdout,
- "stderr": exec_result.stderr,
- "resource_usage": exec_result.resource_usage.to_dict(),
- "created_at": exec_result.created_at,
- "metadata": exec_result.metadata,
- }
- if exec_result.error_type is not None:
- doc["error_type"] = exec_result.error_type
- if user_id is not None:
- doc["user_id"] = user_id
-
- await self.results_collection.replace_one({"_id": exec_result.execution_id}, doc, upsert=True)
-
- update_data = {
- "status": exec_result.status.value,
- "updated_at": datetime.now(timezone.utc),
- "stdout": exec_result.stdout,
- "stderr": exec_result.stderr,
- "exit_code": exec_result.exit_code,
- "resource_usage": exec_result.resource_usage.to_dict(),
- }
- if exec_result.error_type is not None:
- update_data["error_type"] = exec_result.error_type
+ async def update_execution(self, execution_id: str, update_data: DomainExecutionUpdate) -> bool:
+ doc = await ExecutionDocument.find_one({"execution_id": execution_id})
+ if not doc:
+ return False
+
+ update_dict = {k: v for k, v in asdict(update_data).items() if v is not None}
+ if "resource_usage" in update_dict:
+ update_dict["resource_usage"] = ResourceUsage.model_validate(update_data.resource_usage)
+ if update_dict:
+ update_dict["updated_at"] = datetime.now(timezone.utc)
+ await doc.set(update_dict)
+ return True
- res = await self.collection.update_one({"execution_id": exec_result.execution_id}, {"$set": update_data})
- if res.matched_count == 0:
- logger.warning(f"No execution found to patch for {exec_result.execution_id} after result upsert")
+ async def write_terminal_result(self, result: ExecutionResultDomain) -> bool:
+ doc = await ExecutionDocument.find_one({"execution_id": result.execution_id})
+ if not doc:
+ self.logger.warning("No execution found", extra={"execution_id": result.execution_id})
+ return False
+
+ await doc.set(
+ {
+ "status": result.status,
+ "exit_code": result.exit_code,
+ "stdout": result.stdout,
+ "stderr": result.stderr,
+ "resource_usage": ResourceUsage.model_validate(result.resource_usage),
+ "error_type": result.error_type,
+ "updated_at": datetime.now(timezone.utc),
+ }
+ )
return True
async def get_executions(
self, query: dict[str, Any], limit: int = 50, skip: int = 0, sort: list[tuple[str, int]] | None = None
) -> list[DomainExecution]:
- cursor = self.collection.find(query)
+ find_query = ExecutionDocument.find(query)
if sort:
- cursor = cursor.sort(sort)
- cursor = cursor.skip(skip).limit(limit)
-
- executions: list[DomainExecution] = []
- async for doc in cursor:
- sv = doc.get("status")
- executions.append(
- DomainExecution(
- execution_id=doc.get("execution_id"),
- script=doc.get("script", ""),
- status=ExecutionStatus(str(sv)),
- stdout=doc.get("stdout"),
- stderr=doc.get("stderr"),
- lang=doc.get("lang", "python"),
- lang_version=doc.get("lang_version", "3.11"),
- created_at=doc.get("created_at", datetime.now(timezone.utc)),
- updated_at=doc.get("updated_at", datetime.now(timezone.utc)),
- resource_usage=(
- ResourceUsageDomain.from_dict(doc.get("resource_usage"))
- if doc.get("resource_usage") is not None
- else None
- ),
- user_id=doc.get("user_id"),
- exit_code=doc.get("exit_code"),
- error_type=doc.get("error_type"),
- )
+ beanie_sort = [
+ (field, SortDirection.ASCENDING if direction == 1 else SortDirection.DESCENDING)
+ for field, direction in sort
+ ]
+ find_query = find_query.sort(beanie_sort)
+ docs = await find_query.skip(skip).limit(limit).to_list()
+ return [
+ DomainExecution(
+ **{
+ **doc.model_dump(exclude={"id"}),
+ "resource_usage": ResourceUsageDomain(**doc.resource_usage.model_dump())
+ if doc.resource_usage
+ else None,
+ }
)
-
- return executions
+ for doc in docs
+ ]
async def count_executions(self, query: dict[str, Any]) -> int:
- return await self.collection.count_documents(query)
+ return await ExecutionDocument.find(query).count()
async def delete_execution(self, execution_id: str) -> bool:
- result = await self.collection.delete_one({"execution_id": execution_id})
- return result.deleted_count > 0
+ doc = await ExecutionDocument.find_one({"execution_id": execution_id})
+ if not doc:
+ return False
+ await doc.delete()
+ return True
diff --git a/backend/app/db/repositories/notification_repository.py b/backend/app/db/repositories/notification_repository.py
index bc0f3709..6facbe8e 100644
--- a/backend/app/db/repositories/notification_repository.py
+++ b/backend/app/db/repositories/notification_repository.py
@@ -1,92 +1,67 @@
+import logging
+from dataclasses import asdict
from datetime import UTC, datetime, timedelta
-from pymongo import ASCENDING, DESCENDING, IndexModel
+from beanie.odm.enums import SortDirection
+from beanie.operators import GTE, LT, LTE, ElemMatch, In, NotIn, Or
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
-from app.domain.enums.notification import (
- NotificationChannel,
- NotificationStatus,
-)
+from app.db.docs import NotificationDocument, NotificationSubscriptionDocument, UserDocument
+from app.domain.enums.notification import NotificationChannel, NotificationStatus
from app.domain.enums.user import UserRole
-from app.domain.events.event_models import CollectionNames
-from app.domain.notification import DomainNotification, DomainNotificationSubscription
-from app.domain.user import UserFields
-from app.infrastructure.mappers import NotificationMapper
+from app.domain.notification import (
+ DomainNotification,
+ DomainNotificationCreate,
+ DomainNotificationSubscription,
+ DomainNotificationUpdate,
+ DomainSubscriptionUpdate,
+)
class NotificationRepository:
- def __init__(self, database: Database):
- self.db: Database = database
-
- self.notifications_collection: Collection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
- self.subscriptions_collection: Collection = self.db.get_collection(CollectionNames.NOTIFICATION_SUBSCRIPTIONS)
- self.mapper = NotificationMapper()
-
- async def create_indexes(self) -> None:
- # Create indexes if only _id exists
- notif_indexes = await self.notifications_collection.list_indexes().to_list(None)
- if len(notif_indexes) <= 1:
- await self.notifications_collection.create_indexes(
- [
- IndexModel([("user_id", ASCENDING), ("created_at", DESCENDING)]),
- IndexModel([("status", ASCENDING), ("scheduled_for", ASCENDING)]),
- IndexModel([("created_at", ASCENDING)]),
- IndexModel([("notification_id", ASCENDING)], unique=True),
- # Multikey index to speed up tag queries (include/exclude/prefix)
- IndexModel([("tags", ASCENDING)]),
- ]
- )
-
- subs_indexes = await self.subscriptions_collection.list_indexes().to_list(None)
- if len(subs_indexes) <= 1:
- await self.subscriptions_collection.create_indexes(
- [
- IndexModel([("user_id", ASCENDING), ("channel", ASCENDING)], unique=True),
- IndexModel([("enabled", ASCENDING)]),
- IndexModel([("include_tags", ASCENDING)]),
- IndexModel([("severities", ASCENDING)]),
- ]
- )
-
- # Notifications
- async def create_notification(self, notification: DomainNotification) -> str:
- doc = self.mapper.to_mongo_document(notification)
- result = await self.notifications_collection.insert_one(doc)
- return str(result.inserted_id)
-
- async def update_notification(self, notification: DomainNotification) -> bool:
- update = self.mapper.to_update_dict(notification)
- result = await self.notifications_collection.update_one(
- {"notification_id": str(notification.notification_id)}, {"$set": update}
- )
- return result.modified_count > 0
+ def __init__(self, logger: logging.Logger):
+ self.logger = logger
+
+ async def create_notification(self, create_data: DomainNotificationCreate) -> DomainNotification:
+ doc = NotificationDocument(**asdict(create_data))
+ await doc.insert()
+ return DomainNotification(**doc.model_dump(exclude={"id"}))
+
+ async def update_notification(
+ self, notification_id: str, user_id: str, update_data: DomainNotificationUpdate
+ ) -> bool:
+ doc = await NotificationDocument.find_one({"notification_id": notification_id, "user_id": user_id})
+ if not doc:
+ return False
+ update_dict = {k: v for k, v in asdict(update_data).items() if v is not None}
+ if update_dict:
+ await doc.set(update_dict)
+ return True
async def get_notification(self, notification_id: str, user_id: str) -> DomainNotification | None:
- doc = await self.notifications_collection.find_one({"notification_id": notification_id, "user_id": user_id})
+ doc = await NotificationDocument.find_one({"notification_id": notification_id, "user_id": user_id})
if not doc:
return None
- return self.mapper.from_mongo_document(doc)
+ return DomainNotification(**doc.model_dump(exclude={"id"}))
async def mark_as_read(self, notification_id: str, user_id: str) -> bool:
- result = await self.notifications_collection.update_one(
- {"notification_id": notification_id, "user_id": user_id},
- {"$set": {"status": NotificationStatus.READ, "read_at": datetime.now(UTC)}},
- )
- return result.modified_count > 0
+ doc = await NotificationDocument.find_one({"notification_id": notification_id, "user_id": user_id})
+ if not doc:
+ return False
+ await doc.set({"status": NotificationStatus.READ, "read_at": datetime.now(UTC)})
+ return True
async def mark_all_as_read(self, user_id: str) -> int:
- result = await self.notifications_collection.update_many(
- {"user_id": user_id, "status": {"$in": [NotificationStatus.DELIVERED]}},
- {"$set": {"status": NotificationStatus.READ, "read_at": datetime.now(UTC)}},
- )
- return result.modified_count
+ result = await NotificationDocument.find(
+ {"user_id": user_id, "status": NotificationStatus.DELIVERED}
+ ).update_many({"$set": {"status": NotificationStatus.READ, "read_at": datetime.now(UTC)}})
+ return result.modified_count if result and hasattr(result, "modified_count") else 0
async def delete_notification(self, notification_id: str, user_id: str) -> bool:
- result = await self.notifications_collection.delete_one(
- {"notification_id": notification_id, "user_id": user_id}
- )
- return result.deleted_count > 0
+ doc = await NotificationDocument.find_one({"notification_id": notification_id, "user_id": user_id})
+ if not doc:
+ return False
+ await doc.delete()
+ return True
async def list_notifications(
self,
@@ -98,180 +73,160 @@ async def list_notifications(
exclude_tags: list[str] | None = None,
tag_prefix: str | None = None,
) -> list[DomainNotification]:
- base: dict[str, object] = {"user_id": user_id}
- if status:
- base["status"] = status
- query: dict[str, object] | None = base
- tag_filters: list[dict[str, object]] = []
- if include_tags:
- tag_filters.append({"tags": {"$in": include_tags}})
- if exclude_tags:
- tag_filters.append({"tags": {"$nin": exclude_tags}})
- if tag_prefix:
- tag_filters.append({"tags": {"$elemMatch": {"$regex": f"^{tag_prefix}"}}})
- if tag_filters:
- query = {"$and": [base] + tag_filters}
-
- cursor = (
- self.notifications_collection.find(query or base).sort("created_at", DESCENDING).skip(skip).limit(limit)
- )
-
- items: list[DomainNotification] = []
- async for doc in cursor:
- items.append(self.mapper.from_mongo_document(doc))
- return items
-
- async def list_notifications_by_tag(
- self,
- user_id: str,
- tag: str,
- skip: int = 0,
- limit: int = 20,
- ) -> list[DomainNotification]:
- """Convenience helper to list notifications filtered by a single exact tag."""
- return await self.list_notifications(
- user_id=user_id,
- skip=skip,
- limit=limit,
- include_tags=[tag],
+ conditions = [
+ NotificationDocument.user_id == user_id,
+ NotificationDocument.status == status if status else None,
+ In(NotificationDocument.tags, include_tags) if include_tags else None,
+ NotIn(NotificationDocument.tags, exclude_tags) if exclude_tags else None,
+ ElemMatch(NotificationDocument.tags, {"$regex": f"^{tag_prefix}"}) if tag_prefix else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
+ docs = (
+ await NotificationDocument.find(*conditions)
+ .sort([("created_at", SortDirection.DESCENDING)])
+ .skip(skip)
+ .limit(limit)
+ .to_list()
)
+ return [DomainNotification(**doc.model_dump(exclude={"id"})) for doc in docs]
- async def count_notifications(self, user_id: str, additional_filters: dict[str, object] | None = None) -> int:
- query: dict[str, object] = {"user_id": user_id}
- if additional_filters:
- query.update(additional_filters)
- return await self.notifications_collection.count_documents(query)
+ async def count_notifications(self, user_id: str, *additional_conditions) -> int: # type: ignore[no-untyped-def]
+ conditions = [NotificationDocument.user_id == user_id, *additional_conditions]
+ return await NotificationDocument.find(*conditions).count()
async def get_unread_count(self, user_id: str) -> int:
- return await self.notifications_collection.count_documents(
- {
- "user_id": user_id,
- "status": {"$in": [NotificationStatus.DELIVERED]},
- }
- )
+ return await NotificationDocument.find(
+ NotificationDocument.user_id == user_id,
+ In(NotificationDocument.status, [NotificationStatus.DELIVERED]),
+ ).count()
async def try_claim_pending(self, notification_id: str) -> bool:
- """Atomically claim a pending notification for delivery.
-
- Transitions PENDING -> SENDING when scheduled_for is None or due.
- Returns True if the document was claimed by this caller.
- """
now = datetime.now(UTC)
- result = await self.notifications_collection.update_one(
- {
- "notification_id": notification_id,
- "status": NotificationStatus.PENDING,
- "$or": [{"scheduled_for": None}, {"scheduled_for": {"$lte": now}}],
- },
- {"$set": {"status": NotificationStatus.SENDING, "sent_at": now}},
+ doc = await NotificationDocument.find_one(
+ NotificationDocument.notification_id == notification_id,
+ NotificationDocument.status == NotificationStatus.PENDING,
+ Or(
+ NotificationDocument.scheduled_for == None, # noqa: E711
+ LTE(NotificationDocument.scheduled_for, now),
+ ),
)
- return result.modified_count > 0
+ if not doc:
+ return False
+ await doc.set({"status": NotificationStatus.SENDING, "sent_at": now})
+ return True
async def find_pending_notifications(self, batch_size: int = 10) -> list[DomainNotification]:
- cursor = self.notifications_collection.find(
- {
- "status": NotificationStatus.PENDING,
- "$or": [{"scheduled_for": None}, {"scheduled_for": {"$lte": datetime.now(UTC)}}],
- }
- ).limit(batch_size)
-
- items: list[DomainNotification] = []
- async for doc in cursor:
- items.append(self.mapper.from_mongo_document(doc))
- return items
+ now = datetime.now(UTC)
+ docs = (
+ await NotificationDocument.find(
+ NotificationDocument.status == NotificationStatus.PENDING,
+ Or(
+ NotificationDocument.scheduled_for == None, # noqa: E711
+ LTE(NotificationDocument.scheduled_for, now),
+ ),
+ )
+ .limit(batch_size)
+ .to_list()
+ )
+ return [DomainNotification(**doc.model_dump(exclude={"id"})) for doc in docs]
async def find_scheduled_notifications(self, batch_size: int = 10) -> list[DomainNotification]:
- cursor = self.notifications_collection.find(
- {
- "status": NotificationStatus.PENDING,
- "scheduled_for": {"$lte": datetime.now(UTC), "$ne": None},
- }
- ).limit(batch_size)
-
- items: list[DomainNotification] = []
- async for doc in cursor:
- items.append(self.mapper.from_mongo_document(doc))
- return items
+ now = datetime.now(UTC)
+ docs = (
+ await NotificationDocument.find(
+ NotificationDocument.status == NotificationStatus.PENDING,
+ LTE(NotificationDocument.scheduled_for, now),
+ NotificationDocument.scheduled_for != None, # noqa: E711
+ )
+ .limit(batch_size)
+ .to_list()
+ )
+ return [DomainNotification(**doc.model_dump(exclude={"id"})) for doc in docs]
async def cleanup_old_notifications(self, days: int = 30) -> int:
cutoff = datetime.now(UTC) - timedelta(days=days)
- result = await self.notifications_collection.delete_many({"created_at": {"$lt": cutoff}})
- return result.deleted_count
+ result = await NotificationDocument.find(
+ LT(NotificationDocument.created_at, cutoff),
+ ).delete()
+ return result.deleted_count if result else 0
# Subscriptions
async def get_subscription(
self, user_id: str, channel: NotificationChannel
) -> DomainNotificationSubscription | None:
- doc = await self.subscriptions_collection.find_one({"user_id": user_id, "channel": channel})
+ doc = await NotificationSubscriptionDocument.find_one({"user_id": user_id, "channel": channel})
if not doc:
return None
- return self.mapper.subscription_from_mongo_document(doc)
+ return DomainNotificationSubscription(**doc.model_dump(exclude={"id"}))
async def upsert_subscription(
- self,
- user_id: str,
- channel: NotificationChannel,
- subscription: DomainNotificationSubscription,
- ) -> None:
- subscription.user_id = user_id
- subscription.channel = channel
- subscription.updated_at = datetime.now(UTC)
- doc = self.mapper.subscription_to_mongo_document(subscription)
- await self.subscriptions_collection.replace_one({"user_id": user_id, "channel": channel}, doc, upsert=True)
+ self, user_id: str, channel: NotificationChannel, update_data: DomainSubscriptionUpdate
+ ) -> DomainNotificationSubscription:
+ existing = await NotificationSubscriptionDocument.find_one({"user_id": user_id, "channel": channel})
+ update_dict = {k: v for k, v in asdict(update_data).items() if v is not None}
+ update_dict["updated_at"] = datetime.now(UTC)
+
+ if existing:
+ await existing.set(update_dict)
+ return DomainNotificationSubscription(**existing.model_dump(exclude={"id"}))
+ else:
+ doc = NotificationSubscriptionDocument(
+ user_id=user_id,
+ channel=channel,
+ **update_dict,
+ )
+ await doc.insert()
+ return DomainNotificationSubscription(**doc.model_dump(exclude={"id"}))
- async def get_all_subscriptions(self, user_id: str) -> dict[str, DomainNotificationSubscription]:
- subs: dict[str, DomainNotificationSubscription] = {}
+ async def get_all_subscriptions(self, user_id: str) -> dict[NotificationChannel, DomainNotificationSubscription]:
+ subs: dict[NotificationChannel, DomainNotificationSubscription] = {}
for channel in NotificationChannel:
- doc = await self.subscriptions_collection.find_one({"user_id": user_id, "channel": channel})
+ doc = await NotificationSubscriptionDocument.find_one({"user_id": user_id, "channel": channel})
if doc:
- subs[channel] = self.mapper.subscription_from_mongo_document(doc)
+ subs[channel] = DomainNotificationSubscription(**doc.model_dump(exclude={"id"}))
else:
subs[channel] = DomainNotificationSubscription(user_id=user_id, channel=channel, enabled=True)
return subs
- # User query operations for system notifications
+ # User query operations
async def get_users_by_roles(self, roles: list[UserRole]) -> list[str]:
- users_collection = self.db.users
- role_values = [role.value for role in roles]
- cursor = users_collection.find(
- {UserFields.ROLE: {"$in": role_values}, UserFields.IS_ACTIVE: True},
- {UserFields.USER_ID: 1},
- )
-
- user_ids: list[str] = []
- async for user in cursor:
- if user.get("user_id"):
- user_ids.append(user["user_id"])
-
- logger.info(f"Found {len(user_ids)} users with roles {role_values}")
+ docs = await UserDocument.find(
+ In(UserDocument.role, roles),
+ UserDocument.is_active == True, # noqa: E712
+ ).to_list()
+ user_ids = [doc.user_id for doc in docs if doc.user_id]
+ self.logger.info(f"Found {len(user_ids)} users with roles {[r.value for r in roles]}")
return user_ids
async def get_active_users(self, days: int = 30) -> list[str]:
- cutoff_date = datetime.now(UTC) - timedelta(days=days)
+ from app.db.docs import ExecutionDocument
- users_collection = self.db.users
- cursor = users_collection.find(
- {
- "$or": [
- {"last_login": {"$gte": cutoff_date}},
- {"last_activity": {"$gte": cutoff_date}},
- {"updated_at": {"$gte": cutoff_date}},
- ],
- "is_active": True,
- },
- {"user_id": 1},
+ cutoff_date = datetime.now(UTC) - timedelta(days=days)
+ user_ids: set[str] = set()
+
+ # From users collection
+ docs = await UserDocument.find(
+ Or(
+ GTE(UserDocument.last_login, cutoff_date),
+ GTE(UserDocument.last_activity, cutoff_date),
+ GTE(UserDocument.updated_at, cutoff_date),
+ ),
+ UserDocument.is_active == True, # noqa: E712
+ ).to_list()
+ for doc in docs:
+ if doc.user_id:
+ user_ids.add(doc.user_id)
+
+ # From executions
+ exec_docs = (
+ await ExecutionDocument.find(
+ GTE(ExecutionDocument.created_at, cutoff_date),
+ )
+ .limit(1000)
+ .to_list()
)
-
- user_ids = set()
- async for user in cursor:
- if user.get("user_id"):
- user_ids.add(user["user_id"])
-
- executions_collection = self.db.executions
- exec_cursor = executions_collection.find({"created_at": {"$gte": cutoff_date}}, {"user_id": 1}).limit(1000)
-
- async for execution in exec_cursor:
- if execution.get("user_id"):
- user_ids.add(execution["user_id"])
+ for doc in exec_docs:
+ if doc.user_id:
+ user_ids.add(doc.user_id)
return list(user_ids)
diff --git a/backend/app/db/repositories/replay_repository.py b/backend/app/db/repositories/replay_repository.py
index fc21d3aa..e2c07846 100644
--- a/backend/app/db/repositories/replay_repository.py
+++ b/backend/app/db/repositories/replay_repository.py
@@ -1,109 +1,107 @@
-from typing import Any, AsyncIterator, Dict, List
+import logging
+from dataclasses import asdict
+from datetime import datetime
+from typing import Any, AsyncIterator
-from pymongo import ASCENDING, DESCENDING
+from beanie.odm.enums import SortDirection
+from beanie.operators import LT, In
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
+from app.db.docs import EventStoreDocument, ReplaySessionDocument
from app.domain.admin.replay_updates import ReplaySessionUpdate
from app.domain.enums.replay import ReplayStatus
-from app.domain.events.event_models import CollectionNames
-from app.domain.replay import ReplayFilter, ReplaySessionState
-from app.infrastructure.mappers import ReplayStateMapper
+from app.domain.replay.models import ReplayConfig, ReplayFilter, ReplaySessionState
class ReplayRepository:
- def __init__(self, database: Database) -> None:
- self.db = database
- self.replay_collection: Collection = database.get_collection(CollectionNames.REPLAY_SESSIONS)
- self.events_collection: Collection = database.get_collection(CollectionNames.EVENTS)
- self._mapper = ReplayStateMapper()
-
- async def create_indexes(self) -> None:
- # Replay sessions indexes
- await self.replay_collection.create_index([("session_id", ASCENDING)], unique=True)
- await self.replay_collection.create_index([("status", ASCENDING)])
- await self.replay_collection.create_index([("created_at", DESCENDING)])
- await self.replay_collection.create_index([("user_id", ASCENDING)])
-
- # Events collection indexes for replay queries
- await self.events_collection.create_index([("execution_id", 1), ("timestamp", 1)])
- await self.events_collection.create_index([("event_type", 1), ("timestamp", 1)])
- await self.events_collection.create_index([("metadata.user_id", 1), ("timestamp", 1)])
-
- logger.info("Replay repository indexes created successfully")
+ def __init__(self, logger: logging.Logger) -> None:
+ self.logger = logger
async def save_session(self, session: ReplaySessionState) -> None:
- """Save or update a replay session (domain → persistence)."""
- doc = self._mapper.to_mongo_document(session)
- await self.replay_collection.update_one({"session_id": session.session_id}, {"$set": doc}, upsert=True)
+ existing = await ReplaySessionDocument.find_one({"session_id": session.session_id})
+ data = asdict(session)
+ # config is a Pydantic model, convert to dict for document
+ data["config"] = session.config.model_dump()
+ doc = ReplaySessionDocument(**data)
+ if existing:
+ doc.id = existing.id
+ await doc.save()
async def get_session(self, session_id: str) -> ReplaySessionState | None:
- """Get a replay session by ID (persistence → domain)."""
- data = await self.replay_collection.find_one({"session_id": session_id})
- return self._mapper.from_mongo_document(data) if data else None
+ doc = await ReplaySessionDocument.find_one({"session_id": session_id})
+ if not doc:
+ return None
+ data = doc.model_dump(exclude={"id", "revision_id"})
+ data["config"] = ReplayConfig.model_validate(data["config"])
+ return ReplaySessionState(**data)
async def list_sessions(
self, status: ReplayStatus | None = None, user_id: str | None = None, limit: int = 100, skip: int = 0
) -> list[ReplaySessionState]:
- collection = self.replay_collection
-
- query: dict[str, object] = {}
- if status:
- query["status"] = status.value
- if user_id:
- query["config.filter.user_id"] = user_id
-
- cursor = collection.find(query).sort("created_at", DESCENDING).skip(skip).limit(limit)
- sessions: list[ReplaySessionState] = []
- async for doc in cursor:
- sessions.append(self._mapper.from_mongo_document(doc))
- return sessions
+ conditions: list[Any] = [
+ ReplaySessionDocument.status == status if status else None,
+ ReplaySessionDocument.config.filter.user_id == user_id if user_id else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
+ docs = (
+ await ReplaySessionDocument.find(*conditions)
+ .sort([("created_at", SortDirection.DESCENDING)])
+ .skip(skip)
+ .limit(limit)
+ .to_list()
+ )
+ results = []
+ for doc in docs:
+ data = doc.model_dump(exclude={"id", "revision_id"})
+ data["config"] = ReplayConfig.model_validate(data["config"])
+ results.append(ReplaySessionState(**data))
+ return results
async def update_session_status(self, session_id: str, status: ReplayStatus) -> bool:
- """Update the status of a replay session"""
- result = await self.replay_collection.update_one({"session_id": session_id}, {"$set": {"status": status.value}})
- return result.modified_count > 0
+ doc = await ReplaySessionDocument.find_one({"session_id": session_id})
+ if not doc:
+ return False
+ doc.status = status
+ await doc.save()
+ return True
- async def delete_old_sessions(self, cutoff_time: str) -> int:
- """Delete old completed/failed/cancelled sessions"""
+ async def delete_old_sessions(self, cutoff_time: datetime) -> int:
terminal_statuses = [
- ReplayStatus.COMPLETED.value,
- ReplayStatus.FAILED.value,
- ReplayStatus.CANCELLED.value,
+ ReplayStatus.COMPLETED,
+ ReplayStatus.FAILED,
+ ReplayStatus.CANCELLED,
]
- result = await self.replay_collection.delete_many(
- {"created_at": {"$lt": cutoff_time}, "status": {"$in": terminal_statuses}}
- )
- return result.deleted_count
+ result = await ReplaySessionDocument.find(
+ LT(ReplaySessionDocument.created_at, cutoff_time),
+ In(ReplaySessionDocument.status, terminal_statuses),
+ ).delete()
+ return result.deleted_count if result else 0
- async def count_sessions(self, query: dict[str, object] | None = None) -> int:
- """Count sessions matching the given query"""
- return await self.replay_collection.count_documents(query or {})
+ async def count_sessions(self, *conditions: Any) -> int:
+ return await ReplaySessionDocument.find(*conditions).count()
async def update_replay_session(self, session_id: str, updates: ReplaySessionUpdate) -> bool:
- """Update specific fields of a replay session"""
- if not updates.has_updates():
+ update_dict = {k: (v.value if hasattr(v, "value") else v) for k, v in asdict(updates).items() if v is not None}
+ if not update_dict:
return False
-
- mongo_updates = updates.to_dict()
- result = await self.replay_collection.update_one({"session_id": session_id}, {"$set": mongo_updates})
- return result.modified_count > 0
+ doc = await ReplaySessionDocument.find_one({"session_id": session_id})
+ if not doc:
+ return False
+ await doc.set(update_dict)
+ return True
async def count_events(self, replay_filter: ReplayFilter) -> int:
- """Count events matching the given filter"""
query = replay_filter.to_mongo_query()
- return await self.events_collection.count_documents(query)
+ return await EventStoreDocument.find(query).count()
async def fetch_events(
self, replay_filter: ReplayFilter, batch_size: int = 100, skip: int = 0
- ) -> AsyncIterator[List[Dict[str, Any]]]:
- """Fetch events in batches based on filter"""
+ ) -> AsyncIterator[list[dict[str, Any]]]:
query = replay_filter.to_mongo_query()
- cursor = self.events_collection.find(query).sort("timestamp", 1).skip(skip)
+ cursor = EventStoreDocument.find(query).sort([("timestamp", SortDirection.ASCENDING)]).skip(skip)
batch = []
async for doc in cursor:
- batch.append(doc)
+ batch.append(doc.model_dump(exclude={"id", "revision_id", "stored_at"}))
if len(batch) >= batch_size:
yield batch
batch = []
diff --git a/backend/app/db/repositories/resource_allocation_repository.py b/backend/app/db/repositories/resource_allocation_repository.py
index c0aa8454..c2d5e79c 100644
--- a/backend/app/db/repositories/resource_allocation_repository.py
+++ b/backend/app/db/repositories/resource_allocation_repository.py
@@ -1,51 +1,26 @@
+from dataclasses import asdict
from datetime import datetime, timezone
+from uuid import uuid4
-from app.core.database_context import Collection, Database
-from app.domain.events.event_models import CollectionNames
+from app.db.docs import ResourceAllocationDocument
+from app.domain.saga import DomainResourceAllocation, DomainResourceAllocationCreate
class ResourceAllocationRepository:
- """Repository for resource allocation bookkeeping used by saga steps."""
-
- def __init__(self, database: Database):
- self._db = database
- self._collection: Collection = self._db.get_collection(CollectionNames.RESOURCE_ALLOCATIONS)
-
async def count_active(self, language: str) -> int:
- return await self._collection.count_documents(
- {
- "status": "active",
- "language": language,
- }
- )
+ return await ResourceAllocationDocument.find({"status": "active", "language": language}).count()
- async def create_allocation(
- self,
- allocation_id: str,
- *,
- execution_id: str,
- language: str,
- cpu_request: str,
- memory_request: str,
- cpu_limit: str,
- memory_limit: str,
- ) -> bool:
- doc = {
- "_id": allocation_id,
- "execution_id": execution_id,
- "language": language,
- "cpu_request": cpu_request,
- "memory_request": memory_request,
- "cpu_limit": cpu_limit,
- "memory_limit": memory_limit,
- "status": "active",
- "allocated_at": datetime.now(timezone.utc),
- }
- result = await self._collection.insert_one(doc)
- return result.inserted_id is not None
+ async def create_allocation(self, create_data: DomainResourceAllocationCreate) -> DomainResourceAllocation:
+ doc = ResourceAllocationDocument(
+ allocation_id=str(uuid4()),
+ **asdict(create_data),
+ )
+ await doc.insert()
+ return DomainResourceAllocation(**doc.model_dump(exclude={"id"}))
async def release_allocation(self, allocation_id: str) -> bool:
- result = await self._collection.update_one(
- {"_id": allocation_id}, {"$set": {"status": "released", "released_at": datetime.now(timezone.utc)}}
- )
- return result.modified_count > 0
+ doc = await ResourceAllocationDocument.find_one({"allocation_id": allocation_id})
+ if not doc:
+ return False
+ await doc.set({"status": "released", "released_at": datetime.now(timezone.utc)})
+ return True
diff --git a/backend/app/db/repositories/saga_repository.py b/backend/app/db/repositories/saga_repository.py
index 477b3c4e..5eaf1169 100644
--- a/backend/app/db/repositories/saga_repository.py
+++ b/backend/app/db/repositories/saga_repository.py
@@ -1,101 +1,104 @@
+from dataclasses import asdict
from datetime import datetime, timezone
+from typing import Any
-from pymongo import DESCENDING
+from beanie.odm.enums import SortDirection
+from beanie.odm.operators.find import BaseFindOperator
+from beanie.operators import GT, LT, In
-from app.core.database_context import Collection, Database
+from app.db.docs import ExecutionDocument, SagaDocument
from app.domain.enums.saga import SagaState
-from app.domain.events.event_models import CollectionNames
-from app.domain.saga.models import Saga, SagaFilter, SagaListResult
-from app.infrastructure.mappers import SagaFilterMapper, SagaMapper
+from app.domain.saga import Saga, SagaFilter, SagaListResult
class SagaRepository:
- """Repository for saga data access.
-
- This repository handles all database operations for sagas,
- following clean architecture principles with no business logic
- or HTTP-specific concerns.
- """
-
- def __init__(self, database: Database):
- self.db = database
- self.sagas: Collection = self.db.get_collection(CollectionNames.SAGAS)
- self.executions: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
- self.mapper = SagaMapper()
- self.filter_mapper = SagaFilterMapper()
+ def _filter_conditions(self, saga_filter: SagaFilter) -> list[BaseFindOperator]:
+ """Build Beanie query conditions from SagaFilter."""
+ conditions = [
+ SagaDocument.state == saga_filter.state if saga_filter.state else None,
+ In(SagaDocument.execution_id, saga_filter.execution_ids) if saga_filter.execution_ids else None,
+ SagaDocument.context_data["user_id"] == saga_filter.user_id if saga_filter.user_id else None,
+ SagaDocument.saga_name == saga_filter.saga_name if saga_filter.saga_name else None,
+ GT(SagaDocument.created_at, saga_filter.created_after) if saga_filter.created_after else None,
+ LT(SagaDocument.created_at, saga_filter.created_before) if saga_filter.created_before else None,
+ ]
+ if saga_filter.error_status is True:
+ conditions.append(SagaDocument.error_message != None) # noqa: E711
+ elif saga_filter.error_status is False:
+ conditions.append(SagaDocument.error_message == None) # noqa: E711
+ return [c for c in conditions if c is not None]
async def upsert_saga(self, saga: Saga) -> bool:
- doc = self.mapper.to_mongo(saga)
- result = await self.sagas.replace_one(
- {"saga_id": saga.saga_id},
- doc,
- upsert=True,
- )
- return result.modified_count > 0
+ existing = await SagaDocument.find_one({"saga_id": saga.saga_id})
+ doc = SagaDocument(**asdict(saga))
+ if existing:
+ doc.id = existing.id
+ await doc.save()
+ return existing is not None
async def get_saga_by_execution_and_name(self, execution_id: str, saga_name: str) -> Saga | None:
- doc = await self.sagas.find_one(
- {
- "execution_id": execution_id,
- "saga_name": saga_name,
- }
+ doc = await SagaDocument.find_one(
+ SagaDocument.execution_id == execution_id,
+ SagaDocument.saga_name == saga_name,
)
- return self.mapper.from_mongo(doc) if doc else None
+ return Saga(**doc.model_dump(exclude={"id"})) if doc else None
async def get_saga(self, saga_id: str) -> Saga | None:
- doc = await self.sagas.find_one({"saga_id": saga_id})
- return self.mapper.from_mongo(doc) if doc else None
+ doc = await SagaDocument.find_one({"saga_id": saga_id})
+ return Saga(**doc.model_dump(exclude={"id"})) if doc else None
async def get_sagas_by_execution(
self, execution_id: str, state: SagaState | None = None, limit: int = 100, skip: int = 0
) -> SagaListResult:
- query: dict[str, object] = {"execution_id": execution_id}
- if state:
- query["state"] = state.value
-
- total = await self.sagas.count_documents(query)
- cursor = self.sagas.find(query).sort("created_at", DESCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
- sagas = [self.mapper.from_mongo(doc) for doc in docs]
-
- return SagaListResult(sagas=sagas, total=total, skip=skip, limit=limit)
+ conditions = [
+ SagaDocument.execution_id == execution_id,
+ SagaDocument.state == state if state else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
+
+ query = SagaDocument.find(*conditions)
+ total = await query.count()
+ docs = await query.sort([("created_at", SortDirection.DESCENDING)]).skip(skip).limit(limit).to_list()
+ return SagaListResult(
+ sagas=[Saga(**d.model_dump(exclude={"id"})) for d in docs],
+ total=total,
+ skip=skip,
+ limit=limit,
+ )
async def list_sagas(self, saga_filter: SagaFilter, limit: int = 100, skip: int = 0) -> SagaListResult:
- query = self.filter_mapper.to_mongodb_query(saga_filter)
-
- # Get total count
- total = await self.sagas.count_documents(query)
-
- # Get sagas with pagination
- cursor = self.sagas.find(query).sort("created_at", DESCENDING).skip(skip).limit(limit)
- docs = await cursor.to_list(length=limit)
-
- sagas = [self.mapper.from_mongo(doc) for doc in docs]
-
- return SagaListResult(sagas=sagas, total=total, skip=skip, limit=limit)
+ conditions = self._filter_conditions(saga_filter)
+ query = SagaDocument.find(*conditions)
+ total = await query.count()
+ docs = await query.sort([("created_at", SortDirection.DESCENDING)]).skip(skip).limit(limit).to_list()
+ return SagaListResult(
+ sagas=[Saga(**d.model_dump(exclude={"id"})) for d in docs],
+ total=total,
+ skip=skip,
+ limit=limit,
+ )
async def update_saga_state(self, saga_id: str, state: SagaState, error_message: str | None = None) -> bool:
- update_data: dict[str, object] = {"state": state.value, "updated_at": datetime.now(timezone.utc)}
+ doc = await SagaDocument.find_one({"saga_id": saga_id})
+ if not doc:
+ return False
+ doc.state = state
+ doc.updated_at = datetime.now(timezone.utc)
if error_message:
- update_data["error_message"] = error_message
-
- result = await self.sagas.update_one({"saga_id": saga_id}, {"$set": update_data})
-
- return result.modified_count > 0
+ doc.error_message = error_message
+ await doc.save()
+ return True
async def get_user_execution_ids(self, user_id: str) -> list[str]:
- cursor = self.executions.find({"user_id": user_id}, {"execution_id": 1})
- docs = await cursor.to_list(length=None)
- return [doc["execution_id"] for doc in docs]
+ docs = await ExecutionDocument.find({"user_id": user_id}).to_list()
+ return [doc.execution_id for doc in docs]
async def count_sagas_by_state(self) -> dict[str, int]:
pipeline = [{"$group": {"_id": "$state", "count": {"$sum": 1}}}]
-
result = {}
- async for doc in self.sagas.aggregate(pipeline):
+ async for doc in SagaDocument.aggregate(pipeline):
result[doc["_id"]] = doc["count"]
-
return result
async def find_timed_out_sagas(
@@ -105,37 +108,39 @@ async def find_timed_out_sagas(
limit: int = 100,
) -> list[Saga]:
states = states or [SagaState.RUNNING, SagaState.COMPENSATING]
- query = {
- "state": {"$in": [s.value for s in states]},
- "created_at": {"$lt": cutoff_time},
- }
- cursor = self.sagas.find(query)
- docs = await cursor.to_list(length=limit)
- return [self.mapper.from_mongo(doc) for doc in docs]
-
- async def get_saga_statistics(self, saga_filter: SagaFilter | None = None) -> dict[str, object]:
- query = self.filter_mapper.to_mongodb_query(saga_filter) if saga_filter else {}
-
- # Basic counts
- total = await self.sagas.count_documents(query)
+ docs = (
+ await SagaDocument.find(
+ In(SagaDocument.state, states),
+ LT(SagaDocument.created_at, cutoff_time),
+ )
+ .limit(limit)
+ .to_list()
+ )
+ return [Saga(**d.model_dump(exclude={"id"})) for d in docs]
- # State distribution
- state_pipeline = [{"$match": query}, {"$group": {"_id": "$state", "count": {"$sum": 1}}}]
+ async def get_saga_statistics(self, saga_filter: SagaFilter | None = None) -> dict[str, Any]:
+ conditions = self._filter_conditions(saga_filter) if saga_filter else []
+ base_query = SagaDocument.find(*conditions)
+ total = await base_query.count()
+ # Group by state
+ state_pipeline = [{"$group": {"_id": "$state", "count": {"$sum": 1}}}]
states = {}
- async for doc in self.sagas.aggregate(state_pipeline):
+ async for doc in base_query.aggregate(state_pipeline):
states[doc["_id"]] = doc["count"]
# Average duration for completed sagas
+ completed_conditions = [
+ *conditions,
+ SagaDocument.state == SagaState.COMPLETED,
+ SagaDocument.completed_at != None, # noqa: E711
+ ]
duration_pipeline = [
- {"$match": {**query, "state": "completed", "completed_at": {"$ne": None}}},
{"$project": {"duration": {"$subtract": ["$completed_at", "$created_at"]}}},
{"$group": {"_id": None, "avg_duration": {"$avg": "$duration"}}},
]
-
avg_duration = 0.0
- async for doc in self.sagas.aggregate(duration_pipeline):
- # Convert milliseconds to seconds
+ async for doc in SagaDocument.find(*completed_conditions).aggregate(duration_pipeline):
avg_duration = doc["avg_duration"] / 1000.0 if doc["avg_duration"] else 0.0
return {"total": total, "by_state": states, "average_duration_seconds": avg_duration}
diff --git a/backend/app/db/repositories/saved_script_repository.py b/backend/app/db/repositories/saved_script_repository.py
index eb26fa1d..a29cbb4f 100644
--- a/backend/app/db/repositories/saved_script_repository.py
+++ b/backend/app/db/repositories/saved_script_repository.py
@@ -1,45 +1,51 @@
-from app.core.database_context import Collection, Database
-from app.domain.events.event_models import CollectionNames
-from app.domain.saved_script import (
- DomainSavedScript,
- DomainSavedScriptCreate,
- DomainSavedScriptUpdate,
-)
-from app.infrastructure.mappers import SavedScriptMapper
+from dataclasses import asdict
+from beanie.operators import Eq
-class SavedScriptRepository:
- def __init__(self, database: Database):
- self.db = database
- self.collection: Collection = self.db.get_collection(CollectionNames.SAVED_SCRIPTS)
- self.mapper = SavedScriptMapper()
+from app.db.docs import SavedScriptDocument
+from app.domain.saved_script import DomainSavedScript, DomainSavedScriptCreate, DomainSavedScriptUpdate
- async def create_saved_script(self, saved_script: DomainSavedScriptCreate, user_id: str) -> DomainSavedScript:
- # Build DB document with defaults
- doc = self.mapper.to_insert_document(saved_script, user_id)
- result = await self.collection.insert_one(doc)
- if result.inserted_id is None:
- raise ValueError("Insert not acknowledged")
- return self.mapper.from_mongo_document(doc)
+class SavedScriptRepository:
+ async def create_saved_script(self, create_data: DomainSavedScriptCreate, user_id: str) -> DomainSavedScript:
+ doc = SavedScriptDocument(**asdict(create_data), user_id=user_id)
+ await doc.insert()
+ return DomainSavedScript(**doc.model_dump(exclude={"id", "revision_id"}))
async def get_saved_script(self, script_id: str, user_id: str) -> DomainSavedScript | None:
- saved_script = await self.collection.find_one({"script_id": script_id, "user_id": user_id})
- if not saved_script:
+ doc = await SavedScriptDocument.find_one(
+ Eq(SavedScriptDocument.script_id, script_id),
+ Eq(SavedScriptDocument.user_id, user_id),
+ )
+ return DomainSavedScript(**doc.model_dump(exclude={"id", "revision_id"})) if doc else None
+
+ async def update_saved_script(
+ self,
+ script_id: str,
+ user_id: str,
+ update_data: DomainSavedScriptUpdate,
+ ) -> DomainSavedScript | None:
+ doc = await SavedScriptDocument.find_one(
+ Eq(SavedScriptDocument.script_id, script_id),
+ Eq(SavedScriptDocument.user_id, user_id),
+ )
+ if not doc:
return None
- return self.mapper.from_mongo_document(saved_script)
-
- async def update_saved_script(self, script_id: str, user_id: str, update_data: DomainSavedScriptUpdate) -> None:
- update = self.mapper.to_update_dict(update_data)
- await self.collection.update_one({"script_id": script_id, "user_id": user_id}, {"$set": update})
+ update_dict = {k: v for k, v in asdict(update_data).items() if v is not None}
+ await doc.set(update_dict)
+ return DomainSavedScript(**doc.model_dump(exclude={"id", "revision_id"}))
- async def delete_saved_script(self, script_id: str, user_id: str) -> None:
- await self.collection.delete_one({"script_id": script_id, "user_id": user_id})
+ async def delete_saved_script(self, script_id: str, user_id: str) -> bool:
+ doc = await SavedScriptDocument.find_one(
+ Eq(SavedScriptDocument.script_id, script_id),
+ Eq(SavedScriptDocument.user_id, user_id),
+ )
+ if not doc:
+ return False
+ await doc.delete()
+ return True
async def list_saved_scripts(self, user_id: str) -> list[DomainSavedScript]:
- cursor = self.collection.find({"user_id": user_id})
- scripts: list[DomainSavedScript] = []
- async for script in cursor:
- scripts.append(self.mapper.from_mongo_document(script))
- return scripts
+ docs = await SavedScriptDocument.find(Eq(SavedScriptDocument.user_id, user_id)).to_list()
+ return [DomainSavedScript(**d.model_dump(exclude={"id", "revision_id"})) for d in docs]
diff --git a/backend/app/db/repositories/sse_repository.py b/backend/app/db/repositories/sse_repository.py
index d2112cb3..1c46dbe2 100644
--- a/backend/app/db/repositories/sse_repository.py
+++ b/backend/app/db/repositories/sse_repository.py
@@ -1,31 +1,30 @@
from datetime import datetime, timezone
-from app.core.database_context import Collection, Database
-from app.domain.enums.execution import ExecutionStatus
-from app.domain.events.event_models import CollectionNames
-from app.domain.execution import DomainExecution
+from app.db.docs import ExecutionDocument
+from app.domain.execution import DomainExecution, ResourceUsageDomain
from app.domain.sse import SSEExecutionStatusDomain
-from app.infrastructure.mappers import SSEMapper
class SSERepository:
- def __init__(self, database: Database) -> None:
- self.db = database
- self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
- self.mapper = SSEMapper()
-
async def get_execution_status(self, execution_id: str) -> SSEExecutionStatusDomain | None:
- doc = await self.executions_collection.find_one({"execution_id": execution_id}, {"status": 1, "_id": 0})
+ doc = await ExecutionDocument.find_one({"execution_id": execution_id})
if not doc:
return None
return SSEExecutionStatusDomain(
execution_id=execution_id,
- status=ExecutionStatus(doc["status"]),
+ status=doc.status,
timestamp=datetime.now(timezone.utc).isoformat(),
)
async def get_execution(self, execution_id: str) -> DomainExecution | None:
- doc = await self.executions_collection.find_one({"execution_id": execution_id})
+ doc = await ExecutionDocument.find_one({"execution_id": execution_id})
if not doc:
return None
- return self.mapper.execution_from_mongo_document(doc)
+ return DomainExecution(
+ **{
+ **doc.model_dump(exclude={"id", "revision_id"}),
+ "resource_usage": ResourceUsageDomain(**doc.resource_usage.model_dump())
+ if doc.resource_usage
+ else None,
+ }
+ )
diff --git a/backend/app/db/repositories/user_repository.py b/backend/app/db/repositories/user_repository.py
index 3ee60369..7f3e928f 100644
--- a/backend/app/db/repositories/user_repository.py
+++ b/backend/app/db/repositories/user_repository.py
@@ -1,82 +1,70 @@
import re
-import uuid
+from dataclasses import asdict
from datetime import datetime, timezone
-from app.core.database_context import Collection, Database
+from beanie.odm.operators.find import BaseFindOperator
+from beanie.operators import Eq, Or, RegEx
+
+from app.db.docs import UserDocument
from app.domain.enums.user import UserRole
-from app.domain.events.event_models import CollectionNames
-from app.domain.user import User as DomainAdminUser
-from app.domain.user import UserFields
-from app.domain.user import UserUpdate as DomainUserUpdate
-from app.infrastructure.mappers import UserMapper
+from app.domain.user import DomainUserCreate, DomainUserUpdate, User, UserListResult
class UserRepository:
- def __init__(self, db: Database):
- self.db = db
- self.collection: Collection = self.db.get_collection(CollectionNames.USERS)
- self.mapper = UserMapper()
-
- async def get_user(self, username: str) -> DomainAdminUser | None:
- user = await self.collection.find_one({UserFields.USERNAME: username})
- if user:
- return self.mapper.from_mongo_document(user)
- return None
+ async def get_user(self, username: str) -> User | None:
+ doc = await UserDocument.find_one({"username": username})
+ return User(**doc.model_dump(exclude={"id", "revision_id"})) if doc else None
- async def create_user(self, user: DomainAdminUser) -> DomainAdminUser:
- if not user.user_id:
- user.user_id = str(uuid.uuid4())
- # Ensure timestamps
- if not getattr(user, "created_at", None):
- user.created_at = datetime.now(timezone.utc)
- if not getattr(user, "updated_at", None):
- user.updated_at = user.created_at
- user_dict = self.mapper.to_mongo_document(user)
- await self.collection.insert_one(user_dict)
- return user
+ async def create_user(self, create_data: DomainUserCreate) -> User:
+ doc = UserDocument(**asdict(create_data))
+ await doc.insert()
+ return User(**doc.model_dump(exclude={"id", "revision_id"}))
- async def get_user_by_id(self, user_id: str) -> DomainAdminUser | None:
- user = await self.collection.find_one({UserFields.USER_ID: user_id})
- if user:
- return self.mapper.from_mongo_document(user)
- return None
+ async def get_user_by_id(self, user_id: str) -> User | None:
+ doc = await UserDocument.find_one({"user_id": user_id})
+ return User(**doc.model_dump(exclude={"id", "revision_id"})) if doc else None
async def list_users(
self, limit: int = 100, offset: int = 0, search: str | None = None, role: UserRole | None = None
- ) -> list[DomainAdminUser]:
- query: dict[str, object] = {}
+ ) -> UserListResult:
+ conditions: list[BaseFindOperator] = []
if search:
- # Escape special regex characters to prevent ReDoS attacks
escaped_search = re.escape(search)
- query["$or"] = [
- {"username": {"$regex": escaped_search, "$options": "i"}},
- {"email": {"$regex": escaped_search, "$options": "i"}},
- ]
+ conditions.append(
+ Or(
+ RegEx(UserDocument.username, escaped_search, options="i"),
+ RegEx(UserDocument.email, escaped_search, options="i"),
+ )
+ )
if role:
- query["role"] = role.value
+ conditions.append(Eq(UserDocument.role, role))
- cursor = self.collection.find(query).skip(offset).limit(limit)
- users: list[DomainAdminUser] = []
- async for user in cursor:
- users.append(self.mapper.from_mongo_document(user))
+ query = UserDocument.find(*conditions)
+ total = await query.count()
+ docs = await query.skip(offset).limit(limit).to_list()
+ return UserListResult(
+ users=[User(**d.model_dump(exclude={"id", "revision_id"})) for d in docs],
+ total=total,
+ offset=offset,
+ limit=limit,
+ )
- return users
+ async def update_user(self, user_id: str, update_data: DomainUserUpdate) -> User | None:
+ doc = await UserDocument.find_one({"user_id": user_id})
+ if not doc:
+ return None
- async def update_user(self, user_id: str, update_data: DomainUserUpdate) -> DomainAdminUser | None:
- update_dict = self.mapper.to_update_dict(update_data)
- if not update_dict and update_data.password is None:
- return await self.get_user_by_id(user_id)
- # Handle password update separately if provided
- if update_data.password:
- update_dict[UserFields.HASHED_PASSWORD] = update_data.password # caller should pass hashed if desired
- update_dict[UserFields.UPDATED_AT] = datetime.now(timezone.utc)
- result = await self.collection.update_one({UserFields.USER_ID: user_id}, {"$set": update_dict})
- if result.modified_count > 0:
- return await self.get_user_by_id(user_id)
- return None
+ update_dict = {k: v for k, v in asdict(update_data).items() if v is not None}
+ if update_dict:
+ update_dict["updated_at"] = datetime.now(timezone.utc)
+ await doc.set(update_dict)
+ return User(**doc.model_dump(exclude={"id", "revision_id"}))
async def delete_user(self, user_id: str) -> bool:
- result = await self.collection.delete_one({UserFields.USER_ID: user_id})
- return result.deleted_count > 0
+ doc = await UserDocument.find_one({"user_id": user_id})
+ if not doc:
+ return False
+ await doc.delete()
+ return True
diff --git a/backend/app/db/repositories/user_settings_repository.py b/backend/app/db/repositories/user_settings_repository.py
index f91f50b9..3cf2c844 100644
--- a/backend/app/db/repositories/user_settings_repository.py
+++ b/backend/app/db/repositories/user_settings_repository.py
@@ -1,55 +1,33 @@
+import logging
+from dataclasses import asdict
from datetime import datetime
-from typing import Any, Dict, List
+from typing import List
-from pymongo import ASCENDING, DESCENDING, IndexModel
+from beanie.odm.enums import SortDirection
+from beanie.operators import GT, LTE, In
-from app.core.database_context import Collection, Database
-from app.core.logging import logger
+from app.db.docs import EventDocument, UserSettingsDocument, UserSettingsSnapshotDocument
from app.domain.enums.events import EventType
-from app.domain.events.event_models import CollectionNames
-from app.domain.user.settings_models import (
- DomainSettingsEvent,
- DomainUserSettings,
-)
-from app.infrastructure.mappers import UserSettingsMapper
+from app.domain.user.settings_models import DomainUserSettings
class UserSettingsRepository:
- def __init__(self, database: Database) -> None:
- self.db = database
- self.snapshots_collection: Collection = self.db.get_collection(CollectionNames.USER_SETTINGS_SNAPSHOTS)
- self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
- self.mapper = UserSettingsMapper()
-
- async def create_indexes(self) -> None:
- # Create indexes for settings snapshots
- await self.snapshots_collection.create_indexes(
- [
- IndexModel([("user_id", ASCENDING)], unique=True),
- IndexModel([("updated_at", DESCENDING)]),
- ]
- )
-
- # Create indexes for settings events
- await self.events_collection.create_indexes(
- [
- IndexModel([("event_type", ASCENDING), ("aggregate_id", ASCENDING)]),
- IndexModel([("aggregate_id", ASCENDING), ("timestamp", ASCENDING)]),
- ]
- )
-
- logger.info("User settings repository indexes created successfully")
+ def __init__(self, logger: logging.Logger) -> None:
+ self.logger = logger
async def get_snapshot(self, user_id: str) -> DomainUserSettings | None:
- doc = await self.snapshots_collection.find_one({"user_id": user_id})
+ doc = await UserSettingsDocument.find_one({"user_id": user_id})
if not doc:
return None
- return self.mapper.from_snapshot_document(doc)
+ return DomainUserSettings(**doc.model_dump(exclude={"id", "revision_id"}))
async def create_snapshot(self, settings: DomainUserSettings) -> None:
- doc = self.mapper.to_snapshot_document(settings)
- await self.snapshots_collection.replace_one({"user_id": settings.user_id}, doc, upsert=True)
- logger.info(f"Created settings snapshot for user {settings.user_id}")
+ existing = await UserSettingsDocument.find_one({"user_id": settings.user_id})
+ doc = UserSettingsDocument(**asdict(settings))
+ if existing:
+ doc.id = existing.id
+ await doc.save()
+ self.logger.info(f"Created settings snapshot for user {settings.user_id}")
async def get_settings_events(
self,
@@ -58,47 +36,39 @@ async def get_settings_events(
since: datetime | None = None,
until: datetime | None = None,
limit: int | None = None,
- ) -> List[DomainSettingsEvent]:
- query: Dict[str, Any] = {
- "aggregate_id": f"user_settings_{user_id}",
- "event_type": {"$in": [str(et) for et in event_types]},
- }
-
- if since or until:
- timestamp_query: Dict[str, Any] = {}
- if since:
- timestamp_query["$gt"] = since
- if until:
- timestamp_query["$lte"] = until
- query["timestamp"] = timestamp_query
-
- cursor = self.events_collection.find(query).sort("timestamp", ASCENDING)
-
+ ) -> List[EventDocument]:
+ aggregate_id = f"user_settings_{user_id}"
+ conditions = [
+ EventDocument.aggregate_id == aggregate_id,
+ In(EventDocument.event_type, [str(et) for et in event_types]),
+ GT(EventDocument.timestamp, since) if since else None,
+ LTE(EventDocument.timestamp, until) if until else None,
+ ]
+ conditions = [c for c in conditions if c is not None]
+
+ find_query = EventDocument.find(*conditions).sort([("timestamp", SortDirection.ASCENDING)])
if limit:
- cursor = cursor.limit(limit)
+ find_query = find_query.limit(limit)
- docs = await cursor.to_list(None)
- return [self.mapper.event_from_mongo_document(d) for d in docs]
+ return await find_query.to_list()
async def count_events_since_snapshot(self, user_id: str) -> int:
+ aggregate_id = f"user_settings_{user_id}"
snapshot = await self.get_snapshot(user_id)
-
if not snapshot:
- return await self.events_collection.count_documents({"aggregate_id": f"user_settings_{user_id}"})
+ return await EventDocument.find(EventDocument.aggregate_id == aggregate_id).count()
- return await self.events_collection.count_documents(
- {"aggregate_id": f"user_settings_{user_id}", "timestamp": {"$gt": snapshot.updated_at}}
- )
+ return await EventDocument.find(
+ EventDocument.aggregate_id == aggregate_id,
+ GT(EventDocument.timestamp, snapshot.updated_at),
+ ).count()
async def count_events_for_user(self, user_id: str) -> int:
- return await self.events_collection.count_documents({"aggregate_id": f"user_settings_{user_id}"})
+ return await EventDocument.find(EventDocument.aggregate_id == f"user_settings_{user_id}").count()
async def delete_user_settings(self, user_id: str) -> None:
- """Delete all settings data for a user (snapshot and events)."""
- # Delete snapshot
- await self.snapshots_collection.delete_one({"user_id": user_id})
-
- # Delete all events
- await self.events_collection.delete_many({"aggregate_id": f"user_settings_{user_id}"})
-
- logger.info(f"Deleted all settings data for user {user_id}")
+ doc = await UserSettingsSnapshotDocument.find_one({"user_id": user_id})
+ if doc:
+ await doc.delete()
+ await EventDocument.find(EventDocument.aggregate_id == f"user_settings_{user_id}").delete()
+ self.logger.info(f"Deleted all settings data for user {user_id}")
diff --git a/backend/app/db/schema/__init__.py b/backend/app/db/schema/__init__.py
deleted file mode 100644
index e3849b9b..00000000
--- a/backend/app/db/schema/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from app.db.schema.schema_manager import SchemaManager
-
-__all__ = [
- "SchemaManager",
-]
diff --git a/backend/app/db/schema/schema_manager.py b/backend/app/db/schema/schema_manager.py
deleted file mode 100644
index c3eb40f0..00000000
--- a/backend/app/db/schema/schema_manager.py
+++ /dev/null
@@ -1,311 +0,0 @@
-from __future__ import annotations
-
-from datetime import datetime, timezone
-from typing import Any, Awaitable, Callable, Iterable
-
-from pymongo import ASCENDING, DESCENDING, IndexModel
-
-from app.core.database_context import Database
-from app.core.logging import logger
-from app.domain.events.event_models import EventFields
-
-
-class SchemaManager:
- """Applies idempotent, versioned MongoDB migrations per database."""
-
- def __init__(self, database: Database) -> None:
- self.db = database
- self._versions = self.db["schema_versions"]
-
- async def _is_applied(self, migration_id: str) -> bool:
- doc = await self._versions.find_one({"_id": migration_id})
- return doc is not None
-
- async def _mark_applied(self, migration_id: str, description: str) -> None:
- await self._versions.update_one(
- {"_id": migration_id},
- {
- "$set": {
- "description": description,
- "applied_at": datetime.now(timezone.utc),
- }
- },
- upsert=True,
- )
-
- async def apply_all(self) -> None:
- """Apply all pending migrations in order."""
- migrations: list[tuple[str, str, Callable[[], Awaitable[None]]]] = [
- ("0001_events_init", "Create events indexes and validator", self._m_0001_events_init),
- ("0002_user_settings_indexes", "Create user settings indexes", self._m_0002_user_settings),
- ("0003_replay_indexes", "Create replay indexes", self._m_0003_replay),
- ("0004_notification_indexes", "Create notification indexes", self._m_0004_notifications),
- ("0005_idempotency_indexes", "Create idempotency indexes", self._m_0005_idempotency),
- ("0006_saga_indexes", "Create saga indexes", self._m_0006_sagas),
- ("0007_execution_results_indexes", "Create execution results indexes", self._m_0007_execution_results),
- ("0008_dlq_indexes", "Create DLQ indexes", self._m_0008_dlq),
- (
- "0009_event_store_extra_indexes",
- "Additional events indexes for event_store",
- self._m_0009_event_store_extra,
- ),
- ]
-
- for mig_id, desc, func in migrations:
- if await self._is_applied(mig_id):
- continue
- logger.info(f"Applying migration {mig_id}: {desc}")
- await func()
- await self._mark_applied(mig_id, desc)
- logger.info(f"Migration {mig_id} applied")
-
- async def _m_0001_events_init(self) -> None:
- events = self.db["events"]
-
- # Create named, idempotent indexes
- indexes: Iterable[IndexModel] = [
- IndexModel([(EventFields.EVENT_ID, ASCENDING)], name="idx_event_id_unique", unique=True),
- IndexModel(
- [(EventFields.EVENT_TYPE, ASCENDING), (EventFields.TIMESTAMP, DESCENDING)], name="idx_event_type_ts"
- ),
- IndexModel(
- [(EventFields.AGGREGATE_ID, ASCENDING), (EventFields.TIMESTAMP, DESCENDING)], name="idx_aggregate_ts"
- ),
- IndexModel([(EventFields.METADATA_CORRELATION_ID, ASCENDING)], name="idx_meta_correlation"),
- IndexModel(
- [(EventFields.METADATA_USER_ID, ASCENDING), (EventFields.TIMESTAMP, DESCENDING)],
- name="idx_meta_user_ts",
- ),
- IndexModel(
- [(EventFields.METADATA_SERVICE_NAME, ASCENDING), (EventFields.TIMESTAMP, DESCENDING)],
- name="idx_meta_service_ts",
- ),
- IndexModel([(EventFields.STATUS, ASCENDING), (EventFields.TIMESTAMP, DESCENDING)], name="idx_status_ts"),
- IndexModel([(EventFields.PAYLOAD_EXECUTION_ID, ASCENDING)], name="idx_payload_execution", sparse=True),
- IndexModel([(EventFields.PAYLOAD_POD_NAME, ASCENDING)], name="idx_payload_pod", sparse=True),
- # Optional TTL on ttl_expires_at (no effect for nulls)
- IndexModel([(EventFields.TTL_EXPIRES_AT, ASCENDING)], name="idx_ttl", expireAfterSeconds=0),
- # Text search index to support $text queries
- # Use language_override: "none" to prevent MongoDB from interpreting
- # the "language" field as a text search language (which causes
- # "language override unsupported: python" errors)
- IndexModel(
- [
- (EventFields.EVENT_TYPE, "text"),
- (EventFields.METADATA_SERVICE_NAME, "text"),
- (EventFields.METADATA_USER_ID, "text"),
- (EventFields.PAYLOAD, "text"),
- ],
- name="idx_text_search",
- language_override="none",
- default_language="english",
- ),
- ]
-
- try:
- await events.create_indexes(list(indexes))
- logger.info("Events indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring events indexes: {e}")
-
- # Validator (moderate, warn) — non-blocking
- try:
- await self.db.command(
- {
- "collMod": "events",
- "validator": {"$jsonSchema": self._event_json_schema()},
- "validationLevel": "moderate",
- "validationAction": "warn",
- }
- )
- logger.info("Events collection validator ensured")
- except Exception as e:
- logger.warning(f"Could not set events validator: {e}")
-
- @staticmethod
- def _event_json_schema() -> dict[str, Any]:
- return {
- "bsonType": "object",
- "required": [
- EventFields.EVENT_ID,
- EventFields.EVENT_TYPE,
- EventFields.TIMESTAMP,
- EventFields.EVENT_VERSION,
- ],
- "properties": {
- EventFields.EVENT_ID: {"bsonType": "string"},
- EventFields.EVENT_TYPE: {"bsonType": "string"},
- EventFields.TIMESTAMP: {"bsonType": "date"},
- EventFields.EVENT_VERSION: {"bsonType": "string", "pattern": "^\\d+\\.\\d+$"},
- EventFields.AGGREGATE_ID: {"bsonType": ["string", "null"]},
- EventFields.METADATA: {"bsonType": "object"},
- EventFields.PAYLOAD: {"bsonType": "object"},
- EventFields.STORED_AT: {"bsonType": ["date", "null"]},
- EventFields.TTL_EXPIRES_AT: {"bsonType": ["date", "null"]},
- EventFields.STATUS: {"bsonType": ["string", "null"]},
- },
- }
-
- async def _m_0002_user_settings(self) -> None:
- snapshots = self.db["user_settings_snapshots"]
- events = self.db["events"]
- try:
- await snapshots.create_indexes(
- [
- IndexModel([("user_id", ASCENDING)], name="idx_settings_user_unique", unique=True),
- IndexModel([("updated_at", DESCENDING)], name="idx_settings_updated_at_desc"),
- ]
- )
- await events.create_indexes(
- [
- IndexModel([("event_type", ASCENDING), ("aggregate_id", ASCENDING)], name="idx_events_type_agg"),
- IndexModel([("aggregate_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_agg_ts"),
- ]
- )
- logger.info("User settings indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring user settings indexes: {e}")
-
- async def _m_0003_replay(self) -> None:
- sessions = self.db["replay_sessions"]
- events = self.db["events"]
- try:
- await sessions.create_indexes(
- [
- IndexModel([("session_id", ASCENDING)], name="idx_replay_session_id", unique=True),
- IndexModel([("status", ASCENDING)], name="idx_replay_status"),
- IndexModel([("created_at", DESCENDING)], name="idx_replay_created_at_desc"),
- IndexModel([("user_id", ASCENDING)], name="idx_replay_user"),
- ]
- )
- await events.create_indexes(
- [
- IndexModel([("execution_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_exec_ts"),
- IndexModel([("event_type", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_type_ts"),
- IndexModel([("metadata.user_id", ASCENDING), ("timestamp", ASCENDING)], name="idx_events_user_ts"),
- ]
- )
- logger.info("Replay indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring replay indexes: {e}")
-
- async def _m_0004_notifications(self) -> None:
- notifications = self.db["notifications"]
- rules = self.db["notification_rules"]
- subs = self.db["notification_subscriptions"]
- try:
- await notifications.create_indexes(
- [
- IndexModel(
- [("user_id", ASCENDING), ("created_at", DESCENDING)], name="idx_notif_user_created_desc"
- ),
- IndexModel([("status", ASCENDING), ("scheduled_for", ASCENDING)], name="idx_notif_status_sched"),
- IndexModel([("created_at", ASCENDING)], name="idx_notif_created_at"),
- IndexModel([("notification_id", ASCENDING)], name="idx_notif_id_unique", unique=True),
- ]
- )
- await rules.create_indexes(
- [
- IndexModel([("event_types", ASCENDING)], name="idx_rules_event_types"),
- IndexModel([("enabled", ASCENDING)], name="idx_rules_enabled"),
- ]
- )
- await subs.create_indexes(
- [
- IndexModel(
- [("user_id", ASCENDING), ("channel", ASCENDING)],
- name="idx_sub_user_channel_unique",
- unique=True,
- ),
- IndexModel([("enabled", ASCENDING)], name="idx_sub_enabled"),
- ]
- )
- logger.info("Notification indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring notification indexes: {e}")
-
- async def _m_0005_idempotency(self) -> None:
- coll = self.db["idempotency_keys"]
- try:
- await coll.create_indexes(
- [
- IndexModel([("key", ASCENDING)], name="idx_idem_key_unique", unique=True),
- IndexModel([("created_at", ASCENDING)], name="idx_idem_created_ttl", expireAfterSeconds=3600),
- IndexModel([("status", ASCENDING)], name="idx_idem_status"),
- IndexModel([("event_type", ASCENDING)], name="idx_idem_event_type"),
- ]
- )
- logger.info("Idempotency indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring idempotency indexes: {e}")
-
- async def _m_0006_sagas(self) -> None:
- coll = self.db["sagas"]
- try:
- await coll.create_indexes(
- [
- IndexModel([("saga_id", ASCENDING)], name="idx_saga_id_unique", unique=True),
- IndexModel([("execution_id", ASCENDING)], name="idx_saga_execution"),
- IndexModel([("state", ASCENDING)], name="idx_saga_state"),
- IndexModel([("created_at", ASCENDING)], name="idx_saga_created_at"),
- IndexModel([("state", ASCENDING), ("created_at", ASCENDING)], name="idx_saga_state_created"),
- ]
- )
- logger.info("Saga indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring saga indexes: {e}")
-
- async def _m_0007_execution_results(self) -> None:
- coll = self.db["execution_results"]
- try:
- await coll.create_indexes(
- [
- IndexModel([("execution_id", ASCENDING)], name="idx_results_execution_unique", unique=True),
- IndexModel([("created_at", ASCENDING)], name="idx_results_created_at"),
- IndexModel(
- [("user_id", ASCENDING), ("created_at", DESCENDING)], name="idx_results_user_created_desc"
- ),
- ]
- )
- logger.info("Execution results indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring execution results indexes: {e}")
-
- async def _m_0008_dlq(self) -> None:
- coll = self.db["dlq_messages"]
- try:
- await coll.create_indexes(
- [
- IndexModel([("event_id", ASCENDING)], name="idx_dlq_event_id_unique", unique=True),
- IndexModel([("original_topic", ASCENDING)], name="idx_dlq_topic"),
- IndexModel([("event_type", ASCENDING)], name="idx_dlq_event_type"),
- IndexModel([("failed_at", DESCENDING)], name="idx_dlq_failed_desc"),
- IndexModel([("retry_count", ASCENDING)], name="idx_dlq_retry_count"),
- IndexModel([("status", ASCENDING)], name="idx_dlq_status"),
- IndexModel([("next_retry_at", ASCENDING)], name="idx_dlq_next_retry"),
- IndexModel(
- [("created_at", ASCENDING)], name="idx_dlq_created_ttl", expireAfterSeconds=7 * 24 * 3600
- ),
- ]
- )
- logger.info("DLQ indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring DLQ indexes: {e}")
-
- async def _m_0009_event_store_extra(self) -> None:
- events = self.db["events"]
- try:
- await events.create_indexes(
- [
- IndexModel(
- [("metadata.user_id", ASCENDING), ("event_type", ASCENDING)], name="idx_events_user_type"
- ),
- IndexModel(
- [("event_type", ASCENDING), ("metadata.user_id", ASCENDING), ("timestamp", DESCENDING)],
- name="idx_events_type_user_ts",
- ),
- ]
- )
- logger.info("Additional event store indexes ensured")
- except Exception as e:
- logger.warning(f"Failed ensuring event store extra indexes: {e}")
diff --git a/backend/app/dlq/__init__.py b/backend/app/dlq/__init__.py
index 084dbb3e..f047e9c4 100644
--- a/backend/app/dlq/__init__.py
+++ b/backend/app/dlq/__init__.py
@@ -7,7 +7,6 @@
from .models import (
AgeStatistics,
DLQBatchRetryResult,
- DLQFields,
DLQMessage,
DLQMessageFilter,
DLQMessageListResult,
@@ -23,15 +22,12 @@
)
__all__ = [
- # Core models
"DLQMessageStatus",
"RetryStrategy",
- "DLQFields",
"DLQMessage",
"DLQMessageUpdate",
"DLQMessageFilter",
"RetryPolicy",
- # Stats models
"TopicStatistic",
"EventTypeStatistic",
"AgeStatistics",
diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py
index e6d15ed4..96f1528f 100644
--- a/backend/app/dlq/manager.py
+++ b/backend/app/dlq/manager.py
@@ -1,19 +1,18 @@
import asyncio
import json
+import logging
from datetime import datetime, timezone
-from typing import Any, Awaitable, Callable, Mapping, Sequence
+from typing import Any, Awaitable, Callable
from confluent_kafka import Consumer, KafkaError, Message, Producer
from opentelemetry.trace import SpanKind
-from app.core.database_context import Collection, Database
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics.context import get_dlq_metrics
from app.core.tracing import EventAttributes
from app.core.tracing.utils import extract_trace_context, get_tracer, inject_trace_context
+from app.db.docs import DLQMessageDocument
from app.dlq.models import (
- DLQFields,
DLQMessage,
DLQMessageStatus,
DLQMessageUpdate,
@@ -21,23 +20,24 @@
RetryStrategy,
)
from app.domain.enums.kafka import GroupId, KafkaTopic
-from app.domain.events.event_models import CollectionNames
from app.events.schema.schema_registry import SchemaRegistryManager
-from app.infrastructure.mappers.dlq_mapper import DLQMapper
from app.settings import get_settings
class DLQManager(LifecycleEnabled):
def __init__(
self,
- database: Database,
consumer: Consumer,
producer: Producer,
+ schema_registry: SchemaRegistryManager,
+ logger: logging.Logger,
dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE,
retry_topic_suffix: str = "-retry",
default_retry_policy: RetryPolicy | None = None,
):
self.metrics = get_dlq_metrics()
+ self.schema_registry = schema_registry
+ self.logger = logger
self.dlq_topic = dlq_topic
self.retry_topic_suffix = retry_topic_suffix
self.default_retry_policy = default_retry_policy or RetryPolicy(
@@ -45,7 +45,6 @@ def __init__(
)
self.consumer: Consumer = consumer
self.producer: Producer = producer
- self.dlq_collection: Collection = database.get_collection(CollectionNames.DLQ_MESSAGES)
self._running = False
self._process_task: asyncio.Task[None] | None = None
@@ -64,6 +63,85 @@ def __init__(
"on_discard": [],
}
+ def _doc_to_message(self, doc: DLQMessageDocument) -> DLQMessage:
+ """Convert DLQMessageDocument to DLQMessage domain model."""
+ event = self.schema_registry.deserialize_json(doc.event)
+ return DLQMessage(
+ event_id=doc.event_id,
+ event=event,
+ event_type=doc.event_type,
+ original_topic=doc.original_topic,
+ error=doc.error,
+ retry_count=doc.retry_count,
+ failed_at=doc.failed_at,
+ status=doc.status,
+ producer_id=doc.producer_id,
+ created_at=doc.created_at,
+ last_updated=doc.last_updated,
+ next_retry_at=doc.next_retry_at,
+ retried_at=doc.retried_at,
+ discarded_at=doc.discarded_at,
+ discard_reason=doc.discard_reason,
+ dlq_offset=doc.dlq_offset,
+ dlq_partition=doc.dlq_partition,
+ last_error=doc.last_error,
+ headers=doc.headers,
+ )
+
+ def _message_to_doc(self, message: DLQMessage) -> DLQMessageDocument:
+ """Convert DLQMessage domain model to DLQMessageDocument."""
+ return DLQMessageDocument(
+ event=message.event.model_dump(),
+ event_id=message.event_id,
+ event_type=message.event_type,
+ original_topic=message.original_topic,
+ error=message.error,
+ retry_count=message.retry_count,
+ failed_at=message.failed_at,
+ status=message.status,
+ producer_id=message.producer_id,
+ created_at=message.created_at or datetime.now(timezone.utc),
+ last_updated=message.last_updated,
+ next_retry_at=message.next_retry_at,
+ retried_at=message.retried_at,
+ discarded_at=message.discarded_at,
+ discard_reason=message.discard_reason,
+ dlq_offset=message.dlq_offset,
+ dlq_partition=message.dlq_partition,
+ last_error=message.last_error,
+ headers=message.headers,
+ )
+
+ def _kafka_msg_to_message(self, msg: Message) -> DLQMessage:
+ """Parse Kafka message into DLQMessage."""
+ raw_bytes = msg.value()
+ raw: str = raw_bytes.decode("utf-8") if isinstance(raw_bytes, (bytes, bytearray)) else str(raw_bytes or "")
+ data: dict[str, Any] = json.loads(raw) if raw else {}
+
+ headers_list = msg.headers() or []
+ headers: dict[str, str] = {}
+ for k, v in headers_list:
+ headers[str(k)] = v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else (v or "")
+
+ event = self.schema_registry.deserialize_json(data.get("event", data))
+
+ return DLQMessage(
+ event_id=data.get("event_id", event.event_id),
+ event=event,
+ event_type=event.event_type,
+ original_topic=data.get("original_topic", headers.get("original_topic", "")),
+ error=data.get("error", headers.get("error", "Unknown error")),
+ retry_count=data.get("retry_count", int(headers.get("retry_count", 0))),
+ failed_at=datetime.fromisoformat(data["failed_at"])
+ if data.get("failed_at")
+ else datetime.now(timezone.utc),
+ status=DLQMessageStatus(data.get("status", DLQMessageStatus.PENDING)),
+ producer_id=data.get("producer_id", headers.get("producer_id", "unknown")),
+ dlq_offset=msg.offset(),
+ dlq_partition=msg.partition(),
+ headers=headers,
+ )
+
async def start(self) -> None:
"""Start DLQ manager"""
if self._running:
@@ -78,7 +156,7 @@ async def start(self) -> None:
self._process_task = asyncio.create_task(self._process_messages())
self._monitor_task = asyncio.create_task(self._monitor_dlq())
- logger.info("DLQ Manager started")
+ self.logger.info("DLQ Manager started")
async def stop(self) -> None:
"""Stop DLQ manager"""
@@ -100,7 +178,7 @@ async def stop(self) -> None:
self.consumer.close()
self.producer.flush(10)
- logger.info("DLQ Manager stopped")
+ self.logger.info("DLQ Manager stopped")
async def _process_messages(self) -> None:
while self._running:
@@ -113,14 +191,14 @@ async def _process_messages(self) -> None:
continue
start_time = asyncio.get_event_loop().time()
- dlq_message = await self._parse_message(msg)
+ dlq_message = self._kafka_msg_to_message(msg)
await self._record_message_metrics(dlq_message)
await self._process_message_with_tracing(msg, dlq_message)
await self._commit_and_record_duration(start_time)
except Exception as e:
- logger.error(f"Error in DLQ processing loop: {e}")
+ self.logger.error(f"Error in DLQ processing loop: {e}")
await asyncio.sleep(5)
async def _poll_message(self) -> Message | None:
@@ -133,15 +211,10 @@ async def _validate_message(self, msg: Message) -> bool:
error = msg.error()
if error and error.code() == KafkaError._PARTITION_EOF:
return False
- logger.error(f"Consumer error: {error}")
+ self.logger.error(f"Consumer error: {error}")
return False
return True
- async def _parse_message(self, msg: Message) -> DLQMessage:
- """Parse Kafka message into DLQMessage."""
- schema_registry = SchemaRegistryManager()
- return DLQMapper.from_kafka_message(msg, schema_registry)
-
def _extract_headers(self, msg: Message) -> dict[str, str]:
"""Extract headers from Kafka message."""
headers_list = msg.headers() or []
@@ -152,7 +225,7 @@ def _extract_headers(self, msg: Message) -> dict[str, str]:
async def _record_message_metrics(self, dlq_message: DLQMessage) -> None:
"""Record metrics for received DLQ message."""
- self.metrics.record_dlq_message_received(dlq_message.original_topic, dlq_message.event_type)
+ self.metrics.record_dlq_message_received(dlq_message.original_topic, str(dlq_message.event_type))
self.metrics.record_dlq_message_age(dlq_message.age_seconds)
async def _process_message_with_tracing(self, msg: Message, dlq_message: DLQMessage) -> None:
@@ -167,7 +240,7 @@ async def _process_message_with_tracing(self, msg: Message, dlq_message: DLQMess
kind=SpanKind.CONSUMER,
attributes={
str(EventAttributes.KAFKA_TOPIC): str(self.dlq_topic),
- str(EventAttributes.EVENT_TYPE): dlq_message.event_type,
+ str(EventAttributes.EVENT_TYPE): str(dlq_message.event_type),
str(EventAttributes.EVENT_ID): dlq_message.event_id or "",
},
):
@@ -183,10 +256,10 @@ async def _process_dlq_message(self, message: DLQMessage) -> None:
# Apply filters
for filter_func in self._filters:
if not filter_func(message):
- logger.info(f"Message {message.event_id} filtered out")
+ self.logger.info("Message filtered out", extra={"event_id": message.event_id})
return
- # Store in MongoDB
+ # Store in MongoDB via Beanie
await self._store_message(message)
# Get retry policy for topic
@@ -215,13 +288,34 @@ async def _store_message(self, message: DLQMessage) -> None:
message.status = DLQMessageStatus.PENDING
message.last_updated = datetime.now(timezone.utc)
- doc = DLQMapper.to_mongo_document(message)
+ doc = self._message_to_doc(message)
- await self.dlq_collection.update_one({DLQFields.EVENT_ID: message.event_id}, {"$set": doc}, upsert=True)
+ # Upsert using Beanie
+ existing = await DLQMessageDocument.find_one({"event_id": message.event_id})
+ if existing:
+ doc.id = existing.id
+ await doc.save()
async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) -> None:
- update_doc = DLQMapper.update_to_mongo(update)
- await self.dlq_collection.update_one({DLQFields.EVENT_ID: event_id}, {"$set": update_doc})
+ doc = await DLQMessageDocument.find_one({"event_id": event_id})
+ if not doc:
+ return
+
+ update_dict: dict[str, Any] = {"status": update.status, "last_updated": datetime.now(timezone.utc)}
+ if update.next_retry_at is not None:
+ update_dict["next_retry_at"] = update.next_retry_at
+ if update.retried_at is not None:
+ update_dict["retried_at"] = update.retried_at
+ if update.discarded_at is not None:
+ update_dict["discarded_at"] = update.discarded_at
+ if update.retry_count is not None:
+ update_dict["retry_count"] = update.retry_count
+ if update.discard_reason is not None:
+ update_dict["discard_reason"] = update.discard_reason
+ if update.last_error is not None:
+ update_dict["last_error"] = update.last_error
+
+ await doc.set(update_dict)
async def _retry_message(self, message: DLQMessage) -> None:
# Trigger before_retry callbacks
@@ -264,7 +358,7 @@ async def _retry_message(self, message: DLQMessage) -> None:
await asyncio.to_thread(self.producer.flush, timeout=5)
# Update metrics
- self.metrics.record_dlq_message_retried(message.original_topic, message.event_type, "success")
+ self.metrics.record_dlq_message_retried(message.original_topic, str(message.event_type), "success")
# Update status
await self._update_message_status(
@@ -279,11 +373,11 @@ async def _retry_message(self, message: DLQMessage) -> None:
# Trigger after_retry callbacks
await self._trigger_callbacks("after_retry", message, success=True)
- logger.info(f"Successfully retried message {message.event_id}")
+ self.logger.info("Successfully retried message", extra={"event_id": message.event_id})
async def _discard_message(self, message: DLQMessage, reason: str) -> None:
# Update metrics
- self.metrics.record_dlq_message_discarded(message.original_topic, message.event_type, reason)
+ self.metrics.record_dlq_message_discarded(message.original_topic, str(message.event_type), reason)
# Update status
await self._update_message_status(
@@ -298,23 +392,27 @@ async def _discard_message(self, message: DLQMessage, reason: str) -> None:
# Trigger callbacks
await self._trigger_callbacks("on_discard", message, reason)
- logger.warning(f"Discarded message {message.event_id} due to {reason}")
+ self.logger.warning("Discarded message", extra={"event_id": message.event_id, "reason": reason})
async def _monitor_dlq(self) -> None:
while self._running:
try:
- # Find messages ready for retry
+ # Find messages ready for retry using Beanie
now = datetime.now(timezone.utc)
- cursor = self.dlq_collection.find(
- {"status": DLQMessageStatus.SCHEDULED, "next_retry_at": {"$lte": now}}
- ).limit(100)
-
- async for doc in cursor:
- # Recreate DLQ message from MongoDB document
- message = DLQMapper.from_mongo_document(doc)
-
- # Retry message
+ docs = (
+ await DLQMessageDocument.find(
+ {
+ "status": DLQMessageStatus.SCHEDULED,
+ "next_retry_at": {"$lte": now},
+ }
+ )
+ .limit(100)
+ .to_list()
+ )
+
+ for doc in docs:
+ message = self._doc_to_message(doc)
await self._retry_message(message)
# Update queue size metrics
@@ -324,18 +422,17 @@ async def _monitor_dlq(self) -> None:
await asyncio.sleep(10)
except Exception as e:
- logger.error(f"Error in DLQ monitor: {e}")
+ self.logger.error(f"Error in DLQ monitor: {e}")
await asyncio.sleep(60)
async def _update_queue_metrics(self) -> None:
- # Get counts by topic
- pipeline: Sequence[Mapping[str, Any]] = [
- {"$match": {str(DLQFields.STATUS): {"$in": [DLQMessageStatus.PENDING, DLQMessageStatus.SCHEDULED]}}},
- {"$group": {"_id": f"${DLQFields.ORIGINAL_TOPIC}", "count": {"$sum": 1}}},
+ # Get counts by topic using Beanie aggregation
+ pipeline: list[dict[str, Any]] = [
+ {"$match": {"status": {"$in": [DLQMessageStatus.PENDING, DLQMessageStatus.SCHEDULED]}}},
+ {"$group": {"_id": "$original_topic", "count": {"$sum": 1}}},
]
- async for result in self.dlq_collection.aggregate(pipeline):
- # Note: OpenTelemetry doesn't have direct gauge set, using delta tracking
+ async for result in DLQMessageDocument.aggregate(pipeline):
self.metrics.update_dlq_queue_size(result["_id"], result["count"])
def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None:
@@ -353,28 +450,27 @@ async def _trigger_callbacks(self, event_type: str, *args: Any, **kwargs: Any) -
try:
await callback(*args, **kwargs)
except Exception as e:
- logger.error(f"Error in DLQ callback {callback.__name__}: {e}")
+ self.logger.error(f"Error in DLQ callback {callback.__name__}: {e}")
async def retry_message_manually(self, event_id: str) -> bool:
- doc = await self.dlq_collection.find_one({"event_id": event_id})
+ doc = await DLQMessageDocument.find_one({"event_id": event_id})
if not doc:
- logger.error(f"Message {event_id} not found in DLQ")
+ self.logger.error("Message not found in DLQ", extra={"event_id": event_id})
return False
# Guard against invalid states
- status = doc.get(str(DLQFields.STATUS))
- if status in {DLQMessageStatus.DISCARDED, DLQMessageStatus.RETRIED}:
- logger.info(f"Skipping manual retry for {event_id}: status={status}")
+ if doc.status in {DLQMessageStatus.DISCARDED, DLQMessageStatus.RETRIED}:
+ self.logger.info("Skipping manual retry", extra={"event_id": event_id, "status": str(doc.status)})
return False
- message = DLQMapper.from_mongo_document(doc)
-
+ message = self._doc_to_message(doc)
await self._retry_message(message)
return True
def create_dlq_manager(
- database: Database,
+ schema_registry: SchemaRegistryManager,
+ logger: logging.Logger,
dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE,
retry_topic_suffix: str = "-retry",
default_retry_policy: RetryPolicy | None = None,
@@ -403,9 +499,10 @@ def create_dlq_manager(
if default_retry_policy is None:
default_retry_policy = RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF)
return DLQManager(
- database=database,
consumer=consumer,
producer=producer,
+ schema_registry=schema_registry,
+ logger=logger,
dlq_topic=dlq_topic,
retry_topic_suffix=retry_topic_suffix,
default_retry_policy=default_retry_policy,
diff --git a/backend/app/dlq/models.py b/backend/app/dlq/models.py
index 2f577e14..fc8dd8c0 100644
--- a/backend/app/dlq/models.py
+++ b/backend/app/dlq/models.py
@@ -26,74 +26,34 @@ class RetryStrategy(StringEnum):
MANUAL = "manual"
-class DLQFields(StringEnum):
- """Database field names for DLQ messages collection."""
-
- EVENT_ID = "event_id"
- EVENT = "event"
- EVENT_TYPE = "event.event_type"
- ORIGINAL_TOPIC = "original_topic"
- ERROR = "error"
- RETRY_COUNT = "retry_count"
- FAILED_AT = "failed_at"
- STATUS = "status"
- CREATED_AT = "created_at"
- LAST_UPDATED = "last_updated"
- NEXT_RETRY_AT = "next_retry_at"
- RETRIED_AT = "retried_at"
- DISCARDED_AT = "discarded_at"
- DISCARD_REASON = "discard_reason"
- PRODUCER_ID = "producer_id"
- DLQ_OFFSET = "dlq_offset"
- DLQ_PARTITION = "dlq_partition"
- LAST_ERROR = "last_error"
-
-
@dataclass
class DLQMessage:
"""Unified DLQ message model for the entire system."""
- # Core fields - always required
- event: BaseEvent # The original event that failed
- original_topic: str # Topic where the event originally failed
- error: str # Error message from the failure
- retry_count: int # Number of retry attempts
- failed_at: datetime # When the failure occurred (UTC)
- status: DLQMessageStatus # Current status
- producer_id: str # ID of the producer that sent to DLQ
-
- # Optional fields
- event_id: str = ""
- created_at: datetime | None = None # When added to DLQ (UTC)
- last_updated: datetime | None = None # Last status change (UTC)
- next_retry_at: datetime | None = None # Next scheduled retry (UTC)
- retried_at: datetime | None = None # When last retried (UTC)
- discarded_at: datetime | None = None # When discarded (UTC)
- discard_reason: str | None = None # Why it was discarded
- dlq_offset: int | None = None # Kafka offset in DLQ topic
- dlq_partition: int | None = None # Kafka partition in DLQ topic
- last_error: str | None = None # Most recent error message
-
- # Kafka message headers (optional)
+ event_id: str
+ event: BaseEvent
+ event_type: EventType
+ original_topic: str
+ error: str
+ retry_count: int
+ failed_at: datetime
+ status: DLQMessageStatus
+ producer_id: str
+ created_at: datetime | None = None
+ last_updated: datetime | None = None
+ next_retry_at: datetime | None = None
+ retried_at: datetime | None = None
+ discarded_at: datetime | None = None
+ discard_reason: str | None = None
+ dlq_offset: int | None = None
+ dlq_partition: int | None = None
+ last_error: str | None = None
headers: dict[str, str] = field(default_factory=dict)
- def __post_init__(self) -> None:
- """Initialize computed fields."""
- if not self.event_id:
- self.event_id = self.event.event_id
- if not self.created_at:
- self.created_at = datetime.now(timezone.utc)
-
@property
def age_seconds(self) -> float:
- """Get message age in seconds since failure."""
return (datetime.now(timezone.utc) - self.failed_at).total_seconds()
- @property
- def event_type(self) -> EventType:
- """Get event type from the event."""
- return self.event.event_type
-
@dataclass
class DLQMessageUpdate:
diff --git a/backend/app/domain/admin/__init__.py b/backend/app/domain/admin/__init__.py
index a419a035..beb7ab03 100644
--- a/backend/app/domain/admin/__init__.py
+++ b/backend/app/domain/admin/__init__.py
@@ -5,21 +5,17 @@
)
from .replay_models import (
ReplayQuery,
- ReplaySession,
ReplaySessionData,
- ReplaySessionFields,
ReplaySessionStatusDetail,
ReplaySessionStatusInfo,
)
from .settings_models import (
AuditAction,
AuditLogEntry,
- AuditLogFields,
ExecutionLimits,
LogLevel,
MonitoringSettings,
SecuritySettings,
- SettingsFields,
SystemSettings,
)
@@ -29,8 +25,6 @@
"DerivedCountsDomain",
"RateLimitSummaryDomain",
# Settings
- "SettingsFields",
- "AuditLogFields",
"AuditAction",
"LogLevel",
"ExecutionLimits",
@@ -40,9 +34,7 @@
"AuditLogEntry",
# Replay
"ReplayQuery",
- "ReplaySession",
"ReplaySessionData",
- "ReplaySessionFields",
"ReplaySessionStatusDetail",
"ReplaySessionStatusInfo",
]
diff --git a/backend/app/domain/admin/overview_models.py b/backend/app/domain/admin/overview_models.py
index a208c953..23f91408 100644
--- a/backend/app/domain/admin/overview_models.py
+++ b/backend/app/domain/admin/overview_models.py
@@ -1,8 +1,10 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from dataclasses import field
from typing import List
+from pydantic.dataclasses import dataclass
+
from app.domain.events import Event, EventStatistics
from app.domain.user import User as DomainAdminUser
diff --git a/backend/app/domain/admin/replay_models.py b/backend/app/domain/admin/replay_models.py
index 220b3c8b..44d7d79c 100644
--- a/backend/app/domain/admin/replay_models.py
+++ b/backend/app/domain/admin/replay_models.py
@@ -1,88 +1,27 @@
-from dataclasses import dataclass, field, replace
-from datetime import datetime, timezone
+from dataclasses import field
+from datetime import datetime
from typing import Any
-from app.core.utils import StringEnum
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.replay import ReplayStatus
from app.domain.events.event_models import EventSummary
-
-
-class ReplaySessionFields(StringEnum):
- """Database field names for replay sessions."""
-
- SESSION_ID = "session_id"
- TYPE = "type"
- STATUS = "status"
- TOTAL_EVENTS = "total_events"
- REPLAYED_EVENTS = "replayed_events"
- FAILED_EVENTS = "failed_events"
- SKIPPED_EVENTS = "skipped_events"
- CORRELATION_ID = "correlation_id"
- CREATED_AT = "created_at"
- STARTED_AT = "started_at"
- COMPLETED_AT = "completed_at"
- ERROR = "error"
- CREATED_BY = "created_by"
- TARGET_SERVICE = "target_service"
- DRY_RUN = "dry_run"
-
-
-@dataclass
-class ReplaySession:
- session_id: str
- status: ReplayStatus
- total_events: int
- correlation_id: str
- created_at: datetime
- type: str = "replay_session"
- replayed_events: int = 0
- failed_events: int = 0
- skipped_events: int = 0
- started_at: datetime | None = None
- completed_at: datetime | None = None
- error: str | None = None
- created_by: str | None = None
- target_service: str | None = None
- dry_run: bool = False
- triggered_executions: list[str] = field(default_factory=list) # Track execution IDs created by replay
-
- @property
- def progress_percentage(self) -> float:
- """Calculate progress percentage."""
- if self.total_events == 0:
- return 0.0
- return round((self.replayed_events / self.total_events) * 100, 2)
-
- @property
- def is_completed(self) -> bool:
- """Check if session is completed."""
- return self.status in [ReplayStatus.COMPLETED, ReplayStatus.FAILED, ReplayStatus.CANCELLED]
-
- @property
- def is_running(self) -> bool:
- """Check if session is running."""
- return self.status == ReplayStatus.RUNNING
-
- def update_progress(self, replayed: int, failed: int = 0, skipped: int = 0) -> "ReplaySession":
- # Create new instance with updated values
- new_session = replace(self, replayed_events=replayed, failed_events=failed, skipped_events=skipped)
-
- # Check if completed and update status
- if new_session.replayed_events >= new_session.total_events:
- new_session = replace(new_session, status=ReplayStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
-
- return new_session
+from app.domain.replay.models import ReplaySessionState
@dataclass
class ReplaySessionStatusDetail:
- session: ReplaySession
+ """Status detail with computed metadata for admin API."""
+
+ session: ReplaySessionState
estimated_completion: datetime | None = None
execution_results: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class ReplaySessionStatusInfo:
+ """Lightweight status info for API responses."""
+
session_id: str
status: ReplayStatus
total_events: int
@@ -116,5 +55,5 @@ class ReplaySessionData:
total_events: int
replay_correlation_id: str
dry_run: bool
- query: dict[str, Any]
+ query: ReplayQuery
events_preview: list[EventSummary] = field(default_factory=list)
diff --git a/backend/app/domain/admin/replay_updates.py b/backend/app/domain/admin/replay_updates.py
index c4450a1c..c326565b 100644
--- a/backend/app/domain/admin/replay_updates.py
+++ b/backend/app/domain/admin/replay_updates.py
@@ -1,8 +1,9 @@
"""Domain models for replay session updates."""
-from dataclasses import dataclass
from datetime import datetime
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.replay import ReplayStatus
@@ -21,36 +22,3 @@ class ReplaySessionUpdate:
error: str | None = None
target_service: str | None = None
dry_run: bool | None = None
-
- def to_dict(self) -> dict[str, object]:
- """Convert to dictionary, excluding None values."""
- result: dict[str, object] = {}
-
- if self.status is not None:
- result["status"] = self.status.value if hasattr(self.status, "value") else self.status
- if self.total_events is not None:
- result["total_events"] = self.total_events
- if self.replayed_events is not None:
- result["replayed_events"] = self.replayed_events
- if self.failed_events is not None:
- result["failed_events"] = self.failed_events
- if self.skipped_events is not None:
- result["skipped_events"] = self.skipped_events
- if self.correlation_id is not None:
- result["correlation_id"] = self.correlation_id
- if self.started_at is not None:
- result["started_at"] = self.started_at
- if self.completed_at is not None:
- result["completed_at"] = self.completed_at
- if self.error is not None:
- result["error"] = self.error
- if self.target_service is not None:
- result["target_service"] = self.target_service
- if self.dry_run is not None:
- result["dry_run"] = self.dry_run
-
- return result
-
- def has_updates(self) -> bool:
- """Check if there are any updates to apply."""
- return bool(self.to_dict())
diff --git a/backend/app/domain/admin/settings_models.py b/backend/app/domain/admin/settings_models.py
index e3b0398a..cad09f3c 100644
--- a/backend/app/domain/admin/settings_models.py
+++ b/backend/app/domain/admin/settings_models.py
@@ -1,30 +1,10 @@
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
from typing import Any
-from app.core.utils import StringEnum
-
-
-class SettingsFields(StringEnum):
- """Database field names for settings collection."""
-
- ID = "_id"
- CREATED_AT = "created_at"
- UPDATED_AT = "updated_at"
- UPDATED_BY = "updated_by"
- EXECUTION_LIMITS = "execution_limits"
- SECURITY_SETTINGS = "security_settings"
- MONITORING_SETTINGS = "monitoring_settings"
+from pydantic.dataclasses import dataclass
-
-class AuditLogFields(StringEnum):
- """Database field names for audit log collection."""
-
- ACTION = "action"
- USER_ID = "user_id"
- USERNAME = "username"
- TIMESTAMP = "timestamp"
- CHANGES = "changes"
+from app.core.utils import StringEnum
class AuditAction(StringEnum):
diff --git a/backend/app/domain/events/__init__.py b/backend/app/domain/events/__init__.py
index d96d26bf..9216b541 100644
--- a/backend/app/domain/events/__init__.py
+++ b/backend/app/domain/events/__init__.py
@@ -3,7 +3,6 @@
ArchivedEvent,
Event,
EventAggregationResult,
- EventFields,
EventFilter,
EventListResult,
EventProjection,
@@ -18,7 +17,6 @@
"ArchivedEvent",
"Event",
"EventAggregationResult",
- "EventFields",
"EventFilter",
"EventListResult",
"EventMetadata",
diff --git a/backend/app/domain/events/event_metadata.py b/backend/app/domain/events/event_metadata.py
index ad44c8ed..c3a57440 100644
--- a/backend/app/domain/events/event_metadata.py
+++ b/backend/app/domain/events/event_metadata.py
@@ -1,7 +1,8 @@
-from dataclasses import asdict, dataclass, field, replace
-from typing import Any
+from dataclasses import field, replace
from uuid import uuid4
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.common import Environment
@@ -17,29 +18,6 @@ class EventMetadata:
user_agent: str | None = None
environment: Environment = Environment.PRODUCTION
- def to_dict(self, exclude_none: bool = True) -> dict[str, Any]:
- result = asdict(self)
- if isinstance(result.get("environment"), Environment):
- result["environment"] = result["environment"].value
- if exclude_none:
- return {k: v for k, v in result.items() if v is not None}
- return result
-
- @classmethod
- def from_dict(cls, data: dict[str, Any]) -> "EventMetadata":
- env = data.get("environment", Environment.PRODUCTION)
- if isinstance(env, str):
- env = Environment(env)
- return cls(
- service_name=data.get("service_name", "unknown"),
- service_version=data.get("service_version", "1.0"),
- correlation_id=data.get("correlation_id", str(uuid4())),
- user_id=data.get("user_id"),
- ip_address=data.get("ip_address"),
- user_agent=data.get("user_agent"),
- environment=env,
- )
-
def with_correlation(self, correlation_id: str) -> "EventMetadata":
return replace(self, correlation_id=correlation_id)
diff --git a/backend/app/domain/events/event_models.py b/backend/app/domain/events/event_models.py
index 4c87c332..3dc57627 100644
--- a/backend/app/domain/events/event_models.py
+++ b/backend/app/domain/events/event_models.py
@@ -1,7 +1,9 @@
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime
from typing import Any
+from pydantic.dataclasses import dataclass
+
from app.core.utils import StringEnum
from app.domain.enums.events import EventType
from app.domain.events.event_metadata import EventMetadata
@@ -10,41 +12,6 @@
MongoQuery = dict[str, MongoQueryValue]
-class EventFields(StringEnum):
- """Database field names for events collection."""
-
- ID = "_id"
- EVENT_ID = "event_id"
- EVENT_TYPE = "event_type"
- EVENT_VERSION = "event_version"
- TIMESTAMP = "timestamp"
- AGGREGATE_ID = "aggregate_id"
- METADATA = "metadata"
- PAYLOAD = "payload"
- STORED_AT = "stored_at"
- TTL_EXPIRES_AT = "ttl_expires_at"
- STATUS = "status"
- ERROR = "error"
-
- # Metadata sub-fields
- METADATA_CORRELATION_ID = "metadata.correlation_id"
- METADATA_USER_ID = "metadata.user_id"
- METADATA_SERVICE_NAME = "metadata.service_name"
- METADATA_SERVICE_VERSION = "metadata.service_version"
- METADATA_IP_ADDRESS = "metadata.ip_address"
- METADATA_USER_AGENT = "metadata.user_agent"
-
- # Payload sub-fields for common queries
- PAYLOAD_EXECUTION_ID = "payload.execution_id"
- PAYLOAD_POD_NAME = "payload.pod_name"
- PAYLOAD_DURATION_SECONDS = "payload.duration_seconds"
-
- # Archive fields
- DELETED_AT = "_deleted_at"
- DELETED_BY = "_deleted_by"
- DELETION_REASON = "_deletion_reason"
-
-
class EventSortOrder(StringEnum):
ASC = "asc"
DESC = "desc"
@@ -116,7 +83,6 @@ class EventFilter:
start_time: datetime | None = None
end_time: datetime | None = None
search_text: str | None = None
- text_search: str | None = None
status: str | None = None
@@ -125,7 +91,7 @@ class EventQuery:
"""Query parameters for event search."""
filter: EventFilter
- sort_by: str = EventFields.TIMESTAMP
+ sort_by: str = "timestamp"
sort_order: EventSortOrder = EventSortOrder.DESC
limit: int = 100
skip: int = 0
diff --git a/backend/app/domain/events/query_builders.py b/backend/app/domain/events/query_builders.py
index dc7549f3..1be73e4e 100644
--- a/backend/app/domain/events/query_builders.py
+++ b/backend/app/domain/events/query_builders.py
@@ -1,8 +1,6 @@
from datetime import datetime
from typing import Any
-from app.domain.events.event_models import EventFields
-
class AggregationStages:
@staticmethod
@@ -60,14 +58,14 @@ class EventStatsAggregation:
@staticmethod
def build_overview_pipeline(start_time: datetime) -> list[dict[str, Any]]:
return [
- AggregationStages.match({EventFields.TIMESTAMP: {"$gte": start_time}}),
+ AggregationStages.match({"timestamp": {"$gte": start_time}}),
AggregationStages.group(
{
"_id": None,
"total_events": AggregationStages.sum(),
- "event_types": AggregationStages.add_to_set(f"${EventFields.EVENT_TYPE}"),
- "unique_users": AggregationStages.add_to_set(f"${EventFields.METADATA_USER_ID}"),
- "services": AggregationStages.add_to_set(f"${EventFields.METADATA_SERVICE_NAME}"),
+ "event_types": AggregationStages.add_to_set("$event_type"),
+ "unique_users": AggregationStages.add_to_set("$metadata.user_id"),
+ "services": AggregationStages.add_to_set("$metadata.service_name"),
}
),
AggregationStages.project(
@@ -84,8 +82,8 @@ def build_overview_pipeline(start_time: datetime) -> list[dict[str, Any]]:
@staticmethod
def build_event_types_pipeline(start_time: datetime, limit: int = 10) -> list[dict[str, Any]]:
return [
- AggregationStages.match({EventFields.TIMESTAMP: {"$gte": start_time}}),
- AggregationStages.group({"_id": f"${EventFields.EVENT_TYPE}", "count": AggregationStages.sum()}),
+ AggregationStages.match({"timestamp": {"$gte": start_time}}),
+ AggregationStages.group({"_id": "$event_type", "count": AggregationStages.sum()}),
AggregationStages.sort({"count": -1}),
AggregationStages.limit(limit),
]
@@ -93,9 +91,9 @@ def build_event_types_pipeline(start_time: datetime, limit: int = 10) -> list[di
@staticmethod
def build_hourly_events_pipeline(start_time: datetime) -> list[dict[str, Any]]:
return [
- AggregationStages.match({EventFields.TIMESTAMP: {"$gte": start_time}}),
+ AggregationStages.match({"timestamp": {"$gte": start_time}}),
AggregationStages.group(
- {"_id": AggregationStages.date_to_string(f"${EventFields.TIMESTAMP}"), "count": AggregationStages.sum()}
+ {"_id": AggregationStages.date_to_string("$timestamp"), "count": AggregationStages.sum()}
),
AggregationStages.sort({"_id": 1}),
]
@@ -103,8 +101,8 @@ def build_hourly_events_pipeline(start_time: datetime) -> list[dict[str, Any]]:
@staticmethod
def build_top_users_pipeline(start_time: datetime, limit: int = 10) -> list[dict[str, Any]]:
return [
- AggregationStages.match({EventFields.TIMESTAMP: {"$gte": start_time}}),
- AggregationStages.group({"_id": f"${EventFields.METADATA_USER_ID}", "count": AggregationStages.sum()}),
+ AggregationStages.match({"timestamp": {"$gte": start_time}}),
+ AggregationStages.group({"_id": "$metadata.user_id", "count": AggregationStages.sum()}),
AggregationStages.sort({"count": -1}),
AggregationStages.limit(limit),
]
@@ -114,14 +112,12 @@ def build_avg_duration_pipeline(start_time: datetime, event_type: str) -> list[d
return [
AggregationStages.match(
{
- EventFields.TIMESTAMP: {"$gte": start_time},
- EventFields.EVENT_TYPE: event_type,
- EventFields.PAYLOAD_DURATION_SECONDS: {"$exists": True},
+ "timestamp": {"$gte": start_time},
+ "event_type": event_type,
+ "payload.duration_seconds": {"$exists": True},
}
),
- AggregationStages.group(
- {"_id": None, "avg_duration": AggregationStages.avg(f"${EventFields.PAYLOAD_DURATION_SECONDS}")}
- ),
+ AggregationStages.group({"_id": None, "avg_duration": AggregationStages.avg("$payload.duration_seconds")}),
]
diff --git a/backend/app/domain/exceptions.py b/backend/app/domain/exceptions.py
new file mode 100644
index 00000000..79b46b43
--- /dev/null
+++ b/backend/app/domain/exceptions.py
@@ -0,0 +1,57 @@
+class DomainError(Exception):
+ """Base for all domain errors."""
+
+ def __init__(self, message: str) -> None:
+ self.message = message
+ super().__init__(message)
+
+
+class NotFoundError(DomainError):
+ """Entity not found (maps to 404)."""
+
+ def __init__(self, entity: str, identifier: str) -> None:
+ self.entity = entity
+ self.identifier = identifier
+ super().__init__(f"{entity} '{identifier}' not found")
+
+
+class ValidationError(DomainError):
+ """Business validation failed (maps to 422)."""
+
+ pass
+
+
+class ThrottledError(DomainError):
+ """Rate limit exceeded (maps to 429)."""
+
+ pass
+
+
+class ConflictError(DomainError):
+ """State conflict - duplicate, already exists, etc (maps to 409)."""
+
+ pass
+
+
+class UnauthorizedError(DomainError):
+ """Authentication required (maps to 401)."""
+
+ pass
+
+
+class ForbiddenError(DomainError):
+ """Authenticated but not permitted (maps to 403)."""
+
+ pass
+
+
+class InvalidStateError(DomainError):
+ """Invalid state for operation (maps to 400)."""
+
+ pass
+
+
+class InfrastructureError(DomainError):
+ """Infrastructure failure - DB, Kafka, K8s, etc (maps to 500)."""
+
+ pass
diff --git a/backend/app/domain/execution/__init__.py b/backend/app/domain/execution/__init__.py
index fb725417..d4275c9b 100644
--- a/backend/app/domain/execution/__init__.py
+++ b/backend/app/domain/execution/__init__.py
@@ -1,11 +1,12 @@
from .exceptions import (
EventPublishError,
ExecutionNotFoundError,
- ExecutionServiceError,
RuntimeNotSupportedError,
)
from .models import (
DomainExecution,
+ DomainExecutionCreate,
+ DomainExecutionUpdate,
ExecutionResultDomain,
LanguageInfoDomain,
ResourceLimitsDomain,
@@ -14,11 +15,12 @@
__all__ = [
"DomainExecution",
+ "DomainExecutionCreate",
+ "DomainExecutionUpdate",
"ExecutionResultDomain",
"LanguageInfoDomain",
"ResourceLimitsDomain",
"ResourceUsageDomain",
- "ExecutionServiceError",
"RuntimeNotSupportedError",
"EventPublishError",
"ExecutionNotFoundError",
diff --git a/backend/app/domain/execution/exceptions.py b/backend/app/domain/execution/exceptions.py
index 03d3b3b7..25e0ceb5 100644
--- a/backend/app/domain/execution/exceptions.py
+++ b/backend/app/domain/execution/exceptions.py
@@ -1,22 +1,25 @@
-class ExecutionServiceError(Exception):
- """Base exception for execution service errors."""
+from app.domain.exceptions import InfrastructureError, NotFoundError, ValidationError
- pass
+class ExecutionNotFoundError(NotFoundError):
+ """Raised when execution is not found."""
-class RuntimeNotSupportedError(ExecutionServiceError):
- """Raised when requested runtime is not supported."""
-
- pass
+ def __init__(self, execution_id: str) -> None:
+ super().__init__("Execution", execution_id)
-class EventPublishError(ExecutionServiceError):
- """Raised when event publishing fails."""
+class RuntimeNotSupportedError(ValidationError):
+ """Raised when requested runtime is not supported."""
- pass
+ def __init__(self, lang: str, version: str) -> None:
+ self.lang = lang
+ self.version = version
+ super().__init__(f"Runtime not supported: {lang} {version}")
-class ExecutionNotFoundError(ExecutionServiceError):
- """Raised when execution is not found."""
+class EventPublishError(InfrastructureError):
+ """Raised when event publishing fails."""
- pass
+ def __init__(self, event_type: str, reason: str) -> None:
+ self.event_type = event_type
+ super().__init__(f"Failed to publish {event_type}: {reason}")
diff --git a/backend/app/domain/execution/models.py b/backend/app/domain/execution/models.py
index ab49ff6a..2bd30956 100644
--- a/backend/app/domain/execution/models.py
+++ b/backend/app/domain/execution/models.py
@@ -1,14 +1,24 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import uuid4
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.execution import ExecutionStatus
from app.domain.enums.storage import ExecutionErrorType
+@dataclass
+class ResourceUsageDomain:
+ execution_time_wall_seconds: float = 0.0
+ cpu_time_jiffies: int = 0
+ clk_tck_hertz: int = 0
+ peak_memory_kb: int = 0
+
+
@dataclass
class DomainExecution:
execution_id: str = field(default_factory=lambda: str(uuid4()))
@@ -33,35 +43,10 @@ class ExecutionResultDomain:
exit_code: int
stdout: str
stderr: str
- resource_usage: ResourceUsageDomain
+ resource_usage: ResourceUsageDomain | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
metadata: dict[str, Any] = field(default_factory=dict)
- error_type: Optional[ExecutionErrorType] = None
-
-
-@dataclass
-class ResourceUsageDomain:
- execution_time_wall_seconds: float
- cpu_time_jiffies: int
- clk_tck_hertz: int
- peak_memory_kb: int
-
- def to_dict(self) -> dict[str, Any]:
- return {
- "execution_time_wall_seconds": float(self.execution_time_wall_seconds),
- "cpu_time_jiffies": int(self.cpu_time_jiffies),
- "clk_tck_hertz": int(self.clk_tck_hertz),
- "peak_memory_kb": int(self.peak_memory_kb),
- }
-
- @staticmethod
- def from_dict(data: dict[str, Any]) -> "ResourceUsageDomain":
- return ResourceUsageDomain(
- execution_time_wall_seconds=float(data.get("execution_time_wall_seconds", 0.0)),
- cpu_time_jiffies=int(data.get("cpu_time_jiffies", 0)),
- clk_tck_hertz=int(data.get("clk_tck_hertz", 0)),
- peak_memory_kb=int(data.get("peak_memory_kb", 0)),
- )
+ error_type: ExecutionErrorType | None = None
@dataclass
@@ -82,3 +67,26 @@ class ResourceLimitsDomain:
memory_request: str
execution_timeout: int
supported_runtimes: dict[str, LanguageInfoDomain]
+
+
+@dataclass
+class DomainExecutionCreate:
+ """Execution creation data for repository."""
+
+ script: str
+ user_id: str
+ lang: str = "python"
+ lang_version: str = "3.11"
+ status: ExecutionStatus = ExecutionStatus.QUEUED
+
+
+@dataclass
+class DomainExecutionUpdate:
+ """Execution update data for repository."""
+
+ status: Optional[ExecutionStatus] = None
+ stdout: Optional[str] = None
+ stderr: Optional[str] = None
+ exit_code: Optional[int] = None
+ error_type: Optional[ExecutionErrorType] = None
+ resource_usage: Optional[ResourceUsageDomain] = None
diff --git a/backend/app/domain/idempotency/models.py b/backend/app/domain/idempotency/models.py
index f3001c8f..38fba578 100644
--- a/backend/app/domain/idempotency/models.py
+++ b/backend/app/domain/idempotency/models.py
@@ -1,9 +1,10 @@
from __future__ import annotations
-from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Optional
+from pydantic.dataclasses import dataclass
+
from app.core.utils import StringEnum
diff --git a/backend/app/domain/notification/__init__.py b/backend/app/domain/notification/__init__.py
index bf8fba98..04dfa3dc 100644
--- a/backend/app/domain/notification/__init__.py
+++ b/backend/app/domain/notification/__init__.py
@@ -1,11 +1,25 @@
+from .exceptions import (
+ NotificationNotFoundError,
+ NotificationThrottledError,
+ NotificationValidationError,
+)
from .models import (
DomainNotification,
+ DomainNotificationCreate,
DomainNotificationListResult,
DomainNotificationSubscription,
+ DomainNotificationUpdate,
+ DomainSubscriptionUpdate,
)
__all__ = [
"DomainNotification",
- "DomainNotificationSubscription",
+ "DomainNotificationCreate",
"DomainNotificationListResult",
+ "DomainNotificationSubscription",
+ "DomainNotificationUpdate",
+ "DomainSubscriptionUpdate",
+ "NotificationNotFoundError",
+ "NotificationThrottledError",
+ "NotificationValidationError",
]
diff --git a/backend/app/domain/notification/exceptions.py b/backend/app/domain/notification/exceptions.py
new file mode 100644
index 00000000..e39a3473
--- /dev/null
+++ b/backend/app/domain/notification/exceptions.py
@@ -0,0 +1,24 @@
+from app.domain.exceptions import NotFoundError, ThrottledError, ValidationError
+
+
+class NotificationNotFoundError(NotFoundError):
+ """Raised when a notification is not found."""
+
+ def __init__(self, notification_id: str) -> None:
+ super().__init__("Notification", notification_id)
+
+
+class NotificationThrottledError(ThrottledError):
+ """Raised when notification rate limit is exceeded."""
+
+ def __init__(self, user_id: str, limit: int, window_hours: int) -> None:
+ self.user_id = user_id
+ self.limit = limit
+ self.window_hours = window_hours
+ super().__init__(f"Rate limit exceeded for user '{user_id}': max {limit} per {window_hours}h")
+
+
+class NotificationValidationError(ValidationError):
+ """Raised when notification validation fails."""
+
+ pass
diff --git a/backend/app/domain/notification/models.py b/backend/app/domain/notification/models.py
index f46c1bc8..8a1bac45 100644
--- a/backend/app/domain/notification/models.py
+++ b/backend/app/domain/notification/models.py
@@ -1,10 +1,12 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.notification import (
NotificationChannel,
NotificationSeverity,
@@ -69,3 +71,51 @@ class DomainNotificationListResult:
notifications: list[DomainNotification]
total: int
unread_count: int
+
+
+@dataclass
+class DomainNotificationCreate:
+ """Data for creating a notification."""
+
+ user_id: str
+ channel: NotificationChannel
+ subject: str
+ body: str
+ severity: NotificationSeverity = NotificationSeverity.MEDIUM
+ action_url: str | None = None
+ tags: list[str] = field(default_factory=list)
+ scheduled_for: datetime | None = None
+ webhook_url: str | None = None
+ webhook_headers: dict[str, str] | None = None
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class DomainNotificationUpdate:
+ """Data for updating a notification."""
+
+ status: NotificationStatus | None = None
+ sent_at: datetime | None = None
+ delivered_at: datetime | None = None
+ read_at: datetime | None = None
+ clicked_at: datetime | None = None
+ failed_at: datetime | None = None
+ retry_count: int | None = None
+ error_message: str | None = None
+
+
+@dataclass
+class DomainSubscriptionUpdate:
+ """Data for updating a subscription."""
+
+ enabled: bool | None = None
+ severities: list[NotificationSeverity] | None = None
+ include_tags: list[str] | None = None
+ exclude_tags: list[str] | None = None
+ webhook_url: str | None = None
+ slack_webhook: str | None = None
+ quiet_hours_enabled: bool | None = None
+ quiet_hours_start: str | None = None
+ quiet_hours_end: str | None = None
+ timezone: str | None = None
+ batch_interval_minutes: int | None = None
diff --git a/backend/app/domain/rate_limit/rate_limit_models.py b/backend/app/domain/rate_limit/rate_limit_models.py
index 59713554..08ef9460 100644
--- a/backend/app/domain/rate_limit/rate_limit_models.py
+++ b/backend/app/domain/rate_limit/rate_limit_models.py
@@ -1,8 +1,10 @@
import re
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
from typing import Dict, List, Optional
+from pydantic.dataclasses import dataclass
+
from app.core.utils import StringEnum
diff --git a/backend/app/domain/replay/__init__.py b/backend/app/domain/replay/__init__.py
index b3f36291..28259c97 100644
--- a/backend/app/domain/replay/__init__.py
+++ b/backend/app/domain/replay/__init__.py
@@ -1,3 +1,4 @@
+from .exceptions import ReplayOperationError, ReplaySessionNotFoundError
from .models import (
CleanupResult,
ReplayConfig,
@@ -7,9 +8,11 @@
)
__all__ = [
- "ReplayFilter",
+ "CleanupResult",
"ReplayConfig",
- "ReplaySessionState",
+ "ReplayFilter",
+ "ReplayOperationError",
"ReplayOperationResult",
- "CleanupResult",
+ "ReplaySessionNotFoundError",
+ "ReplaySessionState",
]
diff --git a/backend/app/domain/replay/exceptions.py b/backend/app/domain/replay/exceptions.py
new file mode 100644
index 00000000..45514df0
--- /dev/null
+++ b/backend/app/domain/replay/exceptions.py
@@ -0,0 +1,17 @@
+from app.domain.exceptions import InfrastructureError, NotFoundError
+
+
+class ReplaySessionNotFoundError(NotFoundError):
+ """Raised when a replay session is not found."""
+
+ def __init__(self, session_id: str) -> None:
+ super().__init__("Replay session", session_id)
+
+
+class ReplayOperationError(InfrastructureError):
+ """Raised when a replay operation fails."""
+
+ def __init__(self, session_id: str, operation: str, reason: str) -> None:
+ self.session_id = session_id
+ self.operation = operation
+ super().__init__(f"Failed to {operation} session '{session_id}': {reason}")
diff --git a/backend/app/domain/replay/models.py b/backend/app/domain/replay/models.py
index 36195dcd..17e241b3 100644
--- a/backend/app/domain/replay/models.py
+++ b/backend/app/domain/replay/models.py
@@ -1,8 +1,10 @@
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
from typing import Any, Dict, List
+from uuid import uuid4
from pydantic import BaseModel, Field, PrivateAttr
+from pydantic.dataclasses import dataclass
from app.domain.enums.events import EventType
from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
@@ -56,7 +58,7 @@ def to_mongo_query(self) -> Dict[str, Any]:
class ReplayConfig(BaseModel):
replay_type: ReplayType
target: ReplayTarget = ReplayTarget.KAFKA
- filter: ReplayFilter
+ filter: ReplayFilter = Field(default_factory=ReplayFilter)
speed_multiplier: float = Field(default=1.0, ge=0.1, le=100.0)
preserve_timestamps: bool = False
@@ -83,7 +85,7 @@ def get_progress_callback(self) -> Any:
@dataclass
class ReplaySessionState:
- """Domain replay session model used by services only."""
+ """Domain replay session model used by services and repository."""
session_id: str
config: ReplayConfig
@@ -101,6 +103,14 @@ class ReplaySessionState:
errors: list[dict[str, Any]] = field(default_factory=list)
+ # Tracking and admin fields
+ correlation_id: str = field(default_factory=lambda: str(uuid4()))
+ created_by: str | None = None
+ target_service: str | None = None
+ dry_run: bool = False
+ triggered_executions: list[str] = field(default_factory=list)
+ error: str | None = None
+
@dataclass
class ReplayOperationResult:
diff --git a/backend/app/domain/saga/__init__.py b/backend/app/domain/saga/__init__.py
index 0d485d3e..e489f3b1 100644
--- a/backend/app/domain/saga/__init__.py
+++ b/backend/app/domain/saga/__init__.py
@@ -1,10 +1,13 @@
from app.domain.saga.exceptions import (
SagaAccessDeniedError,
- SagaError,
+ SagaConcurrencyError,
SagaInvalidStateError,
SagaNotFoundError,
+ SagaTimeoutError,
)
from app.domain.saga.models import (
+ DomainResourceAllocation,
+ DomainResourceAllocationCreate,
Saga,
SagaConfig,
SagaFilter,
@@ -14,14 +17,17 @@
)
__all__ = [
+ "DomainResourceAllocation",
+ "DomainResourceAllocationCreate",
"Saga",
"SagaConfig",
"SagaInstance",
"SagaFilter",
"SagaListResult",
"SagaQuery",
- "SagaError",
"SagaNotFoundError",
"SagaAccessDeniedError",
"SagaInvalidStateError",
+ "SagaConcurrencyError",
+ "SagaTimeoutError",
]
diff --git a/backend/app/domain/saga/exceptions.py b/backend/app/domain/saga/exceptions.py
index f7080368..ccf433a0 100644
--- a/backend/app/domain/saga/exceptions.py
+++ b/backend/app/domain/saga/exceptions.py
@@ -1,40 +1,53 @@
-class SagaError(Exception):
- """Base exception for saga-related errors."""
+from app.domain.exceptions import ConflictError, ForbiddenError, InfrastructureError, InvalidStateError, NotFoundError
- pass
-
-class SagaNotFoundError(SagaError):
+class SagaNotFoundError(NotFoundError):
"""Raised when a saga is not found."""
- pass
+ def __init__(self, saga_id: str) -> None:
+ super().__init__("Saga", saga_id)
-class SagaAccessDeniedError(SagaError):
+class SagaAccessDeniedError(ForbiddenError):
"""Raised when access to a saga is denied."""
- pass
+ def __init__(self, saga_id: str, user_id: str) -> None:
+ self.saga_id = saga_id
+ self.user_id = user_id
+ super().__init__(f"Access denied to saga '{saga_id}' for user '{user_id}'")
-class SagaInvalidStateError(SagaError):
+class SagaInvalidStateError(InvalidStateError):
"""Raised when a saga operation is invalid for the current state."""
- pass
+ def __init__(self, saga_id: str, current_state: str, operation: str) -> None:
+ self.saga_id = saga_id
+ self.current_state = current_state
+ self.operation = operation
+ super().__init__(f"Cannot {operation} saga '{saga_id}' in state '{current_state}'")
-class SagaCompensationError(SagaError):
+class SagaCompensationError(InfrastructureError):
"""Raised when saga compensation fails."""
- pass
+ def __init__(self, saga_id: str, step: str, reason: str) -> None:
+ self.saga_id = saga_id
+ self.step = step
+ super().__init__(f"Compensation failed for saga '{saga_id}' at step '{step}': {reason}")
-class SagaTimeoutError(SagaError):
+class SagaTimeoutError(InfrastructureError):
"""Raised when a saga times out."""
- pass
+ def __init__(self, saga_id: str, timeout_seconds: int) -> None:
+ self.saga_id = saga_id
+ self.timeout_seconds = timeout_seconds
+ super().__init__(f"Saga '{saga_id}' timed out after {timeout_seconds}s")
-class SagaConcurrencyError(SagaError):
+class SagaConcurrencyError(ConflictError):
"""Raised when there's a concurrency conflict with saga operations."""
- pass
+ def __init__(self, saga_id: str) -> None:
+ self.saga_id = saga_id
+ super().__init__(f"Concurrency conflict for saga '{saga_id}'")
diff --git a/backend/app/domain/saga/models.py b/backend/app/domain/saga/models.py
index 1998258b..a885c3bd 100644
--- a/backend/app/domain/saga/models.py
+++ b/backend/app/domain/saga/models.py
@@ -1,8 +1,10 @@
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
from typing import Any
from uuid import uuid4
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.saga import SagaState
@@ -119,3 +121,31 @@ class SagaInstance:
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: datetime | None = None
retry_count: int = 0
+
+
+@dataclass
+class DomainResourceAllocation:
+ """Domain model for resource allocation."""
+
+ allocation_id: str
+ execution_id: str
+ language: str
+ cpu_request: str
+ memory_request: str
+ cpu_limit: str
+ memory_limit: str
+ status: str = "active"
+ allocated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ released_at: datetime | None = None
+
+
+@dataclass
+class DomainResourceAllocationCreate:
+ """Data for creating a resource allocation."""
+
+ execution_id: str
+ language: str
+ cpu_request: str
+ memory_request: str
+ cpu_limit: str
+ memory_limit: str
diff --git a/backend/app/domain/saved_script/__init__.py b/backend/app/domain/saved_script/__init__.py
index 444470f3..f2bede41 100644
--- a/backend/app/domain/saved_script/__init__.py
+++ b/backend/app/domain/saved_script/__init__.py
@@ -1,3 +1,4 @@
+from .exceptions import SavedScriptNotFoundError
from .models import (
DomainSavedScript,
DomainSavedScriptCreate,
@@ -8,4 +9,5 @@
"DomainSavedScript",
"DomainSavedScriptCreate",
"DomainSavedScriptUpdate",
+ "SavedScriptNotFoundError",
]
diff --git a/backend/app/domain/saved_script/exceptions.py b/backend/app/domain/saved_script/exceptions.py
new file mode 100644
index 00000000..d6c6d916
--- /dev/null
+++ b/backend/app/domain/saved_script/exceptions.py
@@ -0,0 +1,8 @@
+from app.domain.exceptions import NotFoundError
+
+
+class SavedScriptNotFoundError(NotFoundError):
+ """Raised when a saved script is not found."""
+
+ def __init__(self, script_id: str) -> None:
+ super().__init__("Script", script_id)
diff --git a/backend/app/domain/saved_script/models.py b/backend/app/domain/saved_script/models.py
index ba819cbd..08622426 100644
--- a/backend/app/domain/saved_script/models.py
+++ b/backend/app/domain/saved_script/models.py
@@ -1,8 +1,10 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
+from pydantic.dataclasses import dataclass
+
@dataclass
class DomainSavedScriptBase:
diff --git a/backend/app/domain/sse/models.py b/backend/app/domain/sse/models.py
index e4dfa5fe..c8a59e8c 100644
--- a/backend/app/domain/sse/models.py
+++ b/backend/app/domain/sse/models.py
@@ -1,8 +1,9 @@
from __future__ import annotations
-from dataclasses import dataclass
from datetime import datetime
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.execution import ExecutionStatus
diff --git a/backend/app/domain/user/__init__.py b/backend/app/domain/user/__init__.py
index e81c436d..54601b31 100644
--- a/backend/app/domain/user/__init__.py
+++ b/backend/app/domain/user/__init__.py
@@ -1,5 +1,13 @@
from app.domain.enums.user import UserRole
+from .exceptions import (
+ AdminAccessRequiredError,
+ AuthenticationRequiredError,
+ CSRFValidationError,
+ InvalidCredentialsError,
+ TokenExpiredError,
+ UserNotFoundError,
+)
from .settings_models import (
CachedSettings,
DomainEditorSettings,
@@ -10,6 +18,8 @@
DomainUserSettingsUpdate,
)
from .user_models import (
+ DomainUserCreate,
+ DomainUserUpdate,
PasswordReset,
User,
UserCreation,
@@ -21,20 +31,28 @@
)
__all__ = [
+ "AdminAccessRequiredError",
+ "AuthenticationRequiredError",
+ "CachedSettings",
+ "CSRFValidationError",
+ "DomainEditorSettings",
+ "DomainNotificationSettings",
+ "DomainSettingsEvent",
+ "DomainSettingsHistoryEntry",
+ "DomainUserCreate",
+ "DomainUserSettings",
+ "DomainUserSettingsUpdate",
+ "DomainUserUpdate",
+ "InvalidCredentialsError",
+ "PasswordReset",
+ "TokenExpiredError",
"User",
- "UserUpdate",
- "UserListResult",
"UserCreation",
- "PasswordReset",
"UserFields",
"UserFilterType",
- "UserSearchFilter",
+ "UserListResult",
+ "UserNotFoundError",
"UserRole",
- "DomainNotificationSettings",
- "DomainEditorSettings",
- "DomainUserSettings",
- "DomainUserSettingsUpdate",
- "DomainSettingsEvent",
- "DomainSettingsHistoryEntry",
- "CachedSettings",
+ "UserSearchFilter",
+ "UserUpdate",
]
diff --git a/backend/app/domain/user/exceptions.py b/backend/app/domain/user/exceptions.py
new file mode 100644
index 00000000..dc1b9acb
--- /dev/null
+++ b/backend/app/domain/user/exceptions.py
@@ -0,0 +1,45 @@
+from app.domain.exceptions import ForbiddenError, NotFoundError, UnauthorizedError
+
+
+class AuthenticationRequiredError(UnauthorizedError):
+ """Raised when authentication is required but not provided."""
+
+ def __init__(self, message: str = "Not authenticated") -> None:
+ super().__init__(message)
+
+
+class InvalidCredentialsError(UnauthorizedError):
+ """Raised when credentials are invalid."""
+
+ def __init__(self, message: str = "Could not validate credentials") -> None:
+ super().__init__(message)
+
+
+class TokenExpiredError(UnauthorizedError):
+ """Raised when a token has expired."""
+
+ def __init__(self) -> None:
+ super().__init__("Token has expired")
+
+
+class CSRFValidationError(ForbiddenError):
+ """Raised when CSRF validation fails."""
+
+ def __init__(self, reason: str = "CSRF validation failed") -> None:
+ super().__init__(reason)
+
+
+class AdminAccessRequiredError(ForbiddenError):
+ """Raised when admin access is required."""
+
+ def __init__(self, username: str | None = None) -> None:
+ self.username = username
+ msg = f"Admin access required for user '{username}'" if username else "Admin access required"
+ super().__init__(msg)
+
+
+class UserNotFoundError(NotFoundError):
+ """Raised when a user is not found."""
+
+ def __init__(self, identifier: str) -> None:
+ super().__init__("User", identifier)
diff --git a/backend/app/domain/user/settings_models.py b/backend/app/domain/user/settings_models.py
index 171f1b17..10a730d2 100644
--- a/backend/app/domain/user/settings_models.py
+++ b/backend/app/domain/user/settings_models.py
@@ -1,9 +1,11 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from dataclasses import field
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
+from pydantic.dataclasses import dataclass
+
from app.domain.enums.common import Theme
from app.domain.enums.events import EventType
from app.domain.enums.notification import NotificationChannel
diff --git a/backend/app/domain/user/user_models.py b/backend/app/domain/user/user_models.py
index da91d34a..fa34d066 100644
--- a/backend/app/domain/user/user_models.py
+++ b/backend/app/domain/user/user_models.py
@@ -1,8 +1,9 @@
import re
-from dataclasses import dataclass
from datetime import datetime
from typing import List
+from pydantic.dataclasses import dataclass
+
from app.core.utils import StringEnum
from app.domain.enums.user import UserRole
@@ -99,7 +100,7 @@ def is_valid(self) -> bool:
@dataclass
class UserCreation:
- """User creation domain model."""
+ """User creation domain model (API-facing, with plain password)."""
username: str
email: str
@@ -117,3 +118,26 @@ def is_valid(self) -> bool:
EMAIL_PATTERN.match(self.email) is not None, # Proper email validation
]
)
+
+
+@dataclass
+class DomainUserCreate:
+ """User creation data for repository (with hashed password)."""
+
+ username: str
+ email: str
+ hashed_password: str
+ role: UserRole = UserRole.USER
+ is_active: bool = True
+ is_superuser: bool = False
+
+
+@dataclass
+class DomainUserUpdate:
+ """User update data for repository (with hashed password)."""
+
+ username: str | None = None
+ email: str | None = None
+ role: UserRole | None = None
+ is_active: bool | None = None
+ hashed_password: str | None = None
diff --git a/backend/app/events/admin_utils.py b/backend/app/events/admin_utils.py
index 3aef289a..759c3630 100644
--- a/backend/app/events/admin_utils.py
+++ b/backend/app/events/admin_utils.py
@@ -1,16 +1,17 @@
import asyncio
+import logging
from typing import Dict, List
from confluent_kafka.admin import AdminClient, NewTopic
-from app.core.logging import logger
from app.settings import get_settings
class AdminUtils:
"""Minimal admin utilities using native AdminClient."""
- def __init__(self, bootstrap_servers: str | None = None):
+ def __init__(self, logger: logging.Logger, bootstrap_servers: str | None = None):
+ self.logger = logger
settings = get_settings()
self._admin = AdminClient(
{
@@ -30,7 +31,7 @@ async def check_topic_exists(self, topic: str) -> bool:
metadata = self._admin.list_topics(timeout=5.0)
return topic in metadata.topics
except Exception as e:
- logger.error(f"Failed to check topic {topic}: {e}")
+ self.logger.error(f"Failed to check topic {topic}: {e}")
return False
async def create_topic(self, topic: str, num_partitions: int = 1, replication_factor: int = 1) -> bool:
@@ -41,10 +42,10 @@ async def create_topic(self, topic: str, num_partitions: int = 1, replication_fa
# Wait for result - result() returns None on success, raises exception on failure
await asyncio.get_event_loop().run_in_executor(None, lambda: futures[topic].result(timeout=30.0))
- logger.info(f"Topic {topic} created successfully")
+ self.logger.info(f"Topic {topic} created successfully")
return True
except Exception as e:
- logger.error(f"Failed to create topic {topic}: {e}")
+ self.logger.error(f"Failed to create topic {topic}: {e}")
return False
async def ensure_topics_exist(self, topics: List[tuple[str, int]]) -> Dict[str, bool]:
@@ -62,6 +63,6 @@ def get_admin_client(self) -> AdminClient:
return self._admin
-def create_admin_utils(bootstrap_servers: str | None = None) -> AdminUtils:
+def create_admin_utils(logger: logging.Logger, bootstrap_servers: str | None = None) -> AdminUtils:
"""Create admin utilities."""
- return AdminUtils(bootstrap_servers)
+ return AdminUtils(logger, bootstrap_servers)
diff --git a/backend/app/events/consumer_group_monitor.py b/backend/app/events/consumer_group_monitor.py
index cf424834..25e759cf 100644
--- a/backend/app/events/consumer_group_monitor.py
+++ b/backend/app/events/consumer_group_monitor.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Dict, List, cast
@@ -6,7 +7,6 @@
from confluent_kafka import Consumer, ConsumerGroupState, KafkaError, TopicPartition
from confluent_kafka.admin import ConsumerGroupDescription
-from app.core.logging import logger
from app.core.utils import StringEnum
from app.events.admin_utils import AdminUtils
from app.settings import get_settings
@@ -75,6 +75,7 @@ class NativeConsumerGroupMonitor:
def __init__(
self,
+ logger: logging.Logger,
bootstrap_servers: str | None = None,
client_id: str = "integr8scode-consumer-group-monitor",
request_timeout_ms: int = 30000,
@@ -84,10 +85,11 @@ def __init__(
warning_lag_threshold: int = 1000,
min_members_threshold: int = 1,
):
+ self.logger = logger
settings = get_settings()
self.bootstrap_servers = bootstrap_servers or settings.KAFKA_BOOTSTRAP_SERVERS
- self.admin_client = AdminUtils(bootstrap_servers=self.bootstrap_servers)
+ self.admin_client = AdminUtils(logger=logger, bootstrap_servers=self.bootstrap_servers)
# Health thresholds
self.max_rebalance_time = max_rebalance_time_seconds
@@ -151,7 +153,7 @@ async def get_consumer_group_status(
total_lag = lag_info.get("total_lag", 0)
partition_lags = lag_info.get("partition_lags", {})
except Exception as e:
- logger.warning(f"Failed to get lag info for group {group_id}: {e}")
+ self.logger.warning(f"Failed to get lag info for group {group_id}: {e}")
# Create status object
status = ConsumerGroupStatus(
@@ -177,7 +179,7 @@ async def get_consumer_group_status(
return status
except Exception as e:
- logger.error(f"Failed to get consumer group status for {group_id}: {e}")
+ self.logger.error(f"Failed to get consumer group status for {group_id}: {e}")
# Return minimal status with error
return ConsumerGroupStatus(
@@ -208,7 +210,7 @@ async def get_multiple_group_status(
for group_id, status in zip(group_ids, statuses, strict=False):
if isinstance(status, Exception):
- logger.error(f"Failed to get status for group {group_id}: {status}")
+ self.logger.error(f"Failed to get status for group {group_id}: {status}")
results[group_id] = ConsumerGroupStatus(
group_id=group_id,
state="ERROR",
@@ -226,7 +228,7 @@ async def get_multiple_group_status(
results[group_id] = status
except Exception as e:
- logger.error(f"Failed to get multiple group status: {e}")
+ self.logger.error(f"Failed to get multiple group status: {e}")
# Return error status for all groups
for group_id in group_ids:
results[group_id] = ConsumerGroupStatus(
@@ -264,12 +266,12 @@ async def list_consumer_groups(self, timeout: float = 10.0) -> List[str]:
# Log any errors that occurred
if hasattr(result, "errors") and result.errors:
for error in result.errors:
- logger.warning(f"Error listing some consumer groups: {error}")
+ self.logger.warning(f"Error listing some consumer groups: {error}")
return group_ids
except Exception as e:
- logger.error(f"Failed to list consumer groups: {e}")
+ self.logger.error(f"Failed to list consumer groups: {e}")
return []
async def _describe_consumer_group(self, group_id: str, timeout: float) -> ConsumerGroupDescription:
@@ -293,7 +295,7 @@ async def _describe_consumer_group(self, group_id: str, timeout: float) -> Consu
except Exception as e:
if hasattr(e, "args") and e.args and isinstance(e.args[0], KafkaError):
kafka_err = e.args[0]
- logger.error(
+ self.logger.error(
f"Kafka error describing group {group_id}: "
f"code={kafka_err.code()}, "
f"name={kafka_err.name()}, "
@@ -361,7 +363,7 @@ async def _get_consumer_group_lag(self, group_id: str, timeout: float) -> Dict[s
total_lag += lag
except Exception as e:
- logger.debug(f"Failed to get lag for {topic}:{partition_id}: {e}")
+ self.logger.debug(f"Failed to get lag for {topic}:{partition_id}: {e}")
continue
return {"total_lag": total_lag, "partition_lags": partition_lags}
@@ -370,7 +372,7 @@ async def _get_consumer_group_lag(self, group_id: str, timeout: float) -> Dict[s
consumer.close()
except Exception as e:
- logger.warning(f"Failed to get consumer group lag for {group_id}: {e}")
+ self.logger.warning(f"Failed to get consumer group lag for {group_id}: {e}")
return {"total_lag": 0, "partition_lags": {}}
def _assess_group_health(self, status: ConsumerGroupStatus) -> tuple[ConsumerGroupHealth, str]:
@@ -431,5 +433,7 @@ def clear_cache(self) -> None:
self._group_status_cache.clear()
-def create_consumer_group_monitor(bootstrap_servers: str | None = None, **kwargs: Any) -> NativeConsumerGroupMonitor:
- return NativeConsumerGroupMonitor(bootstrap_servers=bootstrap_servers, **kwargs)
+def create_consumer_group_monitor(
+ logger: logging.Logger, bootstrap_servers: str | None = None, **kwargs: Any
+) -> NativeConsumerGroupMonitor:
+ return NativeConsumerGroupMonitor(logger=logger, bootstrap_servers=bootstrap_servers, **kwargs)
diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py
index 2a482ac3..aad22ac3 100644
--- a/backend/app/events/core/consumer.py
+++ b/backend/app/events/core/consumer.py
@@ -1,5 +1,6 @@
import asyncio
import json
+import logging
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from typing import Any
@@ -8,7 +9,6 @@
from confluent_kafka.error import KafkaError
from opentelemetry.trace import SpanKind
-from app.core.logging import logger
from app.core.metrics.context import get_event_metrics
from app.core.tracing import EventAttributes
from app.core.tracing.utils import extract_trace_context, get_tracer
@@ -26,10 +26,12 @@ def __init__(
self,
config: ConsumerConfig,
event_dispatcher: EventDispatcher,
+ logger: logging.Logger,
stats_callback: Callable[[dict[str, Any]], None] | None = None,
):
self._config = config
- self._schema_registry = SchemaRegistryManager()
+ self.logger = logger
+ self._schema_registry = SchemaRegistryManager(logger=logger)
self._dispatcher = event_dispatcher
self._stats_callback = stats_callback
self._consumer: Consumer | None = None
@@ -56,7 +58,7 @@ async def start(self, topics: list[KafkaTopic]) -> None:
self._state = ConsumerState.RUNNING
- logger.info(f"Consumer started for topics: {topic_strings}")
+ self.logger.info(f"Consumer started for topics: {topic_strings}")
async def stop(self) -> None:
self._state = (
@@ -81,14 +83,14 @@ async def _cleanup(self) -> None:
self._consumer = None
async def _consume_loop(self) -> None:
- logger.info(f"Consumer loop started for group {self._config.group_id}")
+ self.logger.info(f"Consumer loop started for group {self._config.group_id}")
poll_count = 0
message_count = 0
while self._running and self._consumer:
poll_count += 1
if poll_count % 100 == 0: # Log every 100 polls
- logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}")
+ self.logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}")
msg = await asyncio.to_thread(self._consumer.poll, timeout=0.1)
@@ -96,11 +98,11 @@ async def _consume_loop(self) -> None:
error = msg.error()
if error:
if error.code() != KafkaError._PARTITION_EOF:
- logger.error(f"Consumer error: {error}")
+ self.logger.error(f"Consumer error: {error}")
self._metrics.processing_errors += 1
else:
message_count += 1
- logger.debug(
+ self.logger.debug(
f"Message received from topic {msg.topic()}, partition {msg.partition()}, offset {msg.offset()}"
)
await self._process_message(msg)
@@ -109,7 +111,7 @@ async def _consume_loop(self) -> None:
else:
await asyncio.sleep(0.01)
- logger.warning(
+ self.logger.warning(
f"Consumer loop ended for group {self._config.group_id}: "
f"running={self._running}, consumer={self._consumer is not None}"
)
@@ -117,17 +119,17 @@ async def _consume_loop(self) -> None:
async def _process_message(self, message: Message) -> None:
topic = message.topic()
if not topic:
- logger.warning("Message with no topic received")
+ self.logger.warning("Message with no topic received")
return
raw_value = message.value()
if not raw_value:
- logger.warning(f"Empty message from topic {topic}")
+ self.logger.warning(f"Empty message from topic {topic}")
return
- logger.debug(f"Deserializing message from topic {topic}, size={len(raw_value)} bytes")
+ self.logger.debug(f"Deserializing message from topic {topic}, size={len(raw_value)} bytes")
event = self._schema_registry.deserialize_event(raw_value, topic)
- logger.info(f"Deserialized event: type={event.event_type}, id={event.event_id}")
+ self.logger.info(f"Deserialized event: type={event.event_type}, id={event.event_id}")
# Extract trace context from Kafka headers and start a consumer span
header_list = message.headers() or []
@@ -139,7 +141,7 @@ async def _process_message(self, message: Message) -> None:
# Dispatch event through EventDispatcher
try:
- logger.debug(f"Dispatching {event.event_type} to handlers")
+ self.logger.debug(f"Dispatching {event.event_type} to handlers")
partition_val = message.partition()
offset_val = message.offset()
part_attr = partition_val if partition_val is not None else -1
@@ -157,7 +159,7 @@ async def _process_message(self, message: Message) -> None:
},
):
await self._dispatcher.dispatch(event)
- logger.debug(f"Successfully dispatched {event.event_type}")
+ self.logger.debug(f"Successfully dispatched {event.event_type}")
# Update metrics on successful dispatch
self._metrics.messages_consumed += 1
self._metrics.bytes_consumed += len(raw_value)
@@ -165,7 +167,7 @@ async def _process_message(self, message: Message) -> None:
# Record Kafka consumption metrics
self._event_metrics.record_kafka_message_consumed(topic=topic, consumer_group=self._config.group_id)
except Exception as e:
- logger.error(f"Dispatcher error for event {event.event_type}: {e}")
+ self.logger.error(f"Dispatcher error for event {event.event_type}: {e}")
self._metrics.processing_errors += 1
# Record Kafka consumption error
self._event_metrics.record_kafka_consumption_error(
@@ -237,7 +239,7 @@ async def seek_to_end(self) -> None:
def _seek_all_partitions(self, offset_type: int) -> None:
if not self._consumer:
- logger.warning("Cannot seek: consumer not initialized")
+ self.logger.warning("Cannot seek: consumer not initialized")
return
assignment = self._consumer.assignment()
@@ -247,7 +249,7 @@ def _seek_all_partitions(self, offset_type: int) -> None:
async def seek_to_offset(self, topic: str, partition: int, offset: int) -> None:
if not self._consumer:
- logger.warning("Cannot seek to offset: consumer not initialized")
+ self.logger.warning("Cannot seek to offset: consumer not initialized")
return
self._consumer.seek(TopicPartition(topic, partition, offset))
diff --git a/backend/app/events/core/dispatcher.py b/backend/app/events/core/dispatcher.py
index 1727922b..7f972524 100644
--- a/backend/app/events/core/dispatcher.py
+++ b/backend/app/events/core/dispatcher.py
@@ -1,9 +1,9 @@
import asyncio
+import logging
from collections import defaultdict
from collections.abc import Awaitable, Callable
from typing import TypeAlias, TypeVar
-from app.core.logging import logger
from app.domain.enums.events import EventType
from app.infrastructure.kafka.events.base import BaseEvent
from app.infrastructure.kafka.mappings import get_event_class_for_type
@@ -20,7 +20,8 @@ class EventDispatcher:
a direct mapping from event types to their handlers.
"""
- def __init__(self) -> None:
+ def __init__(self, logger: logging.Logger) -> None:
+ self.logger = logger
# Map event types to their handlers
self._handlers: dict[EventType, list[Callable[[BaseEvent], Awaitable[None]]]] = defaultdict(list)
@@ -41,7 +42,7 @@ def _build_topic_mapping(self) -> None:
if hasattr(event_class, "topic"):
topic = str(event_class.topic)
self._topic_event_types[topic].add(event_class)
- logger.debug(f"Mapped {event_class.__name__} to topic {topic}")
+ self.logger.debug(f"Mapped {event_class.__name__} to topic {topic}")
def register(self, event_type: EventType) -> Callable[[EventHandler], EventHandler]:
"""
@@ -54,7 +55,7 @@ async def handle_execution(event: ExecutionRequestedEvent) -> None:
"""
def decorator(handler: EventHandler) -> EventHandler:
- logger.info(f"Registering handler '{handler.__name__}' for event type '{event_type.value}'")
+ self.logger.info(f"Registering handler '{handler.__name__}' for event type '{event_type.value}'")
self._handlers[event_type].append(handler)
return handler
@@ -68,7 +69,7 @@ def register_handler(self, event_type: EventType, handler: EventHandler) -> None
event_type: The event type this handler processes
handler: The async handler function
"""
- logger.info(f"Registering handler '{handler.__name__}' for event type '{event_type.value}'")
+ self.logger.info(f"Registering handler '{handler.__name__}' for event type '{event_type.value}'")
self._handlers[event_type].append(handler)
def remove_handler(self, event_type: EventType, handler: EventHandler) -> bool:
@@ -84,7 +85,7 @@ def remove_handler(self, event_type: EventType, handler: EventHandler) -> bool:
"""
if event_type in self._handlers and handler in self._handlers[event_type]:
self._handlers[event_type].remove(handler)
- logger.info(f"Removed handler '{handler.__name__}' for event type '{event_type.value}'")
+ self.logger.info(f"Removed handler '{handler.__name__}' for event type '{event_type.value}'")
# Clean up empty lists
if not self._handlers[event_type]:
del self._handlers[event_type]
@@ -100,17 +101,17 @@ async def dispatch(self, event: BaseEvent) -> None:
"""
event_type = event.event_type
handlers = self._handlers.get(event_type, [])
- logger.debug(f"Dispatcher has {len(self._handlers)} event types registered")
- logger.debug(
+ self.logger.debug(f"Dispatcher has {len(self._handlers)} event types registered")
+ self.logger.debug(
f"For event type {event_type}, found {len(handlers)} handlers: {[h.__class__.__name__ for h in handlers]}"
)
if not handlers:
self._event_metrics[event_type]["skipped"] += 1
- logger.debug(f"No handlers registered for event type {event_type.value}")
+ self.logger.debug(f"No handlers registered for event type {event_type.value}")
return
- logger.debug(f"Dispatching {event_type.value} to {len(handlers)} handler(s)")
+ self.logger.debug(f"Dispatching {event_type.value} to {len(handlers)} handler(s)")
# Run handlers concurrently for better performance
tasks = []
@@ -135,11 +136,11 @@ async def _execute_handler(self, handler: EventHandler, event: BaseEvent) -> Non
event: The event to process
"""
try:
- logger.debug(f"Executing handler {handler.__class__.__name__} for event {event.event_id}")
+ self.logger.debug(f"Executing handler {handler.__class__.__name__} for event {event.event_id}")
await handler(event)
- logger.debug(f"Handler {handler.__class__.__name__} completed")
+ self.logger.debug(f"Handler {handler.__class__.__name__} completed")
except Exception as e:
- logger.error(
+ self.logger.error(
f"Handler '{handler.__class__.__name__}' failed for event {event.event_id}: {e}", exc_info=True
)
raise
@@ -166,7 +167,7 @@ def get_metrics(self) -> dict[str, dict[str, int]]:
def clear_handlers(self) -> None:
"""Clear all registered handlers (useful for testing)."""
self._handlers.clear()
- logger.info("All event handlers cleared")
+ self.logger.info("All event handlers cleared")
def get_handlers(self, event_type: EventType) -> list[Callable[[BaseEvent], Awaitable[None]]]:
"""Get all handlers for a specific event type."""
diff --git a/backend/app/events/core/dlq_handler.py b/backend/app/events/core/dlq_handler.py
index a674b5a7..0e035b2e 100644
--- a/backend/app/events/core/dlq_handler.py
+++ b/backend/app/events/core/dlq_handler.py
@@ -1,13 +1,13 @@
+import logging
from typing import Awaitable, Callable
-from app.core.logging import logger
from app.infrastructure.kafka.events.base import BaseEvent
from .producer import UnifiedProducer
def create_dlq_error_handler(
- producer: UnifiedProducer, original_topic: str, max_retries: int = 3
+ producer: UnifiedProducer, original_topic: str, logger: logging.Logger, max_retries: int = 3
) -> Callable[[Exception, BaseEvent], Awaitable[None]]:
"""
Create an error handler that sends failed events to DLQ.
@@ -15,6 +15,7 @@ def create_dlq_error_handler(
Args:
producer: The Kafka producer to use for sending to DLQ
original_topic: The topic where the event originally failed
+ logger: Logger instance for logging
max_retries: Maximum number of retries before sending to DLQ
Returns:
@@ -61,7 +62,7 @@ async def handle_error_with_dlq(error: Exception, event: BaseEvent) -> None:
def create_immediate_dlq_handler(
- producer: UnifiedProducer, original_topic: str
+ producer: UnifiedProducer, original_topic: str, logger: logging.Logger
) -> Callable[[Exception, BaseEvent], Awaitable[None]]:
"""
Create an error handler that immediately sends failed events to DLQ.
@@ -71,6 +72,7 @@ def create_immediate_dlq_handler(
Args:
producer: The Kafka producer to use for sending to DLQ
original_topic: The topic where the event originally failed
+ logger: Logger instance for logging
Returns:
An async error handler function suitable for UnifiedConsumer.register_error_callback
diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py
index b174d1e2..76115e56 100644
--- a/backend/app/events/core/producer.py
+++ b/backend/app/events/core/producer.py
@@ -1,5 +1,6 @@
import asyncio
import json
+import logging
import socket
import threading
from datetime import datetime, timezone
@@ -9,12 +10,11 @@
from confluent_kafka.error import KafkaError
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics.context import get_event_metrics
+from app.dlq.models import DLQMessage, DLQMessageStatus
from app.domain.enums.kafka import KafkaTopic
from app.events.schema.schema_registry import SchemaRegistryManager
from app.infrastructure.kafka.events import BaseEvent
-from app.infrastructure.mappers.dlq_mapper import DLQMapper
from app.settings import get_settings
from .types import ProducerConfig, ProducerMetrics, ProducerState
@@ -32,10 +32,12 @@ def __init__(
self,
config: ProducerConfig,
schema_registry_manager: SchemaRegistryManager,
+ logger: logging.Logger,
stats_callback: StatsCallback | None = None,
):
self._config = config
self._schema_registry = schema_registry_manager
+ self.logger = logger
self._producer: Producer | None = None
self._stats_callback = stats_callback
self._state = ProducerState.STOPPED
@@ -72,13 +74,13 @@ def _handle_delivery(self, error: KafkaError | None, message: Message) -> None:
self._event_metrics.record_kafka_production_error(
topic=topic if topic is not None else "unknown", error_type=str(error.code())
)
- logger.error(f"Message delivery failed: {error}")
+ self.logger.error(f"Message delivery failed: {error}")
else:
self._metrics.messages_sent += 1
message_value = message.value()
if message_value:
self._metrics.bytes_sent += len(message_value)
- logger.debug(f"Message delivered to {message.topic()}[{message.partition()}]@{message.offset()}")
+ self.logger.debug(f"Message delivered to {message.topic()}[{message.partition()}]@{message.offset()}")
def _handle_stats(self, stats_json: str) -> None:
try:
@@ -104,15 +106,15 @@ def _handle_stats(self, stats_json: str) -> None:
if self._stats_callback:
self._stats_callback(stats)
except Exception as e:
- logger.error(f"Error parsing producer stats: {e}")
+ self.logger.error(f"Error parsing producer stats: {e}")
async def start(self) -> None:
if self._state not in (ProducerState.STOPPED, ProducerState.ERROR):
- logger.warning(f"Producer already in state {self._state}, skipping start")
+ self.logger.warning(f"Producer already in state {self._state}, skipping start")
return
self._state = ProducerState.STARTING
- logger.info("Starting producer...")
+ self.logger.info("Starting producer...")
producer_config = self._config.to_producer_config()
producer_config["stats_cb"] = self._handle_stats
@@ -125,7 +127,7 @@ async def start(self) -> None:
self._poll_task = asyncio.create_task(self._poll_loop())
self._state = ProducerState.RUNNING
- logger.info(f"Producer started: {self._config.bootstrap_servers}")
+ self.logger.info(f"Producer started: {self._config.bootstrap_servers}")
def get_status(self) -> dict[str, Any]:
return {
@@ -150,11 +152,11 @@ def get_status(self) -> dict[str, Any]:
async def stop(self) -> None:
if self._state in (ProducerState.STOPPED, ProducerState.STOPPING):
- logger.info(f"Producer already in state {self._state}, skipping stop")
+ self.logger.info(f"Producer already in state {self._state}, skipping stop")
return
self._state = ProducerState.STOPPING
- logger.info("Stopping producer...")
+ self.logger.info("Stopping producer...")
self._running = False
if self._poll_task:
@@ -167,16 +169,16 @@ async def stop(self) -> None:
self._producer = None
self._state = ProducerState.STOPPED
- logger.info("Producer stopped")
+ self.logger.info("Producer stopped")
async def _poll_loop(self) -> None:
- logger.info("Started producer poll loop")
+ self.logger.info("Started producer poll loop")
while self._running and self._producer:
self._producer.poll(timeout=0.1)
await asyncio.sleep(0.01)
- logger.info("Producer poll loop ended")
+ self.logger.info("Producer poll loop ended")
async def produce(
self, event_to_produce: BaseEvent, key: str | None = None, headers: dict[str, str] | None = None
@@ -191,7 +193,7 @@ async def produce(
headers: Message headers
"""
if not self._producer:
- logger.error("Producer not running")
+ self.logger.error("Producer not running")
return
# Serialize value
@@ -209,7 +211,7 @@ async def produce(
# Record Kafka metrics
self._event_metrics.record_kafka_message_produced(topic)
- logger.debug(f"Message [{event_to_produce}] queued for topic: {topic}")
+ self.logger.debug(f"Message [{event_to_produce}] queued for topic: {topic}")
async def send_to_dlq(
self, original_event: BaseEvent, original_topic: str, error: Exception, retry_count: int = 0
@@ -224,7 +226,7 @@ async def send_to_dlq(
retry_count: Number of retry attempts already made
"""
if not self._producer:
- logger.error("Producer not running, cannot send to DLQ")
+ self.logger.error("Producer not running, cannot send to DLQ")
return
try:
@@ -233,18 +235,22 @@ async def send_to_dlq(
task_name = current_task.get_name() if current_task else "main"
producer_id = f"{socket.gethostname()}-{task_name}"
- # Create DLQ message
- dlq_message = DLQMapper.from_failed_event(
+ # Create DLQ message directly
+ dlq_message = DLQMessage(
+ event_id=original_event.event_id,
event=original_event,
+ event_type=original_event.event_type,
original_topic=original_topic,
error=str(error),
- producer_id=producer_id,
retry_count=retry_count,
+ failed_at=datetime.now(timezone.utc),
+ status=DLQMessageStatus.PENDING,
+ producer_id=producer_id,
)
# Create DLQ event wrapper
dlq_event_data = {
- "event_id": dlq_message.event_id or original_event.event_id,
+ "event_id": dlq_message.event_id,
"event_type": "dlq.message",
"event": dlq_message.event.to_dict(),
"original_topic": dlq_message.original_topic,
@@ -277,7 +283,7 @@ async def send_to_dlq(
)
self._metrics.messages_sent += 1
- logger.warning(
+ self.logger.warning(
f"Event {original_event.event_id} sent to DLQ. "
f"Original topic: {original_topic}, Error: {error}, "
f"Retry count: {retry_count}"
@@ -285,7 +291,7 @@ async def send_to_dlq(
except Exception as e:
# If we can't send to DLQ, log critically but don't crash
- logger.critical(
+ self.logger.critical(
f"Failed to send event {original_event.event_id} to DLQ: {e}. Original error: {error}", exc_info=True
)
self._metrics.messages_failed += 1
diff --git a/backend/app/events/event_store.py b/backend/app/events/event_store.py
index b3d175eb..491d2c90 100644
--- a/backend/app/events/event_store.py
+++ b/backend/app/events/event_store.py
@@ -1,16 +1,17 @@
import asyncio
+import logging
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from typing import Any, Dict, List
-from pymongo import ASCENDING, DESCENDING, IndexModel
+from beanie.odm.enums import SortDirection
from pymongo.errors import BulkWriteError, DuplicateKeyError
-from app.core.database_context import Collection, Cursor, Database
-from app.core.logging import logger
from app.core.metrics.context import get_event_metrics
from app.core.tracing import EventAttributes
from app.core.tracing.utils import add_span_attributes
+from app.db.docs import EventStoreDocument
+from app.db.docs.event import EventMetadata
from app.domain.enums.events import EventType
from app.events.schema.schema_registry import SchemaRegistryManager
from app.infrastructure.kafka.events.base import BaseEvent
@@ -19,23 +20,20 @@
class EventStore:
def __init__(
self,
- db: Database,
schema_registry: SchemaRegistryManager,
- collection_name: str = "events",
+ logger: logging.Logger,
ttl_days: int = 90,
batch_size: int = 100,
):
- self.db = db
self.metrics = get_event_metrics()
self.schema_registry = schema_registry
- self.collection_name = collection_name
- self.collection: Collection = db[collection_name]
+ self.logger = logger
self.ttl_days = ttl_days
self.batch_size = batch_size
self._initialized = False
self._PROJECTION = {"stored_at": 0, "_id": 0}
- self._SECURITY_TYPES = [ # stringified once
+ self._SECURITY_TYPES = [
EventType.USER_LOGIN,
EventType.USER_LOGGED_OUT,
EventType.SECURITY_VIOLATION,
@@ -44,44 +42,44 @@ def __init__(
async def initialize(self) -> None:
if self._initialized:
return
-
- event_indexes = [
- IndexModel("event_id", unique=True),
- IndexModel([("timestamp", DESCENDING)]),
- IndexModel([("event_type", ASCENDING), ("timestamp", DESCENDING)]),
- IndexModel([("metadata.user_id", ASCENDING), ("timestamp", DESCENDING)]),
- IndexModel([("metadata.user_id", ASCENDING), ("event_type", ASCENDING)]),
- IndexModel([("execution_id", ASCENDING), ("timestamp", ASCENDING)]),
- IndexModel("metadata.correlation_id"),
- IndexModel("metadata.service_name"),
- IndexModel(
- [
- ("event_type", ASCENDING),
- ("metadata.user_id", ASCENDING),
- ("timestamp", DESCENDING),
- ]
- ),
- IndexModel(
- "timestamp",
- expireAfterSeconds=self.ttl_days * 24 * 60 * 60,
- name="timestamp_ttl",
- ),
- ]
-
- existing = await self.collection.list_indexes().to_list(None)
- if len(existing) <= 1:
- await self.collection.create_indexes(event_indexes)
- logger.info(f"Created {len(event_indexes)} indexes for events collection")
-
+ # Beanie handles index creation via Document.Settings.indexes
self._initialized = True
- logger.info("Streamlined event store initialized")
+ self.logger.info("Event store initialized with Beanie")
+
+ def _event_to_doc(self, event: BaseEvent) -> EventStoreDocument:
+ """Convert BaseEvent to EventStoreDocument."""
+ event_dict = event.model_dump()
+ metadata_dict = event_dict.pop("metadata", {})
+ metadata = EventMetadata(**metadata_dict)
+ base_fields = set(BaseEvent.model_fields.keys())
+ payload = {k: v for k, v in event_dict.items() if k not in base_fields}
+
+ return EventStoreDocument(
+ event_id=event.event_id,
+ event_type=event.event_type,
+ event_version=event.event_version,
+ timestamp=event.timestamp,
+ aggregate_id=event.aggregate_id,
+ metadata=metadata,
+ payload=payload,
+ stored_at=datetime.now(timezone.utc),
+ )
+
+ def _doc_to_dict(self, doc: EventStoreDocument) -> Dict[str, Any]:
+ """Convert EventStoreDocument to dict for schema_registry deserialization."""
+ result: Dict[str, Any] = doc.model_dump(exclude={"id", "revision_id", "stored_at"})
+ # Ensure metadata is a dict for schema_registry
+ if isinstance(result.get("metadata"), dict):
+ pass # Already a dict
+ elif hasattr(result.get("metadata"), "model_dump"):
+ result["metadata"] = result["metadata"].model_dump()
+ return result
async def store_event(self, event: BaseEvent) -> bool:
start = asyncio.get_event_loop().time()
try:
- doc = event.model_dump()
- doc["stored_at"] = datetime.now(timezone.utc)
- await self.collection.insert_one(doc)
+ doc = self._event_to_doc(event)
+ await doc.insert()
add_span_attributes(
**{
@@ -92,14 +90,14 @@ async def store_event(self, event: BaseEvent) -> bool:
)
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_store_duration(duration, "store_single", self.collection_name)
- self.metrics.record_event_stored(event.event_type, self.collection_name)
+ self.metrics.record_event_store_duration(duration, "store_single", "event_store")
+ self.metrics.record_event_stored(event.event_type, "event_store")
return True
except DuplicateKeyError:
- logger.warning(f"Event {event.event_id} already exists")
+ self.logger.warning(f"Event {event.event_id} already exists")
return True
except Exception as e:
- logger.error(f"Failed to store event {event.event_id}: {e.__class__.__name__}: {e}", exc_info=True)
+ self.logger.error(f"Failed to store event {event.event_id}: {e.__class__.__name__}: {e}", exc_info=True)
self.metrics.record_event_store_failed(event.event_type, type(e).__name__)
return False
@@ -110,16 +108,11 @@ async def store_batch(self, events: List[BaseEvent]) -> Dict[str, int]:
return results
try:
- docs = []
- now = datetime.now(timezone.utc)
- for e in events:
- d = e.model_dump()
- d["stored_at"] = now
- docs.append(d)
+ docs = [self._event_to_doc(e) for e in events]
try:
- res = await self.collection.insert_many(docs, ordered=False)
- results["stored"] = len(res.inserted_ids)
+ await EventStoreDocument.insert_many(docs)
+ results["stored"] = len(docs)
except Exception as e:
if isinstance(e, BulkWriteError) and e.details:
errs = e.details.get("writeErrors", [])
@@ -133,26 +126,28 @@ async def store_batch(self, events: List[BaseEvent]) -> Dict[str, int]:
raise
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_store_duration(duration, "store_batch", self.collection_name)
+ self.metrics.record_event_store_duration(duration, "store_batch", "event_store")
add_span_attributes(**{"events.batch.count": len(events)})
if results["stored"] > 0:
for event in events:
- self.metrics.record_event_stored(event.event_type, self.collection_name)
+ self.metrics.record_event_stored(event.event_type, "event_store")
return results
except Exception as e:
- logger.error(f"Failed to store batch: {e.__class__.__name__}: {e}", exc_info=True)
+ self.logger.error(f"Failed to store batch: {e.__class__.__name__}: {e}", exc_info=True)
results["failed"] = results["total"] - results["stored"]
return results
async def get_event(self, event_id: str) -> BaseEvent | None:
start = asyncio.get_event_loop().time()
- doc = await self.collection.find_one({"event_id": event_id}, self._PROJECTION)
+ doc = await EventStoreDocument.find_one({"event_id": event_id})
if not doc:
return None
- event = self.schema_registry.deserialize_json(doc)
+
+ event_dict = self._doc_to_dict(doc)
+ event = self.schema_registry.deserialize_json(event_dict)
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "get_by_id", self.collection_name)
+ self.metrics.record_event_query_duration(duration, "get_by_id", "event_store")
return event
async def get_events_by_type(
@@ -164,14 +159,21 @@ async def get_events_by_type(
offset: int = 0,
) -> List[BaseEvent]:
start = asyncio.get_event_loop().time()
- q: Dict[str, Any] = {"event_type": str(event_type)}
+ query: Dict[str, Any] = {"event_type": event_type}
if tr := self._time_range(start_time, end_time):
- q["timestamp"] = tr
-
- events = await self._find_events(q, sort=("timestamp", DESCENDING), limit=limit, offset=offset)
+ query["timestamp"] = tr
+
+ docs = await (
+ EventStoreDocument.find(query)
+ .sort([("timestamp", SortDirection.DESCENDING)])
+ .skip(offset)
+ .limit(limit)
+ .to_list()
+ )
+ events = [self.schema_registry.deserialize_json(self._doc_to_dict(d)) for d in docs]
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "get_by_type", self.collection_name)
+ self.metrics.record_event_query_duration(duration, "get_by_type", "event_store")
return events
async def get_execution_events(
@@ -180,14 +182,15 @@ async def get_execution_events(
event_types: List[EventType] | None = None,
) -> List[BaseEvent]:
start = asyncio.get_event_loop().time()
- q: Dict[str, Any] = {"execution_id": execution_id}
+ query: Dict[str, Any] = {"execution_id": execution_id}
if event_types:
- q["event_type"] = {"$in": [str(et) for et in event_types]}
+ query["event_type"] = {"$in": event_types}
- events = await self._find_events(q, sort=("timestamp", ASCENDING))
+ docs = await EventStoreDocument.find(query).sort([("timestamp", SortDirection.ASCENDING)]).to_list()
+ events = [self.schema_registry.deserialize_json(self._doc_to_dict(d)) for d in docs]
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "get_execution_events", self.collection_name)
+ self.metrics.record_event_query_duration(duration, "get_execution_events", "event_store")
return events
async def get_user_events(
@@ -199,16 +202,19 @@ async def get_user_events(
limit: int = 100,
) -> List[BaseEvent]:
start = asyncio.get_event_loop().time()
- q: Dict[str, Any] = {"metadata.user_id": str(user_id)}
+ query: Dict[str, Any] = {"metadata.user_id": str(user_id)}
if event_types:
- q["event_type"] = {"$in": event_types}
+ query["event_type"] = {"$in": event_types}
if tr := self._time_range(start_time, end_time):
- q["timestamp"] = tr
+ query["timestamp"] = tr
- events = await self._find_events(q, sort=("timestamp", DESCENDING), limit=limit)
+ docs = (
+ await EventStoreDocument.find(query).sort([("timestamp", SortDirection.DESCENDING)]).limit(limit).to_list()
+ )
+ events = [self.schema_registry.deserialize_json(self._doc_to_dict(d)) for d in docs]
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "get_user_events", self.collection_name)
+ self.metrics.record_event_query_duration(duration, "get_user_events", "event_store")
return events
async def get_security_events(
@@ -219,25 +225,32 @@ async def get_security_events(
limit: int = 100,
) -> List[BaseEvent]:
start = asyncio.get_event_loop().time()
- q: Dict[str, Any] = {"event_type": {"$in": self._SECURITY_TYPES}}
+ query: Dict[str, Any] = {"event_type": {"$in": self._SECURITY_TYPES}}
if user_id:
- q["metadata.user_id"] = str(user_id)
+ query["metadata.user_id"] = str(user_id)
if tr := self._time_range(start_time, end_time):
- q["timestamp"] = tr
+ query["timestamp"] = tr
- events = await self._find_events(q, sort=("timestamp", DESCENDING), limit=limit)
+ docs = (
+ await EventStoreDocument.find(query).sort([("timestamp", SortDirection.DESCENDING)]).limit(limit).to_list()
+ )
+ events = [self.schema_registry.deserialize_json(self._doc_to_dict(d)) for d in docs]
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "get_security_events", self.collection_name)
+ self.metrics.record_event_query_duration(duration, "get_security_events", "event_store")
return events
async def get_correlation_chain(self, correlation_id: str) -> List[BaseEvent]:
start = asyncio.get_event_loop().time()
- q = {"metadata.correlation_id": str(correlation_id)}
- events = await self._find_events(q, sort=("timestamp", ASCENDING))
+ docs = await (
+ EventStoreDocument.find({"metadata.correlation_id": str(correlation_id)})
+ .sort([("timestamp", SortDirection.ASCENDING)])
+ .to_list()
+ )
+ events = [self.schema_registry.deserialize_json(self._doc_to_dict(d)) for d in docs]
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "get_correlation_chain", self.collection_name)
+ self.metrics.record_event_query_duration(duration, "get_correlation_chain", "event_store")
return events
async def replay_events(
@@ -251,25 +264,25 @@ async def replay_events(
count = 0
try:
- q: Dict[str, Any] = {"timestamp": {"$gte": start_time}}
+ query: Dict[str, Any] = {"timestamp": {"$gte": start_time}}
if end_time:
- q["timestamp"]["$lte"] = end_time
+ query["timestamp"]["$lte"] = end_time
if event_types:
- q["event_type"] = {"$in": [str(et) for et in event_types]}
+ query["event_type"] = {"$in": event_types}
- cursor = self.collection.find(q, self._PROJECTION).sort("timestamp", ASCENDING)
- async for doc in cursor:
- event = self.schema_registry.deserialize_json(doc)
+ async for doc in EventStoreDocument.find(query).sort([("timestamp", SortDirection.ASCENDING)]):
+ event_dict = self._doc_to_dict(doc)
+ event = self.schema_registry.deserialize_json(event_dict)
if callback:
await callback(event)
count += 1
duration = asyncio.get_event_loop().time() - start
- self.metrics.record_event_query_duration(duration, "replay_events", self.collection_name)
- logger.info(f"Replayed {count} events from {start_time} to {end_time}")
+ self.metrics.record_event_query_duration(duration, "replay_events", "event_store")
+ self.logger.info(f"Replayed {count} events from {start_time} to {end_time}")
return count
except Exception as e:
- logger.error(f"Failed to replay events: {e}")
+ self.logger.error(f"Failed to replay events: {e}")
return count
async def get_event_stats(
@@ -300,9 +313,8 @@ async def get_event_stats(
]
)
- cursor = self.collection.aggregate(pipeline)
stats: Dict[str, Any] = {"total_events": 0, "event_types": {}, "start_time": start_time, "end_time": end_time}
- async for r in cursor:
+ async for r in EventStoreDocument.aggregate(pipeline):
et = r["_id"]
c = r["count"]
stats["event_types"][et] = {
@@ -313,9 +325,6 @@ async def get_event_stats(
stats["total_events"] += c
return stats
- async def _deserialize_cursor(self, cursor: Cursor) -> list[BaseEvent]:
- return [self.schema_registry.deserialize_json(doc) async for doc in cursor]
-
def _time_range(self, start_time: datetime | None, end_time: datetime | None) -> Dict[str, Any] | None:
if not start_time and not end_time:
return None
@@ -326,45 +335,29 @@ def _time_range(self, start_time: datetime | None, end_time: datetime | None) ->
tr["$lte"] = end_time
return tr
- async def _find_events(
- self,
- query: Dict[str, Any],
- *,
- sort: tuple[str, int],
- limit: int | None = None,
- offset: int = 0,
- ) -> List[BaseEvent]:
- cur = self.collection.find(query, self._PROJECTION).sort(*sort).skip(offset)
- if limit is not None:
- cur = cur.limit(limit)
- return await self._deserialize_cursor(cur)
-
async def health_check(self) -> Dict[str, Any]:
try:
- await self.db.command("ping")
- event_count = await self.collection.count_documents({})
+ event_count = await EventStoreDocument.count()
return {
"healthy": True,
"event_count": event_count,
- "collection": self.collection_name,
+ "collection": "event_store",
"initialized": self._initialized,
}
except Exception as e:
- logger.error(f"Event store health check failed: {e}")
+ self.logger.error(f"Event store health check failed: {e}")
return {"healthy": False, "error": str(e)}
def create_event_store(
- db: Database,
schema_registry: SchemaRegistryManager,
- collection_name: str = "events",
+ logger: logging.Logger,
ttl_days: int = 90,
batch_size: int = 100,
) -> EventStore:
return EventStore(
- db=db,
schema_registry=schema_registry,
- collection_name=collection_name,
+ logger=logger,
ttl_days=ttl_days,
batch_size=batch_size,
)
diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py
index 9276a476..bb26612e 100644
--- a/backend/app/events/event_store_consumer.py
+++ b/backend/app/events/event_store_consumer.py
@@ -1,9 +1,9 @@
import asyncio
+import logging
from opentelemetry.trace import SpanKind
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.tracing.utils import trace_span
from app.domain.enums.events import EventType
from app.domain.enums.kafka import GroupId, KafkaTopic
@@ -22,6 +22,7 @@ def __init__(
event_store: EventStore,
topics: list[KafkaTopic],
schema_registry_manager: SchemaRegistryManager,
+ logger: logging.Logger,
producer: UnifiedProducer | None = None,
group_id: GroupId = GroupId.EVENT_STORE_CONSUMER,
batch_size: int = 100,
@@ -32,9 +33,10 @@ def __init__(
self.group_id = group_id
self.batch_size = batch_size
self.batch_timeout = batch_timeout_seconds
+ self.logger = logger
self.consumer: UnifiedConsumer | None = None
self.schema_registry_manager = schema_registry_manager
- self.dispatcher = EventDispatcher()
+ self.dispatcher = EventDispatcher(logger)
self.producer = producer # For DLQ handling
self._batch_buffer: list[BaseEvent] = []
self._batch_lock = asyncio.Lock()
@@ -55,7 +57,7 @@ async def start(self) -> None:
max_poll_records=self.batch_size,
)
- self.consumer = UnifiedConsumer(config, event_dispatcher=self.dispatcher)
+ self.consumer = UnifiedConsumer(config, event_dispatcher=self.dispatcher, logger=self.logger)
# Register handler for all event types - store everything
for event_type in EventType:
@@ -67,6 +69,7 @@ async def start(self) -> None:
dlq_handler = create_dlq_error_handler(
producer=self.producer,
original_topic="event-store", # Generic topic name for event store
+ logger=self.logger,
max_retries=3,
)
self.consumer.register_error_callback(dlq_handler)
@@ -79,7 +82,7 @@ async def start(self) -> None:
self._batch_task = asyncio.create_task(self._batch_processor())
- logger.info(f"Event store consumer started for topics: {self.topics}")
+ self.logger.info(f"Event store consumer started for topics: {self.topics}")
async def stop(self) -> None:
"""Stop consumer."""
@@ -100,11 +103,11 @@ async def stop(self) -> None:
if self.consumer:
await self.consumer.stop()
- logger.info("Event store consumer stopped")
+ self.logger.info("Event store consumer stopped")
async def _handle_event(self, event: BaseEvent) -> None:
"""Handle incoming event from dispatcher."""
- logger.info(f"Event store received event: {event.event_type} - {event.event_id}")
+ self.logger.info(f"Event store received event: {event.event_type} - {event.event_id}")
async with self._batch_lock:
self._batch_buffer.append(event)
@@ -114,7 +117,7 @@ async def _handle_event(self, event: BaseEvent) -> None:
async def _handle_error_with_event(self, error: Exception, event: BaseEvent) -> None:
"""Handle processing errors with event context."""
- logger.error(f"Error processing event {event.event_id} ({event.event_type}): {error}", exc_info=True)
+ self.logger.error(f"Error processing event {event.event_id} ({event.event_type}): {error}", exc_info=True)
async def _batch_processor(self) -> None:
"""Periodically flush batches based on timeout."""
@@ -129,7 +132,7 @@ async def _batch_processor(self) -> None:
await self._flush_batch()
except Exception as e:
- logger.error(f"Error in batch processor: {e}")
+ self.logger.error(f"Error in batch processor: {e}")
async def _flush_batch(self) -> None:
if not self._batch_buffer:
@@ -139,7 +142,7 @@ async def _flush_batch(self) -> None:
self._batch_buffer.clear()
self._last_batch_time = asyncio.get_event_loop().time()
- logger.info(f"Event store flushing batch of {len(batch)} events")
+ self.logger.info(f"Event store flushing batch of {len(batch)} events")
with trace_span(
name="event_store.flush_batch",
kind=SpanKind.CONSUMER,
@@ -147,7 +150,7 @@ async def _flush_batch(self) -> None:
):
results = await self.event_store.store_batch(batch)
- logger.info(
+ self.logger.info(
f"Stored event batch: total={results['total']}, "
f"stored={results['stored']}, duplicates={results['duplicates']}, "
f"failed={results['failed']}"
@@ -158,6 +161,7 @@ def create_event_store_consumer(
event_store: EventStore,
topics: list[KafkaTopic],
schema_registry_manager: SchemaRegistryManager,
+ logger: logging.Logger,
producer: UnifiedProducer | None = None,
group_id: GroupId = GroupId.EVENT_STORE_CONSUMER,
batch_size: int = 100,
@@ -170,5 +174,6 @@ def create_event_store_consumer(
batch_size=batch_size,
batch_timeout_seconds=batch_timeout_seconds,
schema_registry_manager=schema_registry_manager,
+ logger=logger,
producer=producer,
)
diff --git a/backend/app/events/schema/schema_registry.py b/backend/app/events/schema/schema_registry.py
index 09a036c5..fa5b28da 100644
--- a/backend/app/events/schema/schema_registry.py
+++ b/backend/app/events/schema/schema_registry.py
@@ -1,4 +1,5 @@
import json
+import logging
import os
import struct
from functools import lru_cache
@@ -9,7 +10,6 @@
from confluent_kafka.schema_registry.avro import AvroDeserializer, AvroSerializer
from confluent_kafka.serialization import MessageField, SerializationContext
-from app.core.logging import logger
from app.domain.enums.events import EventType
from app.infrastructure.kafka.events.base import BaseEvent
from app.settings import get_settings
@@ -54,7 +54,8 @@ def _get_event_type_to_class_mapping() -> Dict[EventType, Type[BaseEvent]]:
class SchemaRegistryManager:
"""Schema registry manager for Avro serialization with Confluent wire format."""
- def __init__(self, schema_registry_url: str | None = None):
+ def __init__(self, logger: logging.Logger, schema_registry_url: str | None = None):
+ self.logger = logger
settings = get_settings()
self.url = schema_registry_url or settings.SCHEMA_REGISTRY_URL
self.namespace = "com.integr8scode.events"
@@ -81,7 +82,7 @@ def register_schema(self, subject: str, event_class: Type[BaseEvent]) -> int:
schema_id: int = self.client.register_schema(subject, Schema(schema_str, "AVRO"))
self._schema_id_cache[event_class] = schema_id
self._id_to_class_cache[schema_id] = event_class
- logger.info(f"Registered schema for {event_class.__name__}: ID {schema_id}")
+ self.logger.info(f"Registered schema for {event_class.__name__}: ID {schema_id}")
return schema_id
def _get_schema_id(self, event_class: Type[BaseEvent]) -> int:
@@ -213,7 +214,7 @@ def set_compatibility(self, subject: str, mode: str) -> None:
url = f"{self.url}/config/{subject}"
response = httpx.put(url, json={"compatibility": mode})
response.raise_for_status()
- logger.info(f"Set {subject} compatibility to {mode}")
+ self.logger.info(f"Set {subject} compatibility to {mode}")
async def initialize_schemas(self) -> None:
"""Initialize all event schemas in the registry (set compat + register)."""
@@ -227,11 +228,13 @@ async def initialize_schemas(self) -> None:
self.register_schema(subject, event_class)
self._initialized = True
- logger.info(f"Initialized {len(_get_all_event_classes())} event schemas")
+ self.logger.info(f"Initialized {len(_get_all_event_classes())} event schemas")
-def create_schema_registry_manager(schema_registry_url: str | None = None) -> SchemaRegistryManager:
- return SchemaRegistryManager(schema_registry_url)
+def create_schema_registry_manager(
+ logger: logging.Logger, schema_registry_url: str | None = None
+) -> SchemaRegistryManager:
+ return SchemaRegistryManager(logger, schema_registry_url)
async def initialize_event_schemas(registry: SchemaRegistryManager) -> None:
diff --git a/backend/app/infrastructure/kafka/events/execution.py b/backend/app/infrastructure/kafka/events/execution.py
index f596d03a..3030b4eb 100644
--- a/backend/app/infrastructure/kafka/events/execution.py
+++ b/backend/app/infrastructure/kafka/events/execution.py
@@ -87,10 +87,10 @@ class ExecutionCompletedEvent(BaseEvent):
event_type: Literal[EventType.EXECUTION_COMPLETED] = EventType.EXECUTION_COMPLETED
topic: ClassVar[KafkaTopic] = KafkaTopic.EXECUTION_COMPLETED
execution_id: str
- stdout: str = ""
- stderr: str = ""
exit_code: int
resource_usage: ResourceUsageDomain
+ stdout: str = ""
+ stderr: str = ""
class ExecutionFailedEvent(BaseEvent):
@@ -102,17 +102,17 @@ class ExecutionFailedEvent(BaseEvent):
exit_code: int
error_type: ExecutionErrorType
error_message: str
- resource_usage: ResourceUsageDomain
+ resource_usage: ResourceUsageDomain | None = None
class ExecutionTimeoutEvent(BaseEvent):
event_type: Literal[EventType.EXECUTION_TIMEOUT] = EventType.EXECUTION_TIMEOUT
topic: ClassVar[KafkaTopic] = KafkaTopic.EXECUTION_TIMEOUT
execution_id: str
- stdout: str = ""
- stderr: str = ""
timeout_seconds: int
resource_usage: ResourceUsageDomain
+ stdout: str = ""
+ stderr: str = ""
class ExecutionCancelledEvent(BaseEvent):
diff --git a/backend/app/infrastructure/kafka/events/metadata.py b/backend/app/infrastructure/kafka/events/metadata.py
index 23805032..71cba2bf 100644
--- a/backend/app/infrastructure/kafka/events/metadata.py
+++ b/backend/app/infrastructure/kafka/events/metadata.py
@@ -1,4 +1,3 @@
-from typing import Any, Dict
from uuid import uuid4
from pydantic import ConfigDict, Field
@@ -20,21 +19,6 @@ class AvroEventMetadata(AvroBase):
model_config = ConfigDict(extra="allow", str_strip_whitespace=True, use_enum_values=True)
- def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]:
- return self.model_dump(exclude_none=exclude_none)
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "AvroEventMetadata":
- return cls(
- service_name=data.get("service_name", "unknown"),
- service_version=data.get("service_version", "1.0"),
- correlation_id=data.get("correlation_id", str(uuid4())),
- user_id=data.get("user_id"),
- ip_address=data.get("ip_address"),
- user_agent=data.get("user_agent"),
- environment=data.get("environment", Environment.PRODUCTION),
- )
-
def with_correlation(self, correlation_id: str) -> "AvroEventMetadata":
return self.model_copy(update={"correlation_id": correlation_id})
diff --git a/backend/app/infrastructure/kafka/events/user.py b/backend/app/infrastructure/kafka/events/user.py
index 6378bb1f..32d98abf 100644
--- a/backend/app/infrastructure/kafka/events/user.py
+++ b/backend/app/infrastructure/kafka/events/user.py
@@ -51,7 +51,7 @@ class UserSettingsUpdatedEvent(BaseEvent):
topic: ClassVar[KafkaTopic] = KafkaTopic.USER_SETTINGS_EVENTS
user_id: str
settings_type: SettingsType
- changes: dict[str, str]
+ updated: dict[str, str]
class UserThemeChangedEvent(BaseEvent):
diff --git a/backend/app/infrastructure/mappers/__init__.py b/backend/app/infrastructure/mappers/__init__.py
deleted file mode 100644
index 7dcd86e4..00000000
--- a/backend/app/infrastructure/mappers/__init__.py
+++ /dev/null
@@ -1,68 +0,0 @@
-from .admin_mapper import (
- AuditLogMapper,
- SettingsMapper,
- UserMapper,
-)
-from .event_mapper import (
- ArchivedEventMapper,
- EventExportRowMapper,
- EventFilterMapper,
- EventMapper,
- EventSummaryMapper,
-)
-from .notification_mapper import NotificationMapper
-from .rate_limit_mapper import (
- RateLimitConfigMapper,
- RateLimitRuleMapper,
- UserRateLimitMapper,
-)
-from .replay_api_mapper import ReplayApiMapper
-from .replay_mapper import ReplayApiMapper as AdminReplayApiMapper
-from .replay_mapper import (
- ReplayQueryMapper,
- ReplaySessionMapper,
- ReplayStateMapper,
-)
-from .saga_mapper import (
- SagaFilterMapper,
- SagaInstanceMapper,
- SagaMapper,
-)
-from .saved_script_mapper import SavedScriptMapper
-from .sse_mapper import SSEMapper
-from .user_settings_mapper import UserSettingsMapper
-
-__all__ = [
- # Admin
- "UserMapper",
- "SettingsMapper",
- "AuditLogMapper",
- # Events
- "EventMapper",
- "EventSummaryMapper",
- "ArchivedEventMapper",
- "EventExportRowMapper",
- "EventFilterMapper",
- # Notification
- "NotificationMapper",
- # Rate limit
- "RateLimitRuleMapper",
- "UserRateLimitMapper",
- "RateLimitConfigMapper",
- # Replay
- "ReplayApiMapper",
- "AdminReplayApiMapper",
- "ReplaySessionMapper",
- "ReplayQueryMapper",
- "ReplayStateMapper",
- # Saved scripts
- "SavedScriptMapper",
- # SSE
- "SSEMapper",
- # User settings
- "UserSettingsMapper",
- # Saga
- "SagaMapper",
- "SagaFilterMapper",
- "SagaInstanceMapper",
-]
diff --git a/backend/app/infrastructure/mappers/admin_mapper.py b/backend/app/infrastructure/mappers/admin_mapper.py
deleted file mode 100644
index 0800048e..00000000
--- a/backend/app/infrastructure/mappers/admin_mapper.py
+++ /dev/null
@@ -1,251 +0,0 @@
-import re
-from datetime import datetime, timezone
-from typing import Any, Dict
-
-from app.domain.admin import (
- AuditAction,
- AuditLogEntry,
- AuditLogFields,
- ExecutionLimits,
- LogLevel,
- MonitoringSettings,
- SecuritySettings,
- SettingsFields,
- SystemSettings,
-)
-from app.domain.user import (
- User as DomainAdminUser,
-)
-from app.domain.user import (
- UserCreation,
- UserFields,
- UserRole,
- UserSearchFilter,
- UserUpdate,
-)
-from app.schemas_pydantic.user import User as ServiceUser
-
-EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
-
-
-class UserMapper:
- @staticmethod
- def to_mongo_document(user: DomainAdminUser) -> Dict[str, Any]:
- return {
- UserFields.USER_ID: user.user_id,
- UserFields.USERNAME: user.username,
- UserFields.EMAIL: user.email,
- UserFields.ROLE: user.role.value,
- UserFields.IS_ACTIVE: user.is_active,
- UserFields.IS_SUPERUSER: user.is_superuser,
- UserFields.HASHED_PASSWORD: user.hashed_password,
- UserFields.CREATED_AT: user.created_at,
- UserFields.UPDATED_AT: user.updated_at,
- }
-
- @staticmethod
- def from_mongo_document(data: Dict[str, Any]) -> DomainAdminUser:
- required_fields = [UserFields.USER_ID, UserFields.USERNAME, UserFields.EMAIL]
- for field in required_fields:
- if field not in data or not data[field]:
- raise ValueError(f"Missing required field: {field}")
-
- email = data[UserFields.EMAIL]
- if not EMAIL_PATTERN.match(email):
- raise ValueError(f"Invalid email format: {email}")
-
- return DomainAdminUser(
- user_id=data[UserFields.USER_ID],
- username=data[UserFields.USERNAME],
- email=email,
- role=UserRole(data.get(UserFields.ROLE, UserRole.USER)),
- is_active=data.get(UserFields.IS_ACTIVE, True),
- is_superuser=data.get(UserFields.IS_SUPERUSER, False),
- hashed_password=data.get(UserFields.HASHED_PASSWORD, ""),
- created_at=data.get(UserFields.CREATED_AT, datetime.now(timezone.utc)),
- updated_at=data.get(UserFields.UPDATED_AT, datetime.now(timezone.utc)),
- )
-
- @staticmethod
- def from_pydantic_service_user(user: ServiceUser) -> DomainAdminUser:
- """Convert internal service Pydantic user to domain admin user."""
- return DomainAdminUser(
- user_id=user.user_id,
- username=user.username,
- email=str(user.email),
- role=user.role,
- is_active=user.is_active,
- is_superuser=user.is_superuser,
- hashed_password="",
- created_at=user.created_at or datetime.now(timezone.utc),
- updated_at=user.updated_at or datetime.now(timezone.utc),
- )
-
- @staticmethod
- def to_update_dict(update: UserUpdate) -> Dict[str, Any]:
- update_dict: Dict[str, Any] = {}
-
- if update.username is not None:
- update_dict[UserFields.USERNAME] = update.username
- if update.email is not None:
- if not EMAIL_PATTERN.match(update.email):
- raise ValueError(f"Invalid email format: {update.email}")
- update_dict[UserFields.EMAIL] = update.email
- if update.role is not None:
- update_dict[UserFields.ROLE] = update.role.value
- if update.is_active is not None:
- update_dict[UserFields.IS_ACTIVE] = update.is_active
-
- return update_dict
-
- @staticmethod
- def search_filter_to_query(f: UserSearchFilter) -> Dict[str, Any]:
- query: Dict[str, Any] = {}
- if f.search_text:
- query["$or"] = [
- {UserFields.USERNAME.value: {"$regex": f.search_text, "$options": "i"}},
- {UserFields.EMAIL.value: {"$regex": f.search_text, "$options": "i"}},
- ]
- if f.role:
- query[UserFields.ROLE] = f.role
- return query
-
- @staticmethod
- def user_creation_to_dict(creation: UserCreation) -> Dict[str, Any]:
- return {
- UserFields.USERNAME: creation.username,
- UserFields.EMAIL: creation.email,
- UserFields.ROLE: creation.role.value,
- UserFields.IS_ACTIVE: creation.is_active,
- UserFields.IS_SUPERUSER: creation.is_superuser,
- UserFields.CREATED_AT: datetime.now(timezone.utc),
- UserFields.UPDATED_AT: datetime.now(timezone.utc),
- }
-
-
-class SettingsMapper:
- @staticmethod
- def execution_limits_to_dict(limits: ExecutionLimits) -> dict[str, int]:
- return {
- "max_timeout_seconds": limits.max_timeout_seconds,
- "max_memory_mb": limits.max_memory_mb,
- "max_cpu_cores": limits.max_cpu_cores,
- "max_concurrent_executions": limits.max_concurrent_executions,
- }
-
- @staticmethod
- def execution_limits_from_dict(data: dict[str, Any] | None) -> ExecutionLimits:
- if not data:
- return ExecutionLimits()
- return ExecutionLimits(
- max_timeout_seconds=data.get("max_timeout_seconds", 300),
- max_memory_mb=data.get("max_memory_mb", 512),
- max_cpu_cores=data.get("max_cpu_cores", 2),
- max_concurrent_executions=data.get("max_concurrent_executions", 10),
- )
-
- @staticmethod
- def security_settings_to_dict(settings: SecuritySettings) -> dict[str, int]:
- return {
- "password_min_length": settings.password_min_length,
- "session_timeout_minutes": settings.session_timeout_minutes,
- "max_login_attempts": settings.max_login_attempts,
- "lockout_duration_minutes": settings.lockout_duration_minutes,
- }
-
- @staticmethod
- def security_settings_from_dict(data: dict[str, Any] | None) -> SecuritySettings:
- if not data:
- return SecuritySettings()
- return SecuritySettings(
- password_min_length=data.get("password_min_length", 8),
- session_timeout_minutes=data.get("session_timeout_minutes", 60),
- max_login_attempts=data.get("max_login_attempts", 5),
- lockout_duration_minutes=data.get("lockout_duration_minutes", 15),
- )
-
- @staticmethod
- def monitoring_settings_to_dict(settings: MonitoringSettings) -> dict[str, Any]:
- return {
- "metrics_retention_days": settings.metrics_retention_days,
- "log_level": settings.log_level.value,
- "enable_tracing": settings.enable_tracing,
- "sampling_rate": settings.sampling_rate,
- }
-
- @staticmethod
- def monitoring_settings_from_dict(data: dict[str, Any] | None) -> MonitoringSettings:
- if not data:
- return MonitoringSettings()
- return MonitoringSettings(
- metrics_retention_days=data.get("metrics_retention_days", 30),
- log_level=LogLevel(data.get("log_level", LogLevel.INFO)),
- enable_tracing=data.get("enable_tracing", True),
- sampling_rate=data.get("sampling_rate", 0.1),
- )
-
- @staticmethod
- def system_settings_to_dict(settings: SystemSettings) -> dict[str, Any]:
- mapper = SettingsMapper()
- return {
- SettingsFields.EXECUTION_LIMITS: mapper.execution_limits_to_dict(settings.execution_limits),
- SettingsFields.SECURITY_SETTINGS: mapper.security_settings_to_dict(settings.security_settings),
- SettingsFields.MONITORING_SETTINGS: mapper.monitoring_settings_to_dict(settings.monitoring_settings),
- SettingsFields.CREATED_AT: settings.created_at,
- SettingsFields.UPDATED_AT: settings.updated_at,
- }
-
- @staticmethod
- def system_settings_from_dict(data: dict[str, Any] | None) -> SystemSettings:
- if not data:
- return SystemSettings()
- mapper = SettingsMapper()
- return SystemSettings(
- execution_limits=mapper.execution_limits_from_dict(data.get(SettingsFields.EXECUTION_LIMITS)),
- security_settings=mapper.security_settings_from_dict(data.get(SettingsFields.SECURITY_SETTINGS)),
- monitoring_settings=mapper.monitoring_settings_from_dict(data.get(SettingsFields.MONITORING_SETTINGS)),
- created_at=data.get(SettingsFields.CREATED_AT, datetime.now(timezone.utc)),
- updated_at=data.get(SettingsFields.UPDATED_AT, datetime.now(timezone.utc)),
- )
-
- @staticmethod
- def system_settings_to_pydantic_dict(settings: SystemSettings) -> dict[str, Any]:
- mapper = SettingsMapper()
- return {
- "execution_limits": mapper.execution_limits_to_dict(settings.execution_limits),
- "security_settings": mapper.security_settings_to_dict(settings.security_settings),
- "monitoring_settings": mapper.monitoring_settings_to_dict(settings.monitoring_settings),
- }
-
- @staticmethod
- def system_settings_from_pydantic(data: dict[str, Any]) -> SystemSettings:
- mapper = SettingsMapper()
- return SystemSettings(
- execution_limits=mapper.execution_limits_from_dict(data.get("execution_limits")),
- security_settings=mapper.security_settings_from_dict(data.get("security_settings")),
- monitoring_settings=mapper.monitoring_settings_from_dict(data.get("monitoring_settings")),
- )
-
-
-class AuditLogMapper:
- @staticmethod
- def to_dict(entry: AuditLogEntry) -> dict[str, Any]:
- return {
- AuditLogFields.TIMESTAMP: entry.timestamp,
- AuditLogFields.ACTION: entry.action.value,
- AuditLogFields.USER_ID: entry.user_id,
- AuditLogFields.USERNAME: entry.username,
- AuditLogFields.CHANGES: entry.changes,
- "reason": entry.reason, # reason is not in the enum but used as additional field
- }
-
- @staticmethod
- def from_dict(data: dict[str, Any]) -> AuditLogEntry:
- return AuditLogEntry(
- timestamp=data.get(AuditLogFields.TIMESTAMP, datetime.now(timezone.utc)),
- action=AuditAction(data[AuditLogFields.ACTION]),
- user_id=data[AuditLogFields.USER_ID],
- username=data.get(AuditLogFields.USERNAME, ""),
- changes=data.get(AuditLogFields.CHANGES, {}),
- reason=data.get("reason", ""),
- )
diff --git a/backend/app/infrastructure/mappers/dlq_mapper.py b/backend/app/infrastructure/mappers/dlq_mapper.py
deleted file mode 100644
index ed9db45e..00000000
--- a/backend/app/infrastructure/mappers/dlq_mapper.py
+++ /dev/null
@@ -1,199 +0,0 @@
-from __future__ import annotations
-
-import json
-from datetime import datetime, timezone
-from typing import Mapping
-
-from confluent_kafka import Message
-
-from app.dlq.models import (
- DLQFields,
- DLQMessage,
- DLQMessageFilter,
- DLQMessageStatus,
- DLQMessageUpdate,
-)
-from app.events.schema.schema_registry import SchemaRegistryManager
-from app.infrastructure.kafka.events import BaseEvent
-
-
-class DLQMapper:
- """Mongo/Kafka ↔ DLQMessage conversions."""
-
- @staticmethod
- def to_mongo_document(message: DLQMessage) -> dict[str, object]:
- doc: dict[str, object] = {
- DLQFields.EVENT: message.event.to_dict(),
- DLQFields.ORIGINAL_TOPIC: message.original_topic,
- DLQFields.ERROR: message.error,
- DLQFields.RETRY_COUNT: message.retry_count,
- DLQFields.FAILED_AT: message.failed_at,
- DLQFields.STATUS: message.status,
- DLQFields.PRODUCER_ID: message.producer_id,
- }
- if message.event_id:
- doc[DLQFields.EVENT_ID] = message.event_id
- if message.created_at:
- doc[DLQFields.CREATED_AT] = message.created_at
- if message.last_updated:
- doc[DLQFields.LAST_UPDATED] = message.last_updated
- if message.next_retry_at:
- doc[DLQFields.NEXT_RETRY_AT] = message.next_retry_at
- if message.retried_at:
- doc[DLQFields.RETRIED_AT] = message.retried_at
- if message.discarded_at:
- doc[DLQFields.DISCARDED_AT] = message.discarded_at
- if message.discard_reason:
- doc[DLQFields.DISCARD_REASON] = message.discard_reason
- if message.dlq_offset is not None:
- doc[DLQFields.DLQ_OFFSET] = message.dlq_offset
- if message.dlq_partition is not None:
- doc[DLQFields.DLQ_PARTITION] = message.dlq_partition
- if message.last_error:
- doc[DLQFields.LAST_ERROR] = message.last_error
- return doc
-
- @staticmethod
- def from_mongo_document(data: Mapping[str, object]) -> DLQMessage:
- schema_registry = SchemaRegistryManager()
-
- def parse_dt(value: object) -> datetime | None:
- if value is None:
- return None
- if isinstance(value, datetime):
- return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
- if isinstance(value, str):
- return datetime.fromisoformat(value).replace(tzinfo=timezone.utc)
- raise ValueError("Invalid datetime type")
-
- failed_at_raw = data.get(DLQFields.FAILED_AT)
- if failed_at_raw is None:
- raise ValueError("Missing failed_at")
- failed_at = parse_dt(failed_at_raw)
- if failed_at is None:
- raise ValueError("Invalid failed_at value")
-
- event_data = data.get(DLQFields.EVENT)
- if not isinstance(event_data, dict):
- raise ValueError("Missing or invalid event data")
- event = schema_registry.deserialize_json(event_data)
-
- status_raw = data.get(DLQFields.STATUS, DLQMessageStatus.PENDING)
- status = DLQMessageStatus(str(status_raw))
-
- retry_count_value: int = data.get(DLQFields.RETRY_COUNT, 0) # type: ignore[assignment]
- dlq_offset_value: int | None = data.get(DLQFields.DLQ_OFFSET) # type: ignore[assignment]
- dlq_partition_value: int | None = data.get(DLQFields.DLQ_PARTITION) # type: ignore[assignment]
-
- return DLQMessage(
- event=event,
- original_topic=str(data.get(DLQFields.ORIGINAL_TOPIC, "")),
- error=str(data.get(DLQFields.ERROR, "")),
- retry_count=retry_count_value,
- failed_at=failed_at,
- status=status,
- producer_id=str(data.get(DLQFields.PRODUCER_ID, "unknown")),
- event_id=str(data.get(DLQFields.EVENT_ID, "") or event.event_id),
- created_at=parse_dt(data.get(DLQFields.CREATED_AT)),
- last_updated=parse_dt(data.get(DLQFields.LAST_UPDATED)),
- next_retry_at=parse_dt(data.get(DLQFields.NEXT_RETRY_AT)),
- retried_at=parse_dt(data.get(DLQFields.RETRIED_AT)),
- discarded_at=parse_dt(data.get(DLQFields.DISCARDED_AT)),
- discard_reason=str(data.get(DLQFields.DISCARD_REASON, "")) or None,
- dlq_offset=dlq_offset_value,
- dlq_partition=dlq_partition_value,
- last_error=str(data.get(DLQFields.LAST_ERROR, "")) or None,
- )
-
- @staticmethod
- def from_kafka_message(message: Message, schema_registry: SchemaRegistryManager) -> DLQMessage:
- record_value = message.value()
- if record_value is None:
- raise ValueError("Message has no value")
-
- data = json.loads(record_value.decode("utf-8"))
- event_data = data.get("event", {})
- event = schema_registry.deserialize_json(event_data)
-
- headers: dict[str, str] = {}
- msg_headers = message.headers()
- if msg_headers:
- for key, value in msg_headers:
- headers[key] = value.decode("utf-8") if value else ""
-
- failed_at_str = data.get("failed_at")
- failed_at = (
- datetime.fromisoformat(failed_at_str).replace(tzinfo=timezone.utc)
- if failed_at_str
- else datetime.now(timezone.utc)
- )
-
- offset: int = message.offset() # type: ignore[assignment]
- partition: int = message.partition() # type: ignore[assignment]
-
- return DLQMessage(
- event=event,
- original_topic=data.get("original_topic", "unknown"),
- error=data.get("error", "Unknown error"),
- retry_count=data.get("retry_count", 0),
- failed_at=failed_at,
- status=DLQMessageStatus.PENDING,
- producer_id=data.get("producer_id", "unknown"),
- event_id=event.event_id,
- headers=headers,
- dlq_offset=offset if offset >= 0 else None,
- dlq_partition=partition if partition >= 0 else None,
- )
-
- # Domain construction and updates
- @staticmethod
- def from_failed_event(
- event: BaseEvent,
- original_topic: str,
- error: str,
- producer_id: str,
- retry_count: int = 0,
- ) -> DLQMessage:
- return DLQMessage(
- event=event,
- original_topic=original_topic,
- error=error,
- retry_count=retry_count,
- failed_at=datetime.now(timezone.utc),
- status=DLQMessageStatus.PENDING,
- producer_id=producer_id,
- )
-
- @staticmethod
- def update_to_mongo(update: DLQMessageUpdate) -> dict[str, object]:
- now = datetime.now(timezone.utc)
- doc: dict[str, object] = {
- str(DLQFields.STATUS): update.status,
- str(DLQFields.LAST_UPDATED): now,
- }
- if update.next_retry_at is not None:
- doc[str(DLQFields.NEXT_RETRY_AT)] = update.next_retry_at
- if update.retried_at is not None:
- doc[str(DLQFields.RETRIED_AT)] = update.retried_at
- if update.discarded_at is not None:
- doc[str(DLQFields.DISCARDED_AT)] = update.discarded_at
- if update.retry_count is not None:
- doc[str(DLQFields.RETRY_COUNT)] = update.retry_count
- if update.discard_reason is not None:
- doc[str(DLQFields.DISCARD_REASON)] = update.discard_reason
- if update.last_error is not None:
- doc[str(DLQFields.LAST_ERROR)] = update.last_error
- if update.extra:
- doc.update(update.extra)
- return doc
-
- @staticmethod
- def filter_to_query(f: DLQMessageFilter) -> dict[str, object]:
- query: dict[str, object] = {}
- if f.status:
- query[DLQFields.STATUS] = f.status
- if f.topic:
- query[DLQFields.ORIGINAL_TOPIC] = f.topic
- if f.event_type:
- query[DLQFields.EVENT_TYPE] = f.event_type
- return query
diff --git a/backend/app/infrastructure/mappers/event_mapper.py b/backend/app/infrastructure/mappers/event_mapper.py
deleted file mode 100644
index de616077..00000000
--- a/backend/app/infrastructure/mappers/event_mapper.py
+++ /dev/null
@@ -1,212 +0,0 @@
-from datetime import datetime, timezone
-from typing import Any
-
-from app.domain.events.event_metadata import EventMetadata
-from app.domain.events.event_models import (
- ArchivedEvent,
- Event,
- EventExportRow,
- EventFields,
- EventFilter,
- EventSummary,
-)
-from app.schemas_pydantic.admin_events import EventFilter as AdminEventFilter
-
-
-class EventMapper:
- """Handles all Event serialization/deserialization."""
-
- @staticmethod
- def to_mongo_document(event: Event) -> dict[str, Any]:
- """Convert domain event to MongoDB document."""
- doc: dict[str, Any] = {
- EventFields.EVENT_ID: event.event_id,
- EventFields.EVENT_TYPE: event.event_type,
- EventFields.EVENT_VERSION: event.event_version,
- EventFields.TIMESTAMP: event.timestamp,
- EventFields.METADATA: event.metadata.to_dict(exclude_none=True),
- EventFields.PAYLOAD: event.payload,
- }
-
- if event.aggregate_id is not None:
- doc[EventFields.AGGREGATE_ID] = event.aggregate_id
- if event.stored_at is not None:
- doc[EventFields.STORED_AT] = event.stored_at
- if event.ttl_expires_at is not None:
- doc[EventFields.TTL_EXPIRES_AT] = event.ttl_expires_at
- if event.status is not None:
- doc[EventFields.STATUS] = event.status
- if event.error is not None:
- doc[EventFields.ERROR] = event.error
-
- return doc
-
- @staticmethod
- def from_mongo_document(document: dict[str, Any]) -> Event:
- """Create domain event from MongoDB document."""
- # Define base event fields that should NOT be in payload
- base_fields = {
- EventFields.EVENT_ID,
- EventFields.EVENT_TYPE,
- EventFields.EVENT_VERSION,
- EventFields.TIMESTAMP,
- EventFields.METADATA,
- EventFields.AGGREGATE_ID,
- EventFields.STORED_AT,
- EventFields.TTL_EXPIRES_AT,
- EventFields.STATUS,
- EventFields.ERROR,
- "_id",
- "stored_at",
- }
-
- # Extract all non-base fields as payload
- payload = {k: v for k, v in document.items() if k not in base_fields}
-
- return Event(
- event_id=document[EventFields.EVENT_ID],
- event_type=document[EventFields.EVENT_TYPE],
- event_version=document.get(EventFields.EVENT_VERSION, "1.0"),
- timestamp=document.get(EventFields.TIMESTAMP, datetime.now(timezone.utc)),
- metadata=EventMetadata.from_dict(document.get(EventFields.METADATA, {})),
- payload=payload,
- aggregate_id=document.get(EventFields.AGGREGATE_ID),
- stored_at=document.get(EventFields.STORED_AT),
- ttl_expires_at=document.get(EventFields.TTL_EXPIRES_AT),
- status=document.get(EventFields.STATUS),
- error=document.get(EventFields.ERROR),
- )
-
-
-class EventSummaryMapper:
- """Handles EventSummary serialization."""
-
- @staticmethod
- def from_mongo_document(document: dict[str, Any]) -> EventSummary:
- return EventSummary(
- event_id=document[EventFields.EVENT_ID],
- event_type=document[EventFields.EVENT_TYPE],
- timestamp=document[EventFields.TIMESTAMP],
- aggregate_id=document.get(EventFields.AGGREGATE_ID),
- )
-
-
-class ArchivedEventMapper:
- """Handles ArchivedEvent serialization."""
-
- @staticmethod
- def to_mongo_document(event: ArchivedEvent) -> dict[str, Any]:
- event_mapper = EventMapper()
- doc = event_mapper.to_mongo_document(event)
-
- if event.deleted_at is not None:
- doc[EventFields.DELETED_AT] = event.deleted_at
- if event.deleted_by is not None:
- doc[EventFields.DELETED_BY] = event.deleted_by
- if event.deletion_reason is not None:
- doc[EventFields.DELETION_REASON] = event.deletion_reason
-
- return doc
-
- @staticmethod
- def from_event(event: Event, deleted_by: str, deletion_reason: str) -> ArchivedEvent:
- return ArchivedEvent(
- event_id=event.event_id,
- event_type=event.event_type,
- event_version=event.event_version,
- timestamp=event.timestamp,
- metadata=event.metadata,
- payload=event.payload,
- aggregate_id=event.aggregate_id,
- stored_at=event.stored_at,
- ttl_expires_at=event.ttl_expires_at,
- status=event.status,
- error=event.error,
- deleted_at=datetime.now(timezone.utc),
- deleted_by=deleted_by,
- deletion_reason=deletion_reason,
- )
-
-
-class EventExportRowMapper:
- """Handles EventExportRow serialization."""
-
- @staticmethod
- def to_dict(row: EventExportRow) -> dict[str, str]:
- return {
- "Event ID": row.event_id,
- "Event Type": row.event_type,
- "Timestamp": row.timestamp,
- "Correlation ID": row.correlation_id,
- "Aggregate ID": row.aggregate_id,
- "User ID": row.user_id,
- "Service": row.service,
- "Status": row.status,
- "Error": row.error,
- }
-
- @staticmethod
- def from_event(event: Event) -> EventExportRow:
- return EventExportRow(
- event_id=event.event_id,
- event_type=event.event_type,
- timestamp=event.timestamp.isoformat(),
- correlation_id=event.metadata.correlation_id or "",
- aggregate_id=event.aggregate_id or "",
- user_id=event.metadata.user_id or "",
- service=event.metadata.service_name,
- status=event.status or "",
- error=event.error or "",
- )
-
-
-class EventFilterMapper:
- """Converts EventFilter domain model into MongoDB queries."""
-
- @staticmethod
- def to_mongo_query(flt: EventFilter) -> dict[str, Any]:
- query: dict[str, Any] = {}
-
- if flt.event_types:
- query[EventFields.EVENT_TYPE] = {"$in": flt.event_types}
- if flt.aggregate_id:
- query[EventFields.AGGREGATE_ID] = flt.aggregate_id
- if flt.correlation_id:
- query[EventFields.METADATA_CORRELATION_ID] = flt.correlation_id
- if flt.user_id:
- query[EventFields.METADATA_USER_ID] = flt.user_id
- if flt.service_name:
- query[EventFields.METADATA_SERVICE_NAME] = flt.service_name
- if getattr(flt, "status", None):
- query[EventFields.STATUS] = flt.status
-
- if flt.start_time or flt.end_time:
- time_query: dict[str, Any] = {}
- if flt.start_time:
- time_query["$gte"] = flt.start_time
- if flt.end_time:
- time_query["$lte"] = flt.end_time
- query[EventFields.TIMESTAMP] = time_query
-
- search = getattr(flt, "text_search", None) or getattr(flt, "search_text", None)
- if search:
- query["$text"] = {"$search": search}
-
- return query
-
- @staticmethod
- def from_admin_pydantic(pflt: AdminEventFilter) -> EventFilter:
- ev_types: list[str] | None = None
- if pflt.event_types is not None:
- ev_types = [str(et) for et in pflt.event_types]
- return EventFilter(
- event_types=ev_types,
- aggregate_id=pflt.aggregate_id,
- correlation_id=pflt.correlation_id,
- user_id=pflt.user_id,
- service_name=pflt.service_name,
- start_time=pflt.start_time,
- end_time=pflt.end_time,
- search_text=pflt.search_text,
- text_search=pflt.search_text,
- )
diff --git a/backend/app/infrastructure/mappers/notification_mapper.py b/backend/app/infrastructure/mappers/notification_mapper.py
deleted file mode 100644
index f5e7e63b..00000000
--- a/backend/app/infrastructure/mappers/notification_mapper.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from dataclasses import asdict, fields
-from typing import Any
-
-from app.domain.notification import (
- DomainNotification,
- DomainNotificationSubscription,
-)
-
-
-class NotificationMapper:
- """Map Notification domain models to/from MongoDB documents."""
-
- # DomainNotification
- @staticmethod
- def to_mongo_document(notification: DomainNotification) -> dict[str, Any]:
- return asdict(notification)
-
- @staticmethod
- def to_update_dict(notification: DomainNotification) -> dict[str, Any]:
- doc = asdict(notification)
- doc.pop("notification_id", None)
- return doc
-
- @staticmethod
- def from_mongo_document(doc: dict[str, Any]) -> DomainNotification:
- allowed = {f.name for f in fields(DomainNotification)}
- filtered = {k: v for k, v in doc.items() if k in allowed}
- return DomainNotification(**filtered)
-
- # DomainNotificationSubscription
- @staticmethod
- def subscription_to_mongo_document(subscription: DomainNotificationSubscription) -> dict[str, Any]:
- return asdict(subscription)
-
- @staticmethod
- def subscription_from_mongo_document(doc: dict[str, Any]) -> DomainNotificationSubscription:
- allowed = {f.name for f in fields(DomainNotificationSubscription)}
- filtered = {k: v for k, v in doc.items() if k in allowed}
- return DomainNotificationSubscription(**filtered)
diff --git a/backend/app/infrastructure/mappers/rate_limit_mapper.py b/backend/app/infrastructure/mappers/rate_limit_mapper.py
deleted file mode 100644
index 2dcb359f..00000000
--- a/backend/app/infrastructure/mappers/rate_limit_mapper.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import json
-from datetime import datetime, timezone
-from typing import Any, Dict
-
-from app.domain.rate_limit import (
- EndpointGroup,
- RateLimitAlgorithm,
- RateLimitConfig,
- RateLimitRule,
- UserRateLimit,
-)
-
-
-class RateLimitRuleMapper:
- @staticmethod
- def to_dict(rule: RateLimitRule) -> Dict[str, Any]:
- return {
- "endpoint_pattern": rule.endpoint_pattern,
- "group": rule.group.value,
- "requests": rule.requests,
- "window_seconds": rule.window_seconds,
- "burst_multiplier": rule.burst_multiplier,
- "algorithm": rule.algorithm.value,
- "priority": rule.priority,
- "enabled": rule.enabled,
- }
-
- @staticmethod
- def from_dict(data: Dict[str, Any]) -> RateLimitRule:
- return RateLimitRule(
- endpoint_pattern=data["endpoint_pattern"],
- group=EndpointGroup(data["group"]),
- requests=data["requests"],
- window_seconds=data["window_seconds"],
- burst_multiplier=data.get("burst_multiplier", 1.5),
- algorithm=RateLimitAlgorithm(data.get("algorithm", RateLimitAlgorithm.SLIDING_WINDOW)),
- priority=data.get("priority", 0),
- enabled=data.get("enabled", True),
- )
-
-
-class UserRateLimitMapper:
- @staticmethod
- def to_dict(user_limit: UserRateLimit) -> Dict[str, Any]:
- rule_mapper = RateLimitRuleMapper()
- return {
- "user_id": user_limit.user_id,
- "bypass_rate_limit": user_limit.bypass_rate_limit,
- "global_multiplier": user_limit.global_multiplier,
- "rules": [rule_mapper.to_dict(rule) for rule in user_limit.rules],
- "created_at": user_limit.created_at.isoformat() if user_limit.created_at else None,
- "updated_at": user_limit.updated_at.isoformat() if user_limit.updated_at else None,
- "notes": user_limit.notes,
- }
-
- @staticmethod
- def from_dict(data: Dict[str, Any]) -> UserRateLimit:
- rule_mapper = RateLimitRuleMapper()
-
- created_at = data.get("created_at")
- if created_at and isinstance(created_at, str):
- created_at = datetime.fromisoformat(created_at)
- elif not created_at:
- created_at = datetime.now(timezone.utc)
-
- updated_at = data.get("updated_at")
- if updated_at and isinstance(updated_at, str):
- updated_at = datetime.fromisoformat(updated_at)
- elif not updated_at:
- updated_at = datetime.now(timezone.utc)
-
- return UserRateLimit(
- user_id=data["user_id"],
- bypass_rate_limit=data.get("bypass_rate_limit", False),
- global_multiplier=data.get("global_multiplier", 1.0),
- rules=[rule_mapper.from_dict(rule_data) for rule_data in data.get("rules", [])],
- created_at=created_at,
- updated_at=updated_at,
- notes=data.get("notes"),
- )
-
- @staticmethod
- def model_dump(user_limit: UserRateLimit) -> Dict[str, Any]:
- """Pydantic-compatible method for serialization."""
- return UserRateLimitMapper.to_dict(user_limit)
-
-
-class RateLimitConfigMapper:
- @staticmethod
- def to_dict(config: RateLimitConfig) -> Dict[str, Any]:
- rule_mapper = RateLimitRuleMapper()
- user_mapper = UserRateLimitMapper()
- return {
- "default_rules": [rule_mapper.to_dict(rule) for rule in config.default_rules],
- "user_overrides": {
- uid: user_mapper.to_dict(user_limit) for uid, user_limit in config.user_overrides.items()
- },
- "global_enabled": config.global_enabled,
- "redis_ttl": config.redis_ttl,
- }
-
- @staticmethod
- def from_dict(data: Dict[str, Any]) -> RateLimitConfig:
- rule_mapper = RateLimitRuleMapper()
- user_mapper = UserRateLimitMapper()
- return RateLimitConfig(
- default_rules=[rule_mapper.from_dict(rule_data) for rule_data in data.get("default_rules", [])],
- user_overrides={
- uid: user_mapper.from_dict(user_data) for uid, user_data in data.get("user_overrides", {}).items()
- },
- global_enabled=data.get("global_enabled", True),
- redis_ttl=data.get("redis_ttl", 3600),
- )
-
- @staticmethod
- def model_validate_json(json_str: str | bytes) -> RateLimitConfig:
- """Pydantic-compatible method for deserialization from JSON."""
- data = json.loads(json_str)
- return RateLimitConfigMapper.from_dict(data)
-
- @staticmethod
- def model_dump_json(config: RateLimitConfig) -> str:
- """Pydantic-compatible method for serialization to JSON."""
- return json.dumps(RateLimitConfigMapper.to_dict(config))
diff --git a/backend/app/infrastructure/mappers/replay_api_mapper.py b/backend/app/infrastructure/mappers/replay_api_mapper.py
deleted file mode 100644
index 37aabe96..00000000
--- a/backend/app/infrastructure/mappers/replay_api_mapper.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from __future__ import annotations
-
-from app.domain.enums.replay import ReplayStatus
-from app.domain.replay import ReplayConfig, ReplayFilter, ReplaySessionState
-from app.schemas_pydantic.replay import CleanupResponse, ReplayRequest, ReplayResponse, SessionSummary
-
-
-class ReplayApiMapper:
- @staticmethod
- def session_to_summary(state: ReplaySessionState) -> SessionSummary:
- duration = None
- throughput = None
- if state.started_at and state.completed_at:
- d = (state.completed_at - state.started_at).total_seconds()
- duration = d
- if state.replayed_events > 0 and d > 0:
- throughput = state.replayed_events / d
- return SessionSummary(
- session_id=state.session_id,
- replay_type=state.config.replay_type,
- target=state.config.target,
- status=state.status,
- total_events=state.total_events,
- replayed_events=state.replayed_events,
- failed_events=state.failed_events,
- skipped_events=state.skipped_events,
- created_at=state.created_at,
- started_at=state.started_at,
- completed_at=state.completed_at,
- duration_seconds=duration,
- throughput_events_per_second=throughput,
- )
-
- # Request/Response mapping for HTTP
- @staticmethod
- def request_to_filter(req: ReplayRequest) -> ReplayFilter:
- return ReplayFilter(
- execution_id=req.execution_id,
- event_types=req.event_types,
- start_time=req.start_time if req.start_time else None,
- end_time=req.end_time if req.end_time else None,
- user_id=req.user_id,
- service_name=req.service_name,
- custom_query=req.custom_query,
- exclude_event_types=req.exclude_event_types,
- )
-
- @staticmethod
- def request_to_config(req: ReplayRequest) -> ReplayConfig:
- # Convert string keys to EventType for target_topics if provided
- target_topics = None
- if req.target_topics:
- from app.domain.enums.events import EventType
-
- target_topics = {EventType(k): v for k, v in req.target_topics.items()}
-
- return ReplayConfig(
- replay_type=req.replay_type,
- target=req.target,
- filter=ReplayApiMapper.request_to_filter(req),
- speed_multiplier=req.speed_multiplier,
- preserve_timestamps=req.preserve_timestamps,
- batch_size=req.batch_size,
- max_events=req.max_events,
- skip_errors=req.skip_errors,
- target_file_path=req.target_file_path,
- target_topics=target_topics,
- retry_failed=req.retry_failed,
- retry_attempts=req.retry_attempts,
- enable_progress_tracking=req.enable_progress_tracking,
- )
-
- @staticmethod
- def op_to_response(session_id: str, status: ReplayStatus, message: str) -> ReplayResponse:
- return ReplayResponse(session_id=session_id, status=status, message=message)
-
- @staticmethod
- def cleanup_to_response(removed_sessions: int, message: str) -> CleanupResponse:
- return CleanupResponse(removed_sessions=removed_sessions, message=message)
diff --git a/backend/app/infrastructure/mappers/replay_mapper.py b/backend/app/infrastructure/mappers/replay_mapper.py
deleted file mode 100644
index cab31899..00000000
--- a/backend/app/infrastructure/mappers/replay_mapper.py
+++ /dev/null
@@ -1,212 +0,0 @@
-from datetime import datetime, timezone
-from typing import Any
-
-from app.domain.admin import (
- ReplayQuery,
- ReplaySession,
- ReplaySessionFields,
- ReplaySessionStatusDetail,
- ReplaySessionStatusInfo,
-)
-from app.domain.enums.replay import ReplayStatus
-from app.domain.events.event_models import EventFields
-from app.domain.replay import ReplayConfig as DomainReplayConfig
-from app.domain.replay import ReplaySessionState
-from app.schemas_pydantic.admin_events import EventReplayRequest
-
-
-class ReplaySessionMapper:
- @staticmethod
- def to_dict(session: ReplaySession) -> dict[str, Any]:
- doc: dict[str, Any] = {
- ReplaySessionFields.SESSION_ID: session.session_id,
- ReplaySessionFields.TYPE: session.type,
- ReplaySessionFields.STATUS: session.status,
- ReplaySessionFields.TOTAL_EVENTS: session.total_events,
- ReplaySessionFields.REPLAYED_EVENTS: session.replayed_events,
- ReplaySessionFields.FAILED_EVENTS: session.failed_events,
- ReplaySessionFields.SKIPPED_EVENTS: session.skipped_events,
- ReplaySessionFields.CORRELATION_ID: session.correlation_id,
- ReplaySessionFields.CREATED_AT: session.created_at,
- ReplaySessionFields.DRY_RUN: session.dry_run,
- "triggered_executions": session.triggered_executions,
- }
-
- if session.started_at:
- doc[ReplaySessionFields.STARTED_AT] = session.started_at
- if session.completed_at:
- doc[ReplaySessionFields.COMPLETED_AT] = session.completed_at
- if session.error:
- doc[ReplaySessionFields.ERROR] = session.error
- if session.created_by:
- doc[ReplaySessionFields.CREATED_BY] = session.created_by
- if session.target_service:
- doc[ReplaySessionFields.TARGET_SERVICE] = session.target_service
-
- return doc
-
- @staticmethod
- def from_dict(data: dict[str, Any]) -> ReplaySession:
- return ReplaySession(
- session_id=data.get(ReplaySessionFields.SESSION_ID, ""),
- type=data.get(ReplaySessionFields.TYPE, "replay_session"),
- status=ReplayStatus(data.get(ReplaySessionFields.STATUS, ReplayStatus.SCHEDULED)),
- total_events=data.get(ReplaySessionFields.TOTAL_EVENTS, 0),
- replayed_events=data.get(ReplaySessionFields.REPLAYED_EVENTS, 0),
- failed_events=data.get(ReplaySessionFields.FAILED_EVENTS, 0),
- skipped_events=data.get(ReplaySessionFields.SKIPPED_EVENTS, 0),
- correlation_id=data.get(ReplaySessionFields.CORRELATION_ID, ""),
- created_at=data.get(ReplaySessionFields.CREATED_AT, datetime.now(timezone.utc)),
- started_at=data.get(ReplaySessionFields.STARTED_AT),
- completed_at=data.get(ReplaySessionFields.COMPLETED_AT),
- error=data.get(ReplaySessionFields.ERROR),
- created_by=data.get(ReplaySessionFields.CREATED_BY),
- target_service=data.get(ReplaySessionFields.TARGET_SERVICE),
- dry_run=data.get(ReplaySessionFields.DRY_RUN, False),
- triggered_executions=data.get("triggered_executions", []),
- )
-
- @staticmethod
- def status_detail_to_dict(detail: ReplaySessionStatusDetail) -> dict[str, Any]:
- result = {
- "session_id": detail.session.session_id,
- "status": detail.session.status.value,
- "total_events": detail.session.total_events,
- "replayed_events": detail.session.replayed_events,
- "failed_events": detail.session.failed_events,
- "skipped_events": detail.session.skipped_events,
- "correlation_id": detail.session.correlation_id,
- "created_at": detail.session.created_at,
- "started_at": detail.session.started_at,
- "completed_at": detail.session.completed_at,
- "error": detail.session.error,
- "progress_percentage": detail.session.progress_percentage,
- "execution_results": detail.execution_results,
- }
-
- if detail.estimated_completion:
- result["estimated_completion"] = detail.estimated_completion
-
- return result
-
- @staticmethod
- def to_status_info(session: ReplaySession) -> ReplaySessionStatusInfo:
- return ReplaySessionStatusInfo(
- session_id=session.session_id,
- status=session.status,
- total_events=session.total_events,
- replayed_events=session.replayed_events,
- failed_events=session.failed_events,
- skipped_events=session.skipped_events,
- correlation_id=session.correlation_id,
- created_at=session.created_at,
- started_at=session.started_at,
- completed_at=session.completed_at,
- error=session.error,
- progress_percentage=session.progress_percentage,
- )
-
- @staticmethod
- def status_info_to_dict(info: ReplaySessionStatusInfo) -> dict[str, Any]:
- return {
- "session_id": info.session_id,
- "status": info.status.value,
- "total_events": info.total_events,
- "replayed_events": info.replayed_events,
- "failed_events": info.failed_events,
- "skipped_events": info.skipped_events,
- "correlation_id": info.correlation_id,
- "created_at": info.created_at,
- "started_at": info.started_at,
- "completed_at": info.completed_at,
- "error": info.error,
- "progress_percentage": info.progress_percentage,
- }
-
-
-class ReplayQueryMapper:
- @staticmethod
- def to_mongodb_query(query: ReplayQuery) -> dict[str, Any]:
- mongo_query: dict[str, Any] = {}
-
- if query.event_ids:
- mongo_query[EventFields.EVENT_ID] = {"$in": query.event_ids}
-
- if query.correlation_id:
- mongo_query[EventFields.METADATA_CORRELATION_ID] = query.correlation_id
-
- if query.aggregate_id:
- mongo_query[EventFields.AGGREGATE_ID] = query.aggregate_id
-
- if query.start_time or query.end_time:
- time_query = {}
- if query.start_time:
- time_query["$gte"] = query.start_time
- if query.end_time:
- time_query["$lte"] = query.end_time
- mongo_query[EventFields.TIMESTAMP] = time_query
-
- return mongo_query
-
-
-class ReplayApiMapper:
- """API-level mapper for converting replay requests to domain queries."""
-
- @staticmethod
- def request_to_query(req: EventReplayRequest) -> ReplayQuery:
- return ReplayQuery(
- event_ids=req.event_ids,
- correlation_id=req.correlation_id,
- aggregate_id=req.aggregate_id,
- start_time=req.start_time,
- end_time=req.end_time,
- )
-
-
-class ReplayStateMapper:
- """Mapper for service-level replay session state (domain.replay.models).
-
- Moves all domain↔Mongo conversion out of the repository.
- Assumes datetimes are stored as datetimes (no epoch/ISO fallback logic).
- """
-
- @staticmethod
- def to_mongo_document(session: ReplaySessionState | Any) -> dict[str, Any]: # noqa: ANN401
- cfg = session.config
- # Both DomainReplayConfig and schema config are Pydantic models; use model_dump
- cfg_dict = cfg.model_dump()
- return {
- "session_id": session.session_id,
- "status": session.status,
- "total_events": getattr(session, "total_events", 0),
- "replayed_events": getattr(session, "replayed_events", 0),
- "failed_events": getattr(session, "failed_events", 0),
- "skipped_events": getattr(session, "skipped_events", 0),
- "created_at": session.created_at,
- "started_at": getattr(session, "started_at", None),
- "completed_at": getattr(session, "completed_at", None),
- "last_event_at": getattr(session, "last_event_at", None),
- "errors": getattr(session, "errors", []),
- "config": cfg_dict,
- }
-
- @staticmethod
- def from_mongo_document(doc: dict[str, Any]) -> ReplaySessionState:
- cfg_dict = doc.get("config", {})
- cfg = DomainReplayConfig(**cfg_dict)
- raw_status = doc.get("status", ReplayStatus.SCHEDULED)
- status = raw_status if isinstance(raw_status, ReplayStatus) else ReplayStatus(str(raw_status))
-
- return ReplaySessionState(
- session_id=doc.get("session_id", ""),
- config=cfg,
- status=status,
- total_events=doc.get("total_events", 0),
- replayed_events=doc.get("replayed_events", 0),
- failed_events=doc.get("failed_events", 0),
- skipped_events=doc.get("skipped_events", 0),
- started_at=doc.get("started_at"),
- completed_at=doc.get("completed_at"),
- last_event_at=doc.get("last_event_at"),
- errors=doc.get("errors", []),
- )
diff --git a/backend/app/infrastructure/mappers/saga_mapper.py b/backend/app/infrastructure/mappers/saga_mapper.py
deleted file mode 100644
index 26e92bbd..00000000
--- a/backend/app/infrastructure/mappers/saga_mapper.py
+++ /dev/null
@@ -1,173 +0,0 @@
-from typing import Any
-
-from app.domain.enums.saga import SagaState
-from app.domain.saga.models import Saga, SagaFilter, SagaInstance
-
-
-class SagaMapper:
- """Maps between saga domain models and persistence representations."""
-
- def from_mongo(self, doc: dict[str, Any]) -> Saga:
- """Convert MongoDB document to domain model."""
- return Saga(
- saga_id=doc["saga_id"],
- saga_name=doc["saga_name"],
- execution_id=doc["execution_id"],
- state=SagaState(doc["state"]),
- current_step=doc.get("current_step"),
- completed_steps=doc.get("completed_steps", []),
- compensated_steps=doc.get("compensated_steps", []),
- context_data=doc.get("context_data", {}),
- error_message=doc.get("error_message"),
- created_at=doc["created_at"],
- updated_at=doc["updated_at"],
- completed_at=doc.get("completed_at"),
- retry_count=doc.get("retry_count", 0),
- )
-
- def to_mongo(self, saga: Saga) -> dict[str, Any]:
- """Convert domain model to MongoDB document.
-
- Assumes context_data is already sanitized at the source. As a light
- guardrail, exclude private keys (prefixed with "_") if any slip in.
- """
- context = saga.context_data or {}
- if isinstance(context, dict):
- context = {k: v for k, v in context.items() if not (isinstance(k, str) and k.startswith("_"))}
-
- return {
- "saga_id": saga.saga_id,
- "saga_name": saga.saga_name,
- "execution_id": saga.execution_id,
- "state": saga.state.value,
- "current_step": saga.current_step,
- "completed_steps": saga.completed_steps,
- "compensated_steps": saga.compensated_steps,
- "context_data": context,
- "error_message": saga.error_message,
- "created_at": saga.created_at,
- "updated_at": saga.updated_at,
- "completed_at": saga.completed_at,
- "retry_count": saga.retry_count,
- }
-
- def from_instance(self, instance: SagaInstance) -> Saga:
- """Convert a SagaInstance (live orchestrator view) to Saga domain model."""
- return Saga(
- saga_id=instance.saga_id,
- saga_name=instance.saga_name,
- execution_id=instance.execution_id,
- state=instance.state,
- current_step=instance.current_step,
- completed_steps=instance.completed_steps,
- compensated_steps=instance.compensated_steps,
- context_data=instance.context_data,
- error_message=instance.error_message,
- created_at=instance.created_at,
- updated_at=instance.updated_at,
- completed_at=instance.completed_at,
- retry_count=instance.retry_count,
- )
-
-
-class SagaInstanceMapper:
- """Maps SagaInstance domain <-> Mongo documents."""
-
- @staticmethod
- def from_mongo(doc: dict[str, Any]) -> SagaInstance:
- # Robust state conversion
- raw_state = doc.get("state", SagaState.CREATED)
- try:
- state = raw_state if isinstance(raw_state, SagaState) else SagaState(str(raw_state))
- except Exception:
- state = SagaState.CREATED
-
- # Build kwargs conditionally
- kwargs: dict[str, Any] = {
- "saga_id": str(doc.get("saga_id")),
- "saga_name": str(doc.get("saga_name")),
- "execution_id": str(doc.get("execution_id")),
- "state": state,
- "current_step": doc.get("current_step"),
- "completed_steps": list(doc.get("completed_steps", [])),
- "compensated_steps": list(doc.get("compensated_steps", [])),
- "context_data": dict(doc.get("context_data", {})),
- "error_message": doc.get("error_message"),
- "completed_at": doc.get("completed_at"),
- "retry_count": int(doc.get("retry_count", 0)),
- }
-
- # Only add datetime fields if they exist and are valid
- if doc.get("created_at"):
- kwargs["created_at"] = doc.get("created_at")
- if doc.get("updated_at"):
- kwargs["updated_at"] = doc.get("updated_at")
-
- return SagaInstance(**kwargs)
-
- @staticmethod
- def to_mongo(instance: SagaInstance) -> dict[str, Any]:
- # Clean context to ensure it's serializable and skip internal keys
- clean_context: dict[str, Any] = {}
- for k, v in (instance.context_data or {}).items():
- if k.startswith("_"):
- continue
- if isinstance(v, (str, int, float, bool, list, dict, type(None))):
- clean_context[k] = v
- else:
- try:
- clean_context[k] = str(v)
- except Exception:
- continue
-
- return {
- "saga_id": str(instance.saga_id),
- "saga_name": instance.saga_name,
- "execution_id": instance.execution_id,
- "state": instance.state.value if hasattr(instance.state, "value") else str(instance.state),
- "current_step": instance.current_step,
- "completed_steps": instance.completed_steps,
- "compensated_steps": instance.compensated_steps,
- "context_data": clean_context,
- "error_message": instance.error_message,
- "created_at": instance.created_at,
- "updated_at": instance.updated_at,
- "completed_at": instance.completed_at,
- "retry_count": instance.retry_count,
- }
-
-
-class SagaFilterMapper:
- """Maps saga filters to MongoDB queries."""
-
- def to_mongodb_query(self, saga_filter: SagaFilter | None) -> dict[str, Any]:
- """Convert filter to MongoDB query."""
- query: dict[str, Any] = {}
-
- if not saga_filter:
- return query
-
- if saga_filter.state:
- query["state"] = saga_filter.state.value
-
- if saga_filter.execution_ids:
- query["execution_id"] = {"$in": saga_filter.execution_ids}
-
- if saga_filter.saga_name:
- query["saga_name"] = saga_filter.saga_name
-
- if saga_filter.error_status is not None:
- if saga_filter.error_status:
- query["error_message"] = {"$ne": None}
- else:
- query["error_message"] = None
-
- if saga_filter.created_after or saga_filter.created_before:
- time_query: dict[str, Any] = {}
- if saga_filter.created_after:
- time_query["$gte"] = saga_filter.created_after
- if saga_filter.created_before:
- time_query["$lte"] = saga_filter.created_before
- query["created_at"] = time_query
-
- return query
diff --git a/backend/app/infrastructure/mappers/saved_script_mapper.py b/backend/app/infrastructure/mappers/saved_script_mapper.py
deleted file mode 100644
index 5d4ff774..00000000
--- a/backend/app/infrastructure/mappers/saved_script_mapper.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from dataclasses import asdict, fields
-from datetime import datetime, timezone
-from typing import Any
-from uuid import uuid4
-
-from app.domain.saved_script import (
- DomainSavedScript,
- DomainSavedScriptCreate,
- DomainSavedScriptUpdate,
-)
-
-
-class SavedScriptMapper:
- """Mapper for Saved Script domain models to/from MongoDB docs."""
-
- @staticmethod
- def to_insert_document(create: DomainSavedScriptCreate, user_id: str) -> dict[str, Any]:
- now = datetime.now(timezone.utc)
- return {
- "script_id": str(uuid4()),
- "user_id": user_id,
- "name": create.name,
- "script": create.script,
- "lang": create.lang,
- "lang_version": create.lang_version,
- "description": create.description,
- "created_at": now,
- "updated_at": now,
- }
-
- @staticmethod
- def to_update_dict(update: DomainSavedScriptUpdate) -> dict[str, Any]:
- # Convert to dict and drop None fields; keep updated_at
- raw = asdict(update)
- return {k: v for k, v in raw.items() if v is not None}
-
- @staticmethod
- def from_mongo_document(doc: dict[str, Any]) -> DomainSavedScript:
- allowed = {f.name for f in fields(DomainSavedScript)}
- filtered = {k: v for k, v in doc.items() if k in allowed}
- # Coerce required fields to str where applicable for safety
- if "script_id" in filtered:
- filtered["script_id"] = str(filtered["script_id"])
- if "user_id" in filtered:
- filtered["user_id"] = str(filtered["user_id"])
- if "name" in filtered:
- filtered["name"] = str(filtered["name"])
- if "script" in filtered:
- filtered["script"] = str(filtered["script"])
- if "lang" in filtered:
- filtered["lang"] = str(filtered["lang"])
- if "lang_version" in filtered:
- filtered["lang_version"] = str(filtered["lang_version"])
- return DomainSavedScript(**filtered) # dataclass defaults cover missing timestamps
diff --git a/backend/app/infrastructure/mappers/sse_mapper.py b/backend/app/infrastructure/mappers/sse_mapper.py
deleted file mode 100644
index 5a391e5e..00000000
--- a/backend/app/infrastructure/mappers/sse_mapper.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from datetime import datetime, timezone
-from typing import Any
-
-from app.domain.enums.execution import ExecutionStatus
-from app.domain.execution import DomainExecution, ResourceUsageDomain
-
-
-class SSEMapper:
- """Mapper for SSE-related domain models from MongoDB documents."""
-
- @staticmethod
- def execution_from_mongo_document(doc: dict[str, Any]) -> DomainExecution:
- resource_usage_data = doc.get("resource_usage")
- return DomainExecution(
- execution_id=str(doc.get("execution_id")),
- script=str(doc.get("script", "")),
- status=ExecutionStatus(str(doc.get("status"))),
- stdout=doc.get("stdout"),
- stderr=doc.get("stderr"),
- lang=str(doc.get("lang", "python")),
- lang_version=str(doc.get("lang_version", "3.11")),
- created_at=doc.get("created_at", datetime.now(timezone.utc)),
- updated_at=doc.get("updated_at", datetime.now(timezone.utc)),
- resource_usage=ResourceUsageDomain.from_dict(resource_usage_data) if resource_usage_data else None,
- user_id=doc.get("user_id"),
- exit_code=doc.get("exit_code"),
- error_type=doc.get("error_type"),
- )
diff --git a/backend/app/infrastructure/mappers/user_settings_mapper.py b/backend/app/infrastructure/mappers/user_settings_mapper.py
deleted file mode 100644
index f3428e6c..00000000
--- a/backend/app/infrastructure/mappers/user_settings_mapper.py
+++ /dev/null
@@ -1,100 +0,0 @@
-from datetime import datetime, timezone
-from typing import Any
-
-from app.domain.enums import Theme
-from app.domain.enums.events import EventType
-from app.domain.enums.notification import NotificationChannel
-from app.domain.user.settings_models import (
- DomainEditorSettings,
- DomainNotificationSettings,
- DomainSettingsEvent,
- DomainUserSettings,
-)
-
-
-class UserSettingsMapper:
- """Map user settings snapshot/event documents to domain and back."""
-
- @staticmethod
- def from_snapshot_document(doc: dict[str, Any]) -> DomainUserSettings:
- notifications = doc.get("notifications", {})
- editor = doc.get("editor", {})
- theme = Theme(doc.get("theme", Theme.AUTO))
-
- # Use domain dataclass defaults for fallback values
- default_notifications = DomainNotificationSettings()
- default_editor = DomainEditorSettings()
-
- # Coerce channels to NotificationChannel list, using domain default if not present
- channels_raw = notifications.get("channels")
- if channels_raw is not None:
- channels: list[NotificationChannel] = [NotificationChannel(c) for c in channels_raw]
- else:
- channels = default_notifications.channels
-
- return DomainUserSettings(
- user_id=str(doc.get("user_id")),
- theme=theme,
- timezone=doc.get("timezone", "UTC"),
- date_format=doc.get("date_format", "YYYY-MM-DD"),
- time_format=doc.get("time_format", "24h"),
- notifications=DomainNotificationSettings(
- execution_completed=notifications.get("execution_completed", default_notifications.execution_completed),
- execution_failed=notifications.get("execution_failed", default_notifications.execution_failed),
- system_updates=notifications.get("system_updates", default_notifications.system_updates),
- security_alerts=notifications.get("security_alerts", default_notifications.security_alerts),
- channels=channels,
- ),
- editor=DomainEditorSettings(
- theme=editor.get("theme", default_editor.theme),
- font_size=editor.get("font_size", default_editor.font_size),
- tab_size=editor.get("tab_size", default_editor.tab_size),
- use_tabs=editor.get("use_tabs", default_editor.use_tabs),
- word_wrap=editor.get("word_wrap", default_editor.word_wrap),
- show_line_numbers=editor.get("show_line_numbers", default_editor.show_line_numbers),
- ),
- custom_settings=doc.get("custom_settings", {}),
- version=doc.get("version", 1),
- created_at=doc.get("created_at", datetime.now(timezone.utc)),
- updated_at=doc.get("updated_at", datetime.now(timezone.utc)),
- )
-
- @staticmethod
- def to_snapshot_document(settings: DomainUserSettings) -> dict[str, Any]:
- return {
- "user_id": settings.user_id,
- "theme": str(settings.theme),
- "timezone": settings.timezone,
- "date_format": settings.date_format,
- "time_format": settings.time_format,
- "notifications": {
- "execution_completed": settings.notifications.execution_completed,
- "execution_failed": settings.notifications.execution_failed,
- "system_updates": settings.notifications.system_updates,
- "security_alerts": settings.notifications.security_alerts,
- "channels": [str(c) for c in settings.notifications.channels],
- },
- "editor": {
- "theme": settings.editor.theme,
- "font_size": settings.editor.font_size,
- "tab_size": settings.editor.tab_size,
- "use_tabs": settings.editor.use_tabs,
- "word_wrap": settings.editor.word_wrap,
- "show_line_numbers": settings.editor.show_line_numbers,
- },
- "custom_settings": settings.custom_settings,
- "version": settings.version,
- "created_at": settings.created_at,
- "updated_at": settings.updated_at,
- }
-
- @staticmethod
- def event_from_mongo_document(doc: dict[str, Any]) -> DomainSettingsEvent:
- et_parsed: EventType = EventType(str(doc.get("event_type")))
-
- return DomainSettingsEvent(
- event_type=et_parsed,
- timestamp=doc.get("timestamp"), # type: ignore[arg-type]
- payload=doc.get("payload", {}),
- correlation_id=(doc.get("metadata", {}) or {}).get("correlation_id"),
- )
diff --git a/backend/app/main.py b/backend/app/main.py
index 510c2c40..f1c58209 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -30,7 +30,7 @@
from app.core.correlation import CorrelationMiddleware
from app.core.dishka_lifespan import lifespan
from app.core.exceptions import configure_exception_handlers
-from app.core.logging import logger
+from app.core.logging import setup_logger
from app.core.middlewares import (
CacheControlMiddleware,
MetricsMiddleware,
@@ -43,6 +43,7 @@
def create_app() -> FastAPI:
settings = get_settings()
+ logger = setup_logger(settings.LOG_LEVEL)
# Disable OpenAPI/Docs in production for security; health endpoints provide readiness
app = FastAPI(
title=settings.PROJECT_NAME,
@@ -55,7 +56,7 @@ def create_app() -> FastAPI:
container = create_app_container()
setup_dishka(container, app)
- setup_metrics(app)
+ setup_metrics(app, logger)
app.add_middleware(MetricsMiddleware)
if settings.RATE_LIMIT_ENABLED:
app.add_middleware(RateLimitMiddleware)
@@ -123,7 +124,7 @@ def create_app() -> FastAPI:
if __name__ == "__main__":
settings = get_settings()
-
+ logger = setup_logger(settings.LOG_LEVEL)
logger.info(
"Starting uvicorn server",
extra={
diff --git a/backend/app/schemas_pydantic/admin_events.py b/backend/app/schemas_pydantic/admin_events.py
index 2b679a91..0a78c606 100644
--- a/backend/app/schemas_pydantic/admin_events.py
+++ b/backend/app/schemas_pydantic/admin_events.py
@@ -1,10 +1,20 @@
from datetime import datetime
from typing import Any, Dict, List
-from pydantic import BaseModel, ConfigDict, Field
+from pydantic import BaseModel, ConfigDict, Field, computed_field
from app.domain.enums.events import EventType
from app.schemas_pydantic.events import HourlyEventCountSchema
+from app.schemas_pydantic.execution import ExecutionResult
+
+
+class ReplayErrorInfo(BaseModel):
+ """Error info for replay operations."""
+
+ timestamp: datetime
+ error: str
+ event_id: str | None = None
+ error_type: str | None = None
class EventFilter(BaseModel):
@@ -85,10 +95,14 @@ class EventReplayStatusResponse(BaseModel):
created_at: datetime
started_at: datetime | None = None
completed_at: datetime | None = None
- error: str | None = None
- progress_percentage: float
+ errors: List[ReplayErrorInfo] | None = None
estimated_completion: datetime | None = None
- execution_results: List[Dict[str, Any]] | None = None # Results from replayed executions
+ execution_results: List[ExecutionResult] | None = None
+
+ @computed_field # type: ignore[prop-decorator]
+ @property
+ def progress_percentage(self) -> float:
+ return round(self.replayed_events / max(self.total_events, 1) * 100, 2)
class EventDeleteResponse(BaseModel):
diff --git a/backend/app/schemas_pydantic/admin_settings.py b/backend/app/schemas_pydantic/admin_settings.py
index 23645420..c33b6f36 100644
--- a/backend/app/schemas_pydantic/admin_settings.py
+++ b/backend/app/schemas_pydantic/admin_settings.py
@@ -4,6 +4,8 @@
class ExecutionLimitsSchema(BaseModel):
"""Execution resource limits schema."""
+ model_config = ConfigDict(from_attributes=True)
+
max_timeout_seconds: int = Field(default=300, ge=10, le=3600, description="Maximum execution timeout")
max_memory_mb: int = Field(default=512, ge=128, le=4096, description="Maximum memory in MB")
max_cpu_cores: int = Field(default=2, ge=1, le=8, description="Maximum CPU cores")
@@ -13,6 +15,8 @@ class ExecutionLimitsSchema(BaseModel):
class SecuritySettingsSchema(BaseModel):
"""Security configuration schema."""
+ model_config = ConfigDict(from_attributes=True)
+
password_min_length: int = Field(default=8, ge=6, le=32, description="Minimum password length")
session_timeout_minutes: int = Field(default=60, ge=5, le=1440, description="Session timeout in minutes")
max_login_attempts: int = Field(default=5, ge=3, le=10, description="Maximum login attempts")
@@ -22,6 +26,8 @@ class SecuritySettingsSchema(BaseModel):
class MonitoringSettingsSchema(BaseModel):
"""Monitoring and observability schema."""
+ model_config = ConfigDict(from_attributes=True, use_enum_values=True)
+
metrics_retention_days: int = Field(default=30, ge=7, le=90, description="Metrics retention in days")
log_level: str = Field(default="INFO", pattern="^(DEBUG|INFO|WARNING|ERROR|CRITICAL)$", description="Log level")
enable_tracing: bool = Field(default=True, description="Enable distributed tracing")
@@ -31,7 +37,7 @@ class MonitoringSettingsSchema(BaseModel):
class SystemSettings(BaseModel):
"""System-wide settings model."""
- model_config = ConfigDict(extra="ignore")
+ model_config = ConfigDict(extra="ignore", from_attributes=True)
execution_limits: ExecutionLimitsSchema = Field(default_factory=ExecutionLimitsSchema)
security_settings: SecuritySettingsSchema = Field(default_factory=SecuritySettingsSchema)
diff --git a/backend/app/schemas_pydantic/events.py b/backend/app/schemas_pydantic/events.py
index 346b91a7..854c6a21 100644
--- a/backend/app/schemas_pydantic/events.py
+++ b/backend/app/schemas_pydantic/events.py
@@ -72,7 +72,7 @@ class EventFilterRequest(BaseModel):
service_name: str | None = Field(None, description="Filter by service name")
start_time: datetime | None = Field(None, description="Filter events after this time")
end_time: datetime | None = Field(None, description="Filter events before this time")
- text_search: str | None = Field(None, description="Full-text search in event data")
+ search_text: str | None = Field(None, description="Full-text search in event data")
sort_by: str = Field("timestamp", description="Field to sort by")
sort_order: SortOrder = Field(SortOrder.DESC, description="Sort order")
limit: int = Field(100, ge=1, le=1000, description="Maximum events to return")
@@ -190,7 +190,7 @@ class EventQuery(BaseModel):
service_name: str | None = None
start_time: datetime | None = None
end_time: datetime | None = None
- text_search: str | None = None
+ search_text: str | None = None
limit: int = Field(default=100, ge=1, le=1000)
skip: int = Field(default=0, ge=0)
diff --git a/backend/app/schemas_pydantic/execution.py b/backend/app/schemas_pydantic/execution.py
index 74fc8b18..843ca75f 100644
--- a/backend/app/schemas_pydantic/execution.py
+++ b/backend/app/schemas_pydantic/execution.py
@@ -57,12 +57,10 @@ class ExecutionUpdate(BaseModel):
class ResourceUsage(BaseModel):
"""Model for execution resource usage."""
- execution_time_wall_seconds: float | None = Field(default=None, description="Wall clock execution time in seconds")
- cpu_time_jiffies: int | None = Field(
- default=None, description="CPU time in jiffies (multiply by 10 for milliseconds)"
- )
- clk_tck_hertz: int | None = Field(default=None, description="Clock ticks per second (usually 100)")
- peak_memory_kb: int | None = Field(default=None, description="Peak memory usage in KB")
+ execution_time_wall_seconds: float = 0.0
+ cpu_time_jiffies: int = 0
+ clk_tck_hertz: int = 0
+ peak_memory_kb: int = 0
model_config = ConfigDict(from_attributes=True)
diff --git a/backend/app/schemas_pydantic/replay.py b/backend/app/schemas_pydantic/replay.py
index 9949884e..5101a7ac 100644
--- a/backend/app/schemas_pydantic/replay.py
+++ b/backend/app/schemas_pydantic/replay.py
@@ -1,10 +1,11 @@
from datetime import datetime
-from typing import Any, Dict, List
+from typing import Dict
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, ConfigDict, Field, computed_field
from app.domain.enums.events import EventType
from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
+from app.domain.replay import ReplayFilter
class ReplayRequest(BaseModel):
@@ -12,25 +13,15 @@ class ReplayRequest(BaseModel):
replay_type: ReplayType
target: ReplayTarget = ReplayTarget.KAFKA
+ filter: ReplayFilter = Field(default_factory=ReplayFilter)
- # Filter options
- execution_id: str | None = None
- event_types: List[EventType] | None = None
- start_time: datetime | None = None
- end_time: datetime | None = None
- user_id: str | None = None
- service_name: str | None = None
- custom_query: Dict[str, Any] | None = None
- exclude_event_types: List[EventType] | None = None
-
- # Replay configuration
speed_multiplier: float = Field(default=1.0, ge=0.1, le=100.0)
preserve_timestamps: bool = False
batch_size: int = Field(default=100, ge=1, le=1000)
max_events: int | None = Field(default=None, ge=1)
skip_errors: bool = True
target_file_path: str | None = None
- target_topics: Dict[str, str] | None = None
+ target_topics: Dict[EventType, str] | None = None
retry_failed: bool = False
retry_attempts: int = Field(default=3, ge=1, le=10)
enable_progress_tracking: bool = True
@@ -47,6 +38,8 @@ class ReplayResponse(BaseModel):
class SessionSummary(BaseModel):
"""Summary information for replay sessions"""
+ model_config = ConfigDict(from_attributes=True)
+
session_id: str
replay_type: ReplayType
target: ReplayTarget
@@ -58,8 +51,20 @@ class SessionSummary(BaseModel):
created_at: datetime
started_at: datetime | None
completed_at: datetime | None
- duration_seconds: float | None = None
- throughput_events_per_second: float | None = None
+
+ @computed_field # type: ignore[prop-decorator]
+ @property
+ def duration_seconds(self) -> float | None:
+ if self.started_at and self.completed_at:
+ return (self.completed_at - self.started_at).total_seconds()
+ return None
+
+ @computed_field # type: ignore[prop-decorator]
+ @property
+ def throughput_events_per_second(self) -> float | None:
+ if self.duration_seconds and self.duration_seconds > 0 and self.replayed_events > 0:
+ return self.replayed_events / self.duration_seconds
+ return None
class CleanupResponse(BaseModel):
diff --git a/backend/app/schemas_pydantic/replay_models.py b/backend/app/schemas_pydantic/replay_models.py
index d5fa4c2d..19fa0273 100644
--- a/backend/app/schemas_pydantic/replay_models.py
+++ b/backend/app/schemas_pydantic/replay_models.py
@@ -2,37 +2,23 @@
from typing import Any, Dict, List
from uuid import uuid4
-from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
+from pydantic import BaseModel, ConfigDict, Field
+from app.domain.enums.events import EventType
from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
-from app.domain.replay import ReplayConfig as DomainReplayConfig
-from app.domain.replay import ReplayFilter as DomainReplayFilter
class ReplayFilterSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)
execution_id: str | None = None
- event_types: List[str] | None = None
+ event_types: List[EventType] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
user_id: str | None = None
service_name: str | None = None
custom_query: Dict[str, Any] | None = None
- exclude_event_types: List[str] | None = None
-
- @classmethod
- def from_domain(cls, f: DomainReplayFilter) -> "ReplayFilterSchema":
- return cls(
- execution_id=f.execution_id,
- event_types=[str(et) for et in (f.event_types or [])] or None,
- start_time=f.start_time,
- end_time=f.end_time,
- user_id=f.user_id,
- service_name=f.service_name,
- custom_query=f.custom_query,
- exclude_event_types=[str(et) for et in (f.exclude_event_types or [])] or None,
- )
+ exclude_event_types: List[EventType] | None = None
class ReplayConfigSchema(BaseModel):
@@ -40,15 +26,14 @@ class ReplayConfigSchema(BaseModel):
replay_type: ReplayType
target: ReplayTarget = ReplayTarget.KAFKA
- filter: ReplayFilterSchema
+ filter: ReplayFilterSchema = Field(default_factory=ReplayFilterSchema)
speed_multiplier: float = Field(default=1.0, ge=0.1, le=100.0)
preserve_timestamps: bool = False
batch_size: int = Field(default=100, ge=1, le=1000)
max_events: int | None = Field(default=None, ge=1)
- # Use string keys for event types for clean JSON
- target_topics: Dict[str, str] | None = None
+ target_topics: Dict[EventType, str] | None = None
target_file_path: str | None = None
skip_errors: bool = True
@@ -57,28 +42,6 @@ class ReplayConfigSchema(BaseModel):
enable_progress_tracking: bool = True
- @field_validator("filter", mode="before")
- @classmethod
- def _coerce_filter(cls, v: Any) -> Any: # noqa: ANN001
- if isinstance(v, DomainReplayFilter):
- return ReplayFilterSchema.from_domain(v).model_dump()
- return v
-
- @model_validator(mode="before")
- @classmethod
- def _from_domain(cls, data: Any) -> Any: # noqa: ANN001
- if isinstance(data, DomainReplayConfig):
- # Convert DomainReplayConfig to dict compatible with this schema
- d = data.model_dump()
- # Convert filter
- filt = data.filter
- d["filter"] = ReplayFilterSchema.from_domain(filt).model_dump()
- # Convert target_topics keys to strings if present
- if d.get("target_topics"):
- d["target_topics"] = {str(k): v for k, v in d["target_topics"].items()}
- return d
- return data
-
class ReplaySession(BaseModel):
model_config = ConfigDict(from_attributes=True)
@@ -98,12 +61,3 @@ class ReplaySession(BaseModel):
last_event_at: datetime | None = None
errors: List[Dict[str, Any]] = Field(default_factory=list)
-
- @field_validator("config", mode="before")
- @classmethod
- def _coerce_config(cls, v: Any) -> Any: # noqa: ANN001
- if isinstance(v, DomainReplayConfig):
- return ReplayConfigSchema.model_validate(v).model_dump()
- if isinstance(v, dict):
- return v
- return v
diff --git a/backend/app/schemas_pydantic/saga.py b/backend/app/schemas_pydantic/saga.py
index 217b469d..130c8296 100644
--- a/backend/app/schemas_pydantic/saga.py
+++ b/backend/app/schemas_pydantic/saga.py
@@ -1,12 +1,15 @@
-from pydantic import BaseModel
+from datetime import datetime
+
+from pydantic import BaseModel, ConfigDict
from app.domain.enums.saga import SagaState
-from app.domain.saga.models import Saga
class SagaStatusResponse(BaseModel):
"""Response schema for saga status"""
+ model_config = ConfigDict(from_attributes=True)
+
saga_id: str
saga_name: str
execution_id: str
@@ -15,29 +18,11 @@ class SagaStatusResponse(BaseModel):
completed_steps: list[str]
compensated_steps: list[str]
error_message: str | None
- created_at: str
- updated_at: str
- completed_at: str | None
+ created_at: datetime
+ updated_at: datetime
+ completed_at: datetime | None
retry_count: int
- @classmethod
- def from_domain(cls, saga: "Saga") -> "SagaStatusResponse":
- """Create response from domain model."""
- return cls(
- saga_id=saga.saga_id,
- saga_name=saga.saga_name,
- execution_id=saga.execution_id,
- state=saga.state,
- current_step=saga.current_step,
- completed_steps=saga.completed_steps,
- compensated_steps=saga.compensated_steps,
- error_message=saga.error_message,
- created_at=saga.created_at.isoformat(),
- updated_at=saga.updated_at.isoformat(),
- completed_at=saga.completed_at.isoformat() if saga.completed_at else None,
- retry_count=saga.retry_count,
- )
-
class SagaListResponse(BaseModel):
"""Response schema for saga list"""
diff --git a/backend/app/schemas_pydantic/user.py b/backend/app/schemas_pydantic/user.py
index 2a450977..7ef35d84 100644
--- a/backend/app/schemas_pydantic/user.py
+++ b/backend/app/schemas_pydantic/user.py
@@ -84,20 +84,6 @@ class User(BaseModel):
arbitrary_types_allowed=True,
)
- @classmethod
- def from_response(cls, user_response: UserResponse) -> "User":
- """Create User from UserResponse"""
- return cls(
- user_id=user_response.user_id,
- username=user_response.username,
- email=user_response.email,
- role=user_response.role,
- is_active=user_response.is_active,
- is_superuser=user_response.is_superuser,
- created_at=user_response.created_at,
- updated_at=user_response.updated_at,
- )
-
class UserListResponse(BaseModel):
"""Response model for listing users"""
diff --git a/backend/app/services/admin/admin_events_service.py b/backend/app/services/admin/admin_events_service.py
index 2a9ed6e8..684ab7d7 100644
--- a/backend/app/services/admin/admin_events_service.py
+++ b/backend/app/services/admin/admin_events_service.py
@@ -1,16 +1,15 @@
import csv
import json
+import logging
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from io import StringIO
from typing import Any, Dict, List
-from app.core.logging import logger
+from beanie.odm.enums import SortDirection
+
from app.db.repositories.admin import AdminEventsRepository
-from app.domain.admin import (
- ReplayQuery,
- ReplaySessionStatusDetail,
-)
+from app.domain.admin import ReplayQuery, ReplaySessionStatusDetail
from app.domain.admin.replay_updates import ReplaySessionUpdate
from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
from app.domain.events.event_models import (
@@ -21,10 +20,24 @@
EventStatistics,
)
from app.domain.replay import ReplayConfig, ReplayFilter
-from app.infrastructure.mappers import EventExportRowMapper
from app.services.replay_service import ReplayService
+def _export_row_to_dict(row: EventExportRow) -> dict[str, str]:
+ """Convert EventExportRow to dict with display names."""
+ return {
+ "Event ID": row.event_id,
+ "Event Type": row.event_type,
+ "Timestamp": row.timestamp,
+ "Correlation ID": row.correlation_id,
+ "Aggregate ID": row.aggregate_id,
+ "User ID": row.user_id,
+ "Service": row.service,
+ "Status": row.status,
+ "Error": row.error,
+ }
+
+
class AdminReplayResult:
def __init__(
self,
@@ -52,9 +65,12 @@ class ExportResult:
class AdminEventsService:
- def __init__(self, repository: AdminEventsRepository, replay_service: ReplayService) -> None:
+ def __init__(
+ self, repository: AdminEventsRepository, replay_service: ReplayService, logger: logging.Logger
+ ) -> None:
self._repo = repository
self._replay_service = replay_service
+ self.logger = logger
async def browse_events(
self,
@@ -65,8 +81,9 @@ async def browse_events(
sort_by: str,
sort_order: int,
) -> EventBrowseResult:
+ direction = SortDirection.DESCENDING if sort_order == -1 else SortDirection.ASCENDING
return await self._repo.browse_events(
- event_filter=event_filter, skip=skip, limit=limit, sort_by=sort_by, sort_order=sort_order
+ event_filter=event_filter, skip=skip, limit=limit, sort_by=sort_by, sort_order=direction
)
async def get_event_detail(self, event_id: str) -> EventDetail | None:
@@ -83,12 +100,11 @@ async def prepare_or_schedule_replay(
replay_correlation_id: str,
target_service: str | None,
) -> AdminReplayResult:
- query = self._repo.build_replay_query(replay_query)
- if not query:
+ if replay_query.is_empty():
raise ValueError("Must specify at least one filter for replay")
# Prepare and optionally preview
- logger.info(
+ self.logger.info(
"Preparing replay session",
extra={
"dry_run": dry_run,
@@ -96,7 +112,7 @@ async def prepare_or_schedule_replay(
},
)
session_data = await self._repo.prepare_replay_session(
- query=query,
+ replay_query=replay_query,
dry_run=dry_run,
replay_correlation_id=replay_correlation_id,
max_events=1000,
@@ -120,7 +136,7 @@ async def prepare_or_schedule_replay(
status="Preview",
events_preview=previews,
)
- logger.info(
+ self.logger.info(
"Replay dry-run prepared",
extra={
"total_events": result.total_events,
@@ -130,7 +146,10 @@ async def prepare_or_schedule_replay(
return result
# Build config for actual replay and create session via replay service
- replay_filter = ReplayFilter(custom_query=query)
+ replay_filter = ReplayFilter(
+ start_time=replay_query.start_time,
+ end_time=replay_query.end_time,
+ )
config = ReplayConfig(
replay_type=ReplayType.QUERY,
target=ReplayTarget.KAFKA if target_service else ReplayTarget.TEST,
@@ -163,7 +182,7 @@ async def prepare_or_schedule_replay(
session_id=session_id,
status="Replay scheduled",
)
- logger.info(
+ self.logger.info(
"Replay scheduled",
extra={
"session_id": result.session_id,
@@ -202,12 +221,11 @@ async def export_events_csv_content(self, *, event_filter: EventFilter, limit: i
],
)
writer.writeheader()
- row_mapper = EventExportRowMapper()
for row in rows[:limit]:
- writer.writerow(row_mapper.to_dict(row))
+ writer.writerow(_export_row_to_dict(row))
output.seek(0)
filename = f"events_export_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.csv"
- logger.info(
+ self.logger.info(
"Exported events CSV",
extra={
"row_count": len(rows),
@@ -218,7 +236,7 @@ async def export_events_csv_content(self, *, event_filter: EventFilter, limit: i
async def export_events_json_content(self, *, event_filter: EventFilter, limit: int) -> ExportResult:
result = await self._repo.browse_events(
- event_filter=event_filter, skip=0, limit=limit, sort_by="timestamp", sort_order=-1
+ event_filter=event_filter, skip=0, limit=limit, sort_by="timestamp", sort_order=SortDirection.DESCENDING
)
events_data: list[dict[str, Any]] = []
for event in result.events:
@@ -247,7 +265,7 @@ async def export_events_json_content(self, *, event_filter: EventFilter, limit:
}
json_content = json.dumps(export_data, indent=2, default=str)
filename = f"events_export_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.json"
- logger.info(
+ self.logger.info(
"Exported events JSON",
extra={
"event_count": len(events_data),
@@ -258,14 +276,14 @@ async def export_events_json_content(self, *, event_filter: EventFilter, limit:
async def delete_event(self, *, event_id: str, deleted_by: str) -> bool:
# Load event for archival; archive then delete
- logger.warning("Admin attempting to delete event", extra={"event_id": event_id, "deleted_by": deleted_by})
+ self.logger.warning("Admin attempting to delete event", extra={"event_id": event_id, "deleted_by": deleted_by})
detail = await self._repo.get_event_detail(event_id)
if not detail:
return False
await self._repo.archive_event(detail.event, deleted_by)
deleted = await self._repo.delete_event(event_id)
if deleted:
- logger.info(
+ self.logger.info(
"Event deleted",
extra={
"event_id": event_id,
diff --git a/backend/app/services/admin/admin_settings_service.py b/backend/app/services/admin/admin_settings_service.py
index 88754c80..674ad8de 100644
--- a/backend/app/services/admin/admin_settings_service.py
+++ b/backend/app/services/admin/admin_settings_service.py
@@ -1,14 +1,16 @@
-from app.core.logging import logger
+import logging
+
from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository
from app.domain.admin import SystemSettings
class AdminSettingsService:
- def __init__(self, repository: AdminSettingsRepository):
+ def __init__(self, repository: AdminSettingsRepository, logger: logging.Logger):
self._repo = repository
+ self.logger = logger
async def get_system_settings(self, admin_username: str) -> SystemSettings:
- logger.info(
+ self.logger.info(
"Admin retrieving system settings",
extra={"admin_username": admin_username},
)
@@ -21,21 +23,21 @@ async def update_system_settings(
updated_by: str,
user_id: str,
) -> SystemSettings:
- logger.info(
+ self.logger.info(
"Admin updating system settings",
extra={"admin_username": updated_by},
)
updated = await self._repo.update_system_settings(settings=settings, updated_by=updated_by, user_id=user_id)
- logger.info("System settings updated successfully")
+ self.logger.info("System settings updated successfully")
return updated
async def reset_system_settings(self, username: str, user_id: str) -> SystemSettings:
# Reset (with audit) and return fresh defaults persisted via get
- logger.info(
+ self.logger.info(
"Admin resetting system settings to defaults",
extra={"admin_username": username},
)
await self._repo.reset_system_settings(username=username, user_id=user_id)
settings = await self._repo.get_system_settings()
- logger.info("System settings reset to defaults")
+ self.logger.info("System settings reset to defaults")
return settings
diff --git a/backend/app/services/admin/admin_user_service.py b/backend/app/services/admin/admin_user_service.py
index 8914f270..88e2e8e2 100644
--- a/backend/app/services/admin/admin_user_service.py
+++ b/backend/app/services/admin/admin_user_service.py
@@ -1,7 +1,7 @@
+import logging
+from dataclasses import asdict
from datetime import datetime, timedelta, timezone
-from uuid import uuid4
-from app.core.logging import logger
from app.core.security import SecurityService
from app.db.repositories.admin.admin_user_repository import AdminUserRepository
from app.domain.admin import AdminUserOverviewDomain, DerivedCountsDomain, RateLimitSummaryDomain
@@ -9,8 +9,7 @@
from app.domain.enums.execution import ExecutionStatus
from app.domain.enums.user import UserRole
from app.domain.rate_limit import RateLimitUpdateResult, UserRateLimit, UserRateLimitsResult
-from app.domain.user import PasswordReset, User, UserListResult, UserUpdate
-from app.infrastructure.mappers import UserRateLimitMapper
+from app.domain.user import DomainUserCreate, PasswordReset, User, UserListResult, UserUpdate
from app.schemas_pydantic.user import UserCreate
from app.services.event_service import EventService
from app.services.execution_service import ExecutionService
@@ -24,14 +23,16 @@ def __init__(
event_service: EventService,
execution_service: ExecutionService,
rate_limit_service: RateLimitService,
+ logger: logging.Logger,
) -> None:
self._users = user_repository
self._events = event_service
self._executions = execution_service
self._rate_limits = rate_limit_service
+ self.logger = logger
async def get_user_overview(self, user_id: str, hours: int = 24) -> AdminUserOverviewDomain:
- logger.info("Admin getting user overview", extra={"target_user_id": user_id, "hours": hours})
+ self.logger.info("Admin getting user overview", extra={"target_user_id": user_id, "hours": hours})
user = await self._users.get_user_by_id(user_id)
if not user:
raise ValueError("User not found")
@@ -101,7 +102,7 @@ def _count(status: ExecutionStatus) -> int:
async def list_users(
self, *, admin_username: str, limit: int, offset: int, search: str | None, role: UserRole | None
) -> UserListResult:
- logger.info(
+ self.logger.info(
"Admin listing users",
extra={
"admin_username": admin_username,
@@ -116,7 +117,7 @@ async def list_users(
async def create_user(self, *, admin_username: str, user_data: UserCreate) -> User:
"""Create a new user and return domain user."""
- logger.info(
+ self.logger.info(
"Admin creating new user", extra={"admin_username": admin_username, "new_username": user_data.username}
)
# Ensure not exists
@@ -128,42 +129,35 @@ async def create_user(self, *, admin_username: str, user_data: UserCreate) -> Us
security = SecurityService()
hashed_password = security.get_password_hash(user_data.password)
- user_id = str(uuid4()) # imported where defined
- now = datetime.now(timezone.utc)
- user_doc = {
- "user_id": user_id,
- "username": user_data.username,
- "email": user_data.email,
- "hashed_password": hashed_password,
- "role": getattr(user_data, "role", UserRole.USER),
- "is_active": getattr(user_data, "is_active", True),
- "is_superuser": False,
- "created_at": now,
- "updated_at": now,
- }
- await self._users.users_collection.insert_one(user_doc)
- logger.info(
+ create_data = DomainUserCreate(
+ username=user_data.username,
+ email=user_data.email,
+ hashed_password=hashed_password,
+ role=getattr(user_data, "role", UserRole.USER),
+ is_active=getattr(user_data, "is_active", True),
+ is_superuser=False,
+ )
+ created = await self._users.create_user(create_data)
+ self.logger.info(
"User created successfully", extra={"new_username": user_data.username, "admin_username": admin_username}
)
- # Return fresh domain user
- created = await self._users.get_user_by_id(user_id)
- if not created:
- raise ValueError("Failed to fetch created user")
return created
async def get_user(self, *, admin_username: str, user_id: str) -> User | None:
- logger.info("Admin getting user details", extra={"admin_username": admin_username, "target_user_id": user_id})
+ self.logger.info(
+ "Admin getting user details", extra={"admin_username": admin_username, "target_user_id": user_id}
+ )
return await self._users.get_user_by_id(user_id)
async def update_user(self, *, admin_username: str, user_id: str, update: UserUpdate) -> User | None:
- logger.info(
+ self.logger.info(
"Admin updating user",
extra={"admin_username": admin_username, "target_user_id": user_id},
)
return await self._users.update_user(user_id, update)
async def delete_user(self, *, admin_username: str, user_id: str, cascade: bool) -> dict[str, int]:
- logger.info(
+ self.logger.info(
"Admin deleting user",
extra={"admin_username": admin_username, "target_user_id": user_id, "cascade": cascade},
)
@@ -171,21 +165,21 @@ async def delete_user(self, *, admin_username: str, user_id: str, cascade: bool)
await self._rate_limits.reset_user_limits(user_id)
deleted_counts = await self._users.delete_user(user_id, cascade=cascade)
if deleted_counts.get("user", 0) > 0:
- logger.info("User deleted successfully", extra={"target_user_id": user_id})
+ self.logger.info("User deleted successfully", extra={"target_user_id": user_id})
return deleted_counts
async def reset_user_password(self, *, admin_username: str, user_id: str, new_password: str) -> bool:
- logger.info(
+ self.logger.info(
"Admin resetting user password", extra={"admin_username": admin_username, "target_user_id": user_id}
)
pr = PasswordReset(user_id=user_id, new_password=new_password)
ok = await self._users.reset_user_password(pr)
if ok:
- logger.info("User password reset successfully", extra={"target_user_id": user_id})
+ self.logger.info("User password reset successfully", extra={"target_user_id": user_id})
return ok
async def get_user_rate_limits(self, *, admin_username: str, user_id: str) -> UserRateLimitsResult:
- logger.info(
+ self.logger.info(
"Admin getting user rate limits", extra={"admin_username": admin_username, "target_user_id": user_id}
)
user_limit = await self._rate_limits.get_user_rate_limit(user_id)
@@ -199,17 +193,16 @@ async def get_user_rate_limits(self, *, admin_username: str, user_id: str) -> Us
async def update_user_rate_limits(
self, *, admin_username: str, user_id: str, config: UserRateLimit
) -> RateLimitUpdateResult:
- mapper = UserRateLimitMapper()
- logger.info(
+ self.logger.info(
"Admin updating user rate limits",
- extra={"admin_username": admin_username, "target_user_id": user_id, "config": mapper.to_dict(config)},
+ extra={"admin_username": admin_username, "target_user_id": user_id, "config": asdict(config)},
)
config.user_id = user_id
await self._rate_limits.update_user_rate_limit(user_id, config)
return RateLimitUpdateResult(user_id=user_id, updated=True, config=config)
async def reset_user_rate_limits(self, *, admin_username: str, user_id: str) -> bool:
- logger.info(
+ self.logger.info(
"Admin resetting user rate limits", extra={"admin_username": admin_username, "target_user_id": user_id}
)
await self._rate_limits.reset_user_limits(user_id)
diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py
index 68caeb1c..50e6f98f 100644
--- a/backend/app/services/auth_service.py
+++ b/backend/app/services/auth_service.py
@@ -1,24 +1,23 @@
-from fastapi import HTTPException, Request, status
+import logging
+
+from fastapi import Request
-from app.core.logging import logger
from app.core.security import security_service
from app.db.repositories.user_repository import UserRepository
from app.domain.enums.user import UserRole
+from app.domain.user import AdminAccessRequiredError, AuthenticationRequiredError
from app.schemas_pydantic.user import UserResponse
class AuthService:
- def __init__(self, user_repo: UserRepository):
+ def __init__(self, user_repo: UserRepository, logger: logging.Logger):
self.user_repo = user_repo
+ self.logger = logger
async def get_current_user(self, request: Request) -> UserResponse:
token = request.cookies.get("access_token")
if not token:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Not authenticated",
- headers={"WWW-Authenticate": "Bearer"},
- )
+ raise AuthenticationRequiredError()
user = await security_service.get_current_user(token, self.user_repo)
@@ -35,9 +34,6 @@ async def get_current_user(self, request: Request) -> UserResponse:
async def get_admin(self, request: Request) -> UserResponse:
user = await self.get_current_user(request)
if user.role != UserRole.ADMIN:
- logger.warning(f"Admin access denied for user: {user.username} (role: {user.role})")
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Admin access required",
- )
+ self.logger.warning(f"Admin access denied for user: {user.username} (role: {user.role})")
+ raise AdminAccessRequiredError(user.username)
return user
diff --git a/backend/app/services/coordinator/coordinator.py b/backend/app/services/coordinator/coordinator.py
index b7843425..fb4cb8ba 100644
--- a/backend/app/services/coordinator/coordinator.py
+++ b/backend/app/services/coordinator/coordinator.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
import signal
import time
from collections.abc import Coroutine
@@ -6,18 +7,17 @@
from uuid import uuid4
import redis.asyncio as redis
-from motor.motor_asyncio import AsyncIOMotorClient
+from beanie import init_beanie
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
from app.core.database_context import DBClient
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics.context import get_coordinator_metrics
+from app.db.docs import ALL_DOCUMENTS
from app.db.repositories.execution_repository import ExecutionRepository
-from app.db.schema.schema_manager import SchemaManager
from app.domain.enums.events import EventType
from app.domain.enums.kafka import KafkaTopic
from app.domain.enums.storage import ExecutionErrorType
-from app.domain.execution import ResourceUsageDomain
from app.events.core import ConsumerConfig, EventDispatcher, ProducerConfig, UnifiedConsumer, UnifiedProducer
from app.events.event_store import EventStore, create_event_store
from app.events.schema.schema_registry import (
@@ -66,10 +66,12 @@ def __init__(
event_store: EventStore,
execution_repository: ExecutionRepository,
idempotency_manager: IdempotencyManager,
+ logger: logging.Logger,
consumer_group: str = "execution-coordinator",
max_concurrent_scheduling: int = 10,
scheduling_interval_seconds: float = 0.5,
):
+ self.logger = logger
self.metrics = get_coordinator_metrics()
settings = get_settings()
@@ -78,9 +80,13 @@ def __init__(
self.consumer_group = consumer_group
# Components
- self.queue_manager = QueueManager(max_queue_size=10000, max_executions_per_user=100, stale_timeout_seconds=3600)
+ self.queue_manager = QueueManager(
+ logger=self.logger, max_queue_size=10000, max_executions_per_user=100, stale_timeout_seconds=3600
+ )
- self.resource_manager = ResourceManager(total_cpu_cores=32.0, total_memory_mb=65536, total_gpu_count=0)
+ self.resource_manager = ResourceManager(
+ logger=self.logger, total_cpu_cores=32.0, total_memory_mb=65536, total_gpu_count=0
+ )
# Kafka components
self.consumer: UnifiedConsumer | None = None
@@ -103,15 +109,15 @@ def __init__(
self._active_executions: set[str] = set()
self._execution_resources: ExecutionMap = {}
self._schema_registry_manager = schema_registry_manager
- self.dispatcher = EventDispatcher()
+ self.dispatcher = EventDispatcher(logger=self.logger)
async def start(self) -> None:
"""Start the coordinator service"""
if self._running:
- logger.warning("ExecutionCoordinator already running")
+ self.logger.warning("ExecutionCoordinator already running")
return
- logger.info("Starting ExecutionCoordinator service...")
+ self.logger.info("Starting ExecutionCoordinator service...")
await self.queue_manager.start()
@@ -130,7 +136,7 @@ async def start(self) -> None:
fetch_min_bytes=1, # Return immediately if any data available
)
- self.consumer = UnifiedConsumer(consumer_config, event_dispatcher=self.dispatcher)
+ self.consumer = UnifiedConsumer(consumer_config, event_dispatcher=self.dispatcher, logger=self.logger)
# Register handlers with EventDispatcher BEFORE wrapping with idempotency
@self.dispatcher.register(EventType.EXECUTION_REQUESTED)
@@ -153,12 +159,13 @@ async def handle_cancelled(event: BaseEvent) -> None:
consumer=self.consumer,
idempotency_manager=self.idempotency_manager,
dispatcher=self.dispatcher,
+ logger=self.logger,
default_key_strategy="event_based", # Use event ID for deduplication
default_ttl_seconds=7200, # 2 hours TTL for coordinator events
enable_for_all_handlers=True, # Enable idempotency for ALL handlers
)
- logger.info("COORDINATOR: Event handlers registered with idempotency protection")
+ self.logger.info("COORDINATOR: Event handlers registered with idempotency protection")
await self.idempotent_consumer.start([KafkaTopic.EXECUTION_EVENTS])
@@ -166,14 +173,14 @@ async def handle_cancelled(event: BaseEvent) -> None:
self._running = True
self._scheduling_task = asyncio.create_task(self._scheduling_loop())
- logger.info("ExecutionCoordinator service started successfully")
+ self.logger.info("ExecutionCoordinator service started successfully")
async def stop(self) -> None:
"""Stop the coordinator service"""
if not self._running:
return
- logger.info("Stopping ExecutionCoordinator service...")
+ self.logger.info("Stopping ExecutionCoordinator service...")
self._running = False
# Stop scheduling task
@@ -194,11 +201,11 @@ async def stop(self) -> None:
if hasattr(self, "idempotency_manager") and self.idempotency_manager:
await self.idempotency_manager.close()
- logger.info(f"ExecutionCoordinator service stopped. Active executions: {len(self._active_executions)}")
+ self.logger.info(f"ExecutionCoordinator service stopped. Active executions: {len(self._active_executions)}")
async def _route_execution_event(self, event: BaseEvent) -> None:
"""Route execution events to appropriate handlers based on event type"""
- logger.info(
+ self.logger.info(
f"COORDINATOR: Routing execution event - type: {event.event_type}, "
f"id: {event.event_id}, "
f"actual class: {type(event).__name__}"
@@ -209,7 +216,7 @@ async def _route_execution_event(self, event: BaseEvent) -> None:
elif event.event_type == EventType.EXECUTION_CANCELLED:
await self._handle_execution_cancelled(event) # type: ignore
else:
- logger.debug(f"Ignoring execution event type: {event.event_type}")
+ self.logger.debug(f"Ignoring execution event type: {event.event_type}")
async def _route_execution_result(self, event: BaseEvent) -> None:
"""Route execution result events to appropriate handlers based on event type"""
@@ -218,11 +225,11 @@ async def _route_execution_result(self, event: BaseEvent) -> None:
elif event.event_type == EventType.EXECUTION_FAILED:
await self._handle_execution_failed(event) # type: ignore
else:
- logger.debug(f"Ignoring execution result event type: {event.event_type}")
+ self.logger.debug(f"Ignoring execution result event type: {event.event_type}")
async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> None:
"""Handle execution requested event - add to queue for processing"""
- logger.info(f"HANDLER CALLED: _handle_execution_requested for event {event.event_id}")
+ self.logger.info(f"HANDLER CALLED: _handle_execution_requested for event {event.event_id}")
start_time = time.time()
try:
@@ -248,10 +255,10 @@ async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> N
self.metrics.record_coordinator_scheduling_duration(duration)
self.metrics.record_coordinator_execution_scheduled("queued")
- logger.info(f"Execution {event.execution_id} added to queue at position {position}")
+ self.logger.info(f"Execution {event.execution_id} added to queue at position {position}")
except Exception as e:
- logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True)
+ self.logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True)
self.metrics.record_coordinator_execution_scheduled("error")
async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None:
@@ -268,7 +275,7 @@ async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> N
self.metrics.update_coordinator_active_executions(len(self._active_executions))
if removed:
- logger.info(f"Execution {execution_id} cancelled and removed from queue")
+ self.logger.info(f"Execution {execution_id} cancelled and removed from queue")
async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None:
"""Handle execution completed event"""
@@ -282,7 +289,7 @@ async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> N
self._active_executions.discard(execution_id)
self.metrics.update_coordinator_active_executions(len(self._active_executions))
- logger.info(f"Execution {execution_id} completed, resources released")
+ self.logger.info(f"Execution {execution_id} completed, resources released")
async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None:
"""Handle execution failed event"""
@@ -312,7 +319,7 @@ async def _scheduling_loop(self) -> None:
await asyncio.sleep(self.scheduling_interval)
except Exception as e:
- logger.error(f"Error in scheduling loop: {e}", exc_info=True)
+ self.logger.error(f"Error in scheduling loop: {e}", exc_info=True)
await asyncio.sleep(5) # Wait before retrying
async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None:
@@ -323,7 +330,7 @@ async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None:
try:
# Check if already active (shouldn't happen, but be safe)
if event.execution_id in self._active_executions:
- logger.warning(f"Execution {event.execution_id} already active, skipping")
+ self.logger.warning(f"Execution {event.execution_id} already active, skipping")
return
# Request resource allocation
@@ -338,7 +345,7 @@ async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None:
if not allocation:
# No resources available, requeue
await self.queue_manager.requeue_execution(event, increment_retry=False)
- logger.info(f"No resources available for {event.execution_id}, requeued")
+ self.logger.info(f"No resources available for {event.execution_id}, requeued")
return
# Track allocation
@@ -347,12 +354,12 @@ async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None:
self.metrics.update_coordinator_active_executions(len(self._active_executions))
# Publish execution started event for workers
- logger.info(f"About to publish ExecutionStartedEvent for {event.execution_id}")
+ self.logger.info(f"About to publish ExecutionStartedEvent for {event.execution_id}")
try:
await self._publish_execution_started(event)
- logger.info(f"Successfully published ExecutionStartedEvent for {event.execution_id}")
+ self.logger.info(f"Successfully published ExecutionStartedEvent for {event.execution_id}")
except Exception as publish_error:
- logger.error(
+ self.logger.error(
f"Failed to publish ExecutionStartedEvent for {event.execution_id}: {publish_error}",
exc_info=True,
)
@@ -367,7 +374,7 @@ async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None:
self.metrics.record_coordinator_scheduling_duration(scheduling_duration)
self.metrics.record_coordinator_execution_scheduled("scheduled")
- logger.info(
+ self.logger.info(
f"Scheduled execution {event.execution_id}. "
f"Queue time: {queue_time:.2f}s, "
f"Resources: {allocation.cpu_cores} CPU, "
@@ -375,7 +382,7 @@ async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None:
)
except Exception as e:
- logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True)
+ self.logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True)
# Release any allocated resources
if event.execution_id in self._execution_resources:
@@ -429,7 +436,7 @@ async def _publish_execution_started(self, request: ExecutionRequestedEvent) ->
async def _publish_execution_accepted(self, request: ExecutionRequestedEvent, position: int, priority: int) -> None:
"""Publish execution accepted event to notify that request was valid and queued"""
- logger.info(f"Publishing ExecutionAcceptedEvent for execution {request.execution_id}")
+ self.logger.info(f"Publishing ExecutionAcceptedEvent for execution {request.execution_id}")
event = ExecutionAcceptedEvent(
execution_id=request.execution_id,
@@ -440,7 +447,7 @@ async def _publish_execution_accepted(self, request: ExecutionRequestedEvent, po
)
await self.producer.produce(event_to_produce=event)
- logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}")
+ self.logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}")
async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str) -> None:
"""Publish queue full event"""
@@ -452,7 +459,7 @@ async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str
error_type=ExecutionErrorType.RESOURCE_LIMIT,
exit_code=-1,
stderr=f"Queue full: {error}. Queue size: {queue_stats.get('total_size', 'unknown')}",
- resource_usage=ResourceUsageDomain.from_dict({}),
+ resource_usage=None,
metadata=request.metadata,
error_message=error,
)
@@ -471,7 +478,7 @@ async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, err
stderr=f"Failed to schedule execution: {error}. "
f"Available resources: CPU={resource_stats.available.cpu_cores}, "
f"Memory={resource_stats.available.memory_mb}MB",
- resource_usage=ResourceUsageDomain.from_dict({}),
+ resource_usage=None,
metadata=request.metadata,
error_message=error,
)
@@ -490,26 +497,31 @@ async def get_status(self) -> dict[str, Any]:
async def run_coordinator() -> None:
"""Run the execution coordinator service"""
+ import os
from contextlib import AsyncExitStack
+ from app.core.logging import setup_logger
+
+ logger = setup_logger(os.environ.get("LOG_LEVEL", "INFO"))
logger.info("Initializing schema registry for coordinator...")
- schema_registry_manager = create_schema_registry_manager()
+ schema_registry_manager = create_schema_registry_manager(logger)
await initialize_event_schemas(schema_registry_manager)
settings = get_settings()
config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS)
- producer = UnifiedProducer(config, schema_registry_manager)
+ producer = UnifiedProducer(config, schema_registry_manager, logger)
- db_client: DBClient = AsyncIOMotorClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
+ db_client: DBClient = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
db_name = settings.DATABASE_NAME
database = db_client[db_name]
- await SchemaManager(database).apply_all()
+ # Initialize Beanie ODM (indexes are idempotently created via Document.Settings.indexes)
+ await init_beanie(database=database, document_models=ALL_DOCUMENTS)
logger.info("Creating event store for coordinator...")
- event_store = create_event_store(db=database, schema_registry=schema_registry_manager, ttl_days=90)
+ event_store = create_event_store(schema_registry_manager, logger, ttl_days=90)
- exec_repo = ExecutionRepository(database)
+ exec_repo = ExecutionRepository(logger)
r = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
@@ -522,7 +534,7 @@ async def run_coordinator() -> None:
socket_timeout=5,
)
idem_repo = RedisIdempotencyRepository(r, key_prefix="idempotency")
- idem_manager = create_idempotency_manager(repository=idem_repo, config=IdempotencyConfig())
+ idem_manager = create_idempotency_manager(repository=idem_repo, config=IdempotencyConfig(), logger=logger)
await idem_manager.initialize()
coordinator = ExecutionCoordinator(
@@ -531,6 +543,7 @@ async def run_coordinator() -> None:
event_store=event_store,
execution_repository=exec_repo,
idempotency_manager=idem_manager,
+ logger=logger,
)
def signal_handler(sig: int, frame: Any) -> None:
diff --git a/backend/app/services/coordinator/queue_manager.py b/backend/app/services/coordinator/queue_manager.py
index d1a9ccd1..e43ec861 100644
--- a/backend/app/services/coordinator/queue_manager.py
+++ b/backend/app/services/coordinator/queue_manager.py
@@ -1,12 +1,12 @@
import asyncio
import heapq
+import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Dict, List, Tuple
-from app.core.logging import logger
from app.core.metrics.context import get_coordinator_metrics
from app.infrastructure.kafka.events import ExecutionRequestedEvent
@@ -42,10 +42,12 @@ def age_seconds(self) -> float:
class QueueManager:
def __init__(
self,
+ logger: logging.Logger,
max_queue_size: int = 10000,
max_executions_per_user: int = 100,
stale_timeout_seconds: int = 3600,
) -> None:
+ self.logger = logger
self.metrics = get_coordinator_metrics()
self.max_queue_size = max_queue_size
self.max_executions_per_user = max_executions_per_user
@@ -64,7 +66,7 @@ async def start(self) -> None:
self._running = True
self._cleanup_task = asyncio.create_task(self._cleanup_stale_executions())
- logger.info("Queue manager started")
+ self.logger.info("Queue manager started")
async def stop(self) -> None:
if not self._running:
@@ -79,7 +81,7 @@ async def stop(self) -> None:
except asyncio.CancelledError:
pass
- logger.info(f"Queue manager stopped. Final queue size: {len(self._queue)}")
+ self.logger.info(f"Queue manager stopped. Final queue size: {len(self._queue)}")
async def add_execution(
self, event: ExecutionRequestedEvent, priority: QueuePriority | None = None
@@ -105,7 +107,7 @@ async def add_execution(
# Update single authoritative metric for execution request queue depth
self.metrics.update_execution_request_queue_size(len(self._queue))
- logger.info(
+ self.logger.info(
f"Added execution {event.execution_id} to queue. "
f"Priority: {priority.name}, Position: {position}, "
f"Queue size: {len(self._queue)}"
@@ -128,7 +130,7 @@ async def get_next_execution(self) -> ExecutionRequestedEvent | None:
# Update metric after removal from the queue
self.metrics.update_execution_request_queue_size(len(self._queue))
- logger.info(
+ self.logger.info(
f"Retrieved execution {queued.execution_id} from queue. "
f"Wait time: {queued.age_seconds:.2f}s, Queue size: {len(self._queue)}"
)
@@ -147,7 +149,7 @@ async def remove_execution(self, execution_id: str) -> bool:
self._untrack_execution(execution_id)
# Update metric after explicit removal
self.metrics.update_execution_request_queue_size(len(self._queue))
- logger.info(f"Removed execution {execution_id} from queue")
+ self.logger.info(f"Removed execution {execution_id} from queue")
return True
return False
@@ -262,7 +264,7 @@ async def _cleanup_stale_executions(self) -> None:
# Update metric after stale cleanup
self.metrics.update_execution_request_queue_size(len(self._queue))
- logger.info(f"Cleaned {len(stale_executions)} stale executions from queue")
+ self.logger.info(f"Cleaned {len(stale_executions)} stale executions from queue")
except Exception as e:
- logger.error(f"Error in queue cleanup: {e}")
+ self.logger.error(f"Error in queue cleanup: {e}")
diff --git a/backend/app/services/coordinator/resource_manager.py b/backend/app/services/coordinator/resource_manager.py
index 8bfe9478..8910852f 100644
--- a/backend/app/services/coordinator/resource_manager.py
+++ b/backend/app/services/coordinator/resource_manager.py
@@ -1,8 +1,8 @@
import asyncio
+import logging
from dataclasses import dataclass
from typing import Dict, List
-from app.core.logging import logger
from app.core.metrics.context import get_coordinator_metrics
@@ -85,11 +85,13 @@ class ResourceManager:
def __init__(
self,
+ logger: logging.Logger,
total_cpu_cores: float = 32.0,
total_memory_mb: int = 65536, # 64GB
total_gpu_count: int = 0,
overcommit_factor: float = 1.2, # Allow 20% overcommit
):
+ self.logger = logger
self.metrics = get_coordinator_metrics()
self.pool = ResourcePool(
total_cpu_cores=total_cpu_cores * overcommit_factor,
@@ -147,7 +149,7 @@ async def request_allocation(
async with self._allocation_lock:
# Check if already allocated
if execution_id in self._allocations:
- logger.warning(f"Execution {execution_id} already has allocation")
+ self.logger.warning(f"Execution {execution_id} already has allocation")
return self._allocations[execution_id]
# Determine requested resources
@@ -172,7 +174,7 @@ async def request_allocation(
or memory_after < self.pool.min_available_memory_mb
or gpu_after < 0
):
- logger.warning(
+ self.logger.warning(
f"Insufficient resources for execution {execution_id}. "
f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB RAM, "
f"{requested_gpu} GPU. Available: {self.pool.available_cpu_cores} CPU, "
@@ -196,7 +198,7 @@ async def request_allocation(
# Update metrics
self._update_metrics()
- logger.info(
+ self.logger.info(
f"Allocated resources for execution {execution_id}: "
f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, "
f"{allocation.gpu_count} GPU"
@@ -208,7 +210,7 @@ async def release_allocation(self, execution_id: str) -> bool:
"""Release resource allocation"""
async with self._allocation_lock:
if execution_id not in self._allocations:
- logger.warning(f"No allocation found for execution {execution_id}")
+ self.logger.warning(f"No allocation found for execution {execution_id}")
return False
allocation = self._allocations[execution_id]
@@ -224,7 +226,7 @@ async def release_allocation(self, execution_id: str) -> bool:
# Update metrics
self._update_metrics()
- logger.info(
+ self.logger.info(
f"Released resources for execution {execution_id}: "
f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, "
f"{allocation.gpu_count} GPU"
diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py
index a0449c5f..78c875ca 100644
--- a/backend/app/services/event_bus.py
+++ b/backend/app/services/event_bus.py
@@ -1,6 +1,7 @@
import asyncio
import fnmatch
import json
+import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Optional
@@ -10,7 +11,6 @@
from fastapi import Request
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics.context import get_connection_metrics
from app.domain.enums.kafka import KafkaTopic
from app.settings import get_settings
@@ -45,7 +45,8 @@ class EventBus(LifecycleEnabled):
- *.completed - matches all completed events
"""
- def __init__(self) -> None:
+ def __init__(self, logger: logging.Logger) -> None:
+ self.logger = logger
self.settings = get_settings()
self.metrics = get_connection_metrics()
self.producer: Optional[Producer] = None
@@ -68,7 +69,7 @@ async def start(self) -> None:
await self._initialize_kafka()
self._consumer_task = asyncio.create_task(self._kafka_listener())
self._running = True
- logger.info("Event bus started with Kafka backing")
+ self.logger.info("Event bus started with Kafka backing")
async def _initialize_kafka(self) -> None:
"""Initialize Kafka producer and consumer."""
@@ -101,7 +102,7 @@ async def _initialize_kafka(self) -> None:
async def stop(self) -> None:
"""Stop the event bus and clean up resources."""
await self._cleanup()
- logger.info("Event bus stopped")
+ self.logger.info("Event bus stopped")
async def _cleanup(self) -> None:
"""Clean up all resources."""
@@ -157,7 +158,7 @@ async def publish(self, event_type: str, data: dict[str, Any]) -> None:
self.producer.produce(self._topic, value=value, key=key)
self.producer.poll(0)
except Exception as e:
- logger.error(f"Failed to publish to Kafka: {e}")
+ self.logger.error(f"Failed to publish to Kafka: {e}")
# Publish to local subscribers for immediate handling
await self._distribute_event(event_type, event)
@@ -196,7 +197,7 @@ async def subscribe(self, pattern: str, handler: Callable[[EventBusEvent], Any])
# Update metrics
self._update_metrics(pattern)
- logger.debug(f"Created subscription {subscription.id} for pattern: {pattern}")
+ self.logger.debug(f"Created subscription {subscription.id} for pattern: {pattern}")
return subscription.id
async def unsubscribe(self, pattern: str, handler: Callable[[EventBusEvent], Any]) -> None:
@@ -208,12 +209,12 @@ async def unsubscribe(self, pattern: str, handler: Callable[[EventBusEvent], Any
await self._remove_subscription(sub_id)
return
- logger.warning(f"No subscription found for pattern {pattern} with given handler")
+ self.logger.warning(f"No subscription found for pattern {pattern} with given handler")
async def _remove_subscription(self, subscription_id: str) -> None:
"""Remove a subscription by ID (must be called within lock)."""
if subscription_id not in self._subscriptions:
- logger.warning(f"Subscription {subscription_id} not found")
+ self.logger.warning(f"Subscription {subscription_id} not found")
return
subscription = self._subscriptions[subscription_id]
@@ -231,7 +232,7 @@ async def _remove_subscription(self, subscription_id: str) -> None:
# Update metrics
self._update_metrics(pattern)
- logger.debug(f"Removed subscription {subscription_id} for pattern: {pattern}")
+ self.logger.debug(f"Removed subscription {subscription_id} for pattern: {pattern}")
async def _distribute_event(self, event_type: str, event: EventBusEvent) -> None:
"""Distribute event to all matching local subscribers."""
@@ -249,7 +250,7 @@ async def _distribute_event(self, event_type: str, event: EventBusEvent) -> None
# Log any errors
for _i, result in enumerate(results):
if isinstance(result, Exception):
- logger.error(f"Handler failed for event {event_type}: {result}")
+ self.logger.error(f"Handler failed for event {event_type}: {result}")
async def _find_matching_handlers(self, event_type: str) -> list[Callable[[EventBusEvent], Any]]:
"""Find all handlers matching the event type."""
@@ -274,7 +275,7 @@ async def _kafka_listener(self) -> None:
if not self.consumer:
return
- logger.info("Kafka listener started")
+ self.logger.info("Kafka listener started")
try:
while self._running:
@@ -291,7 +292,7 @@ async def _kafka_listener(self) -> None:
if msg.error():
if msg.error().code() != KafkaError._PARTITION_EOF:
- logger.error(f"Consumer error: {msg.error()}")
+ self.logger.error(f"Consumer error: {msg.error()}")
continue
try:
@@ -305,12 +306,12 @@ async def _kafka_listener(self) -> None:
)
await self._distribute_event(event.event_type, event)
except Exception as e:
- logger.error(f"Error processing Kafka message: {e}")
+ self.logger.error(f"Error processing Kafka message: {e}")
except asyncio.CancelledError:
- logger.info("Kafka listener cancelled")
+ self.logger.info("Kafka listener cancelled")
except Exception as e:
- logger.error(f"Fatal error in Kafka listener: {e}")
+ self.logger.error(f"Fatal error in Kafka listener: {e}")
self._running = False
def _update_metrics(self, pattern: str) -> None:
@@ -334,7 +335,8 @@ async def get_statistics(self) -> dict[str, Any]:
class EventBusManager:
"""Manages EventBus lifecycle as a singleton."""
- def __init__(self) -> None:
+ def __init__(self, logger: logging.Logger) -> None:
+ self.logger = logger
self._event_bus: Optional[EventBus] = None
self._lock = asyncio.Lock()
@@ -342,7 +344,7 @@ async def get_event_bus(self) -> EventBus:
"""Get or create the event bus instance."""
async with self._lock:
if self._event_bus is None:
- self._event_bus = EventBus()
+ self._event_bus = EventBus(self.logger)
await self._event_bus.start()
return self._event_bus
diff --git a/backend/app/services/event_replay/replay_service.py b/backend/app/services/event_replay/replay_service.py
index 1f87dc49..aa00c7ec 100644
--- a/backend/app/services/event_replay/replay_service.py
+++ b/backend/app/services/event_replay/replay_service.py
@@ -1,13 +1,13 @@
import asyncio
import inspect
import json
+import logging
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator, Callable, Dict, List
from uuid import uuid4
from opentelemetry.trace import SpanKind
-from app.core.logging import logger
from app.core.metrics import ReplayMetrics
from app.core.tracing.utils import trace_span
from app.db.repositories.replay_repository import ReplayRepository
@@ -25,22 +25,27 @@ def __init__(
repository: ReplayRepository,
producer: UnifiedProducer,
event_store: EventStore,
+ logger: logging.Logger,
) -> None:
self._sessions: Dict[str, ReplaySessionState] = {}
self._active_tasks: Dict[str, asyncio.Task[None]] = {}
self._repository = repository
self._producer = producer
self._event_store = event_store
+ self.logger = logger
self._callbacks: Dict[ReplayTarget, Callable[..., Any]] = {}
self._file_locks: Dict[str, asyncio.Lock] = {}
self._metrics = ReplayMetrics()
- logger.info("Event replay service initialized")
+ self.logger.info("Event replay service initialized")
async def create_replay_session(self, config: ReplayConfig) -> str:
state = ReplaySessionState(session_id=str(uuid4()), config=config)
self._sessions[state.session_id] = state
- logger.info(f"Created replay session {state.session_id} type={config.replay_type} target={config.target}")
+ self.logger.info(
+ "Created replay session",
+ extra={"session_id": state.session_id, "type": config.replay_type, "target": config.target},
+ )
return state.session_id
@@ -59,7 +64,7 @@ async def start_replay(self, session_id: str) -> None:
session.started_at = datetime.now(timezone.utc)
self._metrics.increment_active_replays()
- logger.info(f"Started replay session {session_id}")
+ self.logger.info("Started replay session", extra={"session_id": session_id})
async def _run_replay(self, session: ReplaySessionState) -> None:
start_time = asyncio.get_event_loop().time()
@@ -95,7 +100,10 @@ async def _prepare_session(self, session: ReplaySessionState) -> None:
total_count = await self._repository.count_events(session.config.filter)
session.total_events = min(total_count, session.config.max_events or total_count)
- logger.info(f"Replay session {session.session_id} will process {session.total_events} events")
+ self.logger.info(
+ "Replay session will process events",
+ extra={"session_id": session.session_id, "total_events": session.total_events},
+ )
async def _handle_progress_callback(self, session: ReplaySessionState) -> None:
cb = session.config.get_progress_callback()
@@ -105,7 +113,7 @@ async def _handle_progress_callback(self, session: ReplaySessionState) -> None:
if inspect.isawaitable(result):
await result
except Exception as e:
- logger.error(f"Progress callback error: {e}")
+ self.logger.error(f"Progress callback error: {e}")
async def _complete_session(self, session: ReplaySessionState, start_time: float) -> None:
session.status = ReplayStatus.COMPLETED
@@ -116,16 +124,23 @@ async def _complete_session(self, session: ReplaySessionState, start_time: float
await self._update_session_in_db(session)
- logger.info(
- f"Replay session {session.session_id} completed. "
- f"Replayed: {session.replayed_events}, "
- f"Failed: {session.failed_events}, "
- f"Skipped: {session.skipped_events}, "
- f"Duration: {duration:.2f}s"
+ self.logger.info(
+ "Replay session completed",
+ extra={
+ "session_id": session.session_id,
+ "replayed_events": session.replayed_events,
+ "failed_events": session.failed_events,
+ "skipped_events": session.skipped_events,
+ "duration_seconds": round(duration, 2),
+ },
)
async def _handle_session_error(self, session: ReplaySessionState, error: Exception) -> None:
- logger.error(f"Replay session {session.session_id} failed: {error}", exc_info=True)
+ self.logger.error(
+ "Replay session failed",
+ extra={"session_id": session.session_id, "error": str(error)},
+ exc_info=True,
+ )
session.status = ReplayStatus.FAILED
session.completed_at = datetime.now(timezone.utc)
session.errors.append(
@@ -152,7 +167,7 @@ def _update_replay_metrics(self, session: ReplaySessionState, event: BaseEvent,
self._metrics.record_event_replayed(session.config.replay_type, event.event_type, status)
async def _handle_replay_error(self, session: ReplaySessionState, event: BaseEvent, error: Exception) -> None:
- logger.error(f"Failed to replay event {event.event_id}: {error}")
+ self.logger.error("Failed to replay event", extra={"event_id": event.event_id, "error": str(error)})
session.failed_events += 1
session.errors.append(
{"timestamp": datetime.now(timezone.utc).isoformat(), "event_id": str(event.event_id), "error": str(error)}
@@ -176,13 +191,13 @@ async def _replay_to_callback(self, event: BaseEvent, session: ReplaySessionStat
async def _replay_to_file(self, event: BaseEvent, file_path: str | None) -> bool:
if not file_path:
- logger.error("No target file path specified")
+ self.logger.error("No target file path specified")
return False
await self._write_event_to_file(event, file_path)
return True
async def _fetch_event_batches(self, session: ReplaySessionState) -> AsyncIterator[List[BaseEvent]]:
- logger.info(f"Fetching events for session {session.session_id}")
+ self.logger.info("Fetching events for session", extra={"session_id": session.session_id})
events_processed = 0
max_events = session.config.max_events
@@ -248,10 +263,13 @@ async def _replay_event(self, session: ReplaySessionState, event: BaseEvent) ->
elif config.target == ReplayTarget.TEST:
return True
else:
- logger.error(f"Unknown replay target: {config.target}")
+ self.logger.error("Unknown replay target", extra={"target": config.target})
return False
except Exception as e:
- logger.error(f"Failed to replay event (attempt {attempt + 1}/{attempts}): {e}")
+ self.logger.error(
+ "Failed to replay event",
+ extra={"attempt": attempt + 1, "max_attempts": attempts, "error": str(e)},
+ )
if attempt < attempts - 1:
await asyncio.sleep(min(2**attempt, 10))
continue
@@ -277,7 +295,7 @@ async def pause_replay(self, session_id: str) -> None:
if session.status == ReplayStatus.RUNNING:
session.status = ReplayStatus.PAUSED
- logger.info(f"Paused replay session {session_id}")
+ self.logger.info("Paused replay session", extra={"session_id": session_id})
async def resume_replay(self, session_id: str) -> None:
session = self._sessions.get(session_id)
@@ -286,7 +304,7 @@ async def resume_replay(self, session_id: str) -> None:
if session.status == ReplayStatus.PAUSED:
session.status = ReplayStatus.RUNNING
- logger.info(f"Resumed replay session {session_id}")
+ self.logger.info("Resumed replay session", extra={"session_id": session_id})
async def cancel_replay(self, session_id: str) -> None:
session = self._sessions.get(session_id)
@@ -299,7 +317,7 @@ async def cancel_replay(self, session_id: str) -> None:
if task and not task.done():
task.cancel()
- logger.info(f"Cancelled replay session {session_id}")
+ self.logger.info("Cancelled replay session", extra={"session_id": session_id})
def get_session(self, session_id: str) -> ReplaySessionState | None:
return self._sessions.get(session_id)
@@ -328,7 +346,7 @@ async def cleanup_old_sessions(self, older_than_hours: int = 24) -> int:
del self._sessions[session_id]
removed += 1
- logger.info(f"Cleaned up {removed} old replay sessions")
+ self.logger.info("Cleaned up old replay sessions", extra={"removed_count": removed})
return removed
async def _update_session_in_db(self, session: ReplaySessionState) -> None:
@@ -345,4 +363,4 @@ async def _update_session_in_db(self, session: ReplaySessionState) -> None:
# If needed, add it to the domain model
await self._repository.update_replay_session(session_id=session.session_id, updates=session_update)
except Exception as e:
- logger.error(f"Failed to update session in database: {e}")
+ self.logger.error(f"Failed to update session in database: {e}")
diff --git a/backend/app/services/event_service.py b/backend/app/services/event_service.py
index 77211181..e7dce14f 100644
--- a/backend/app/services/event_service.py
+++ b/backend/app/services/event_service.py
@@ -1,10 +1,7 @@
from datetime import datetime
from typing import Any
-from pymongo import ASCENDING, DESCENDING
-
from app.db.repositories.event_repository import EventRepository
-from app.domain.enums.common import SortOrder
from app.domain.enums.events import EventType
from app.domain.enums.user import UserRole
from app.domain.events import (
@@ -15,7 +12,37 @@
EventReplayInfo,
EventStatistics,
)
-from app.infrastructure.mappers import EventFilterMapper
+
+
+def _filter_to_mongo_query(flt: EventFilter) -> dict[str, Any]:
+ """Convert EventFilter to MongoDB query dict."""
+ query: dict[str, Any] = {}
+
+ if flt.event_types:
+ query["event_type"] = {"$in": flt.event_types}
+ if flt.aggregate_id:
+ query["aggregate_id"] = flt.aggregate_id
+ if flt.correlation_id:
+ query["metadata.correlation_id"] = flt.correlation_id
+ if flt.user_id:
+ query["metadata.user_id"] = flt.user_id
+ if flt.service_name:
+ query["metadata.service_name"] = flt.service_name
+ if getattr(flt, "status", None):
+ query["status"] = flt.status
+
+ if flt.start_time or flt.end_time:
+ time_query: dict[str, Any] = {}
+ if flt.start_time:
+ time_query["$gte"] = flt.start_time
+ if flt.end_time:
+ time_query["$lte"] = flt.end_time
+ query["timestamp"] = time_query
+
+ if flt.search_text:
+ query["$text"] = {"$search": flt.search_text}
+
+ return query
class EventService:
@@ -84,7 +111,6 @@ async def query_events_advanced(
user_role: UserRole,
filters: EventFilter,
sort_by: str = "timestamp",
- sort_order: SortOrder = SortOrder.DESC,
limit: int = 100,
skip: int = 0,
) -> EventListResult | None:
@@ -92,7 +118,7 @@ async def query_events_advanced(
if filters.user_id and filters.user_id != user_id and user_role != UserRole.ADMIN:
return None
- query = EventFilterMapper.to_mongo_query(filters)
+ query = _filter_to_mongo_query(filters)
if not filters.user_id and user_role != UserRole.ADMIN:
query["metadata.user_id"] = user_id
@@ -105,13 +131,10 @@ async def query_events_advanced(
"stored_at": "stored_at",
}
sort_field = field_map.get(sort_by, "timestamp")
- direction = DESCENDING if sort_order == SortOrder.DESC else ASCENDING
- # Pagination and sorting from request
- return await self.repository.query_events_generic(
+ return await self.repository.query_events(
query=query,
sort_field=sort_field,
- sort_direction=direction,
skip=skip,
limit=limit,
)
@@ -146,10 +169,10 @@ async def get_event_statistics(
include_all_users: bool = False,
) -> EventStatistics:
match = {} if include_all_users else self._build_user_filter(user_id, user_role)
- return await self.repository.get_event_statistics_filtered(
- match=match,
+ return await self.repository.get_event_statistics(
start_time=start_time,
end_time=end_time,
+ match=match or None,
)
async def get_event(
diff --git a/backend/app/services/execution_service.py b/backend/app/services/execution_service.py
index 33001a73..09ca7922 100644
--- a/backend/app/services/execution_service.py
+++ b/backend/app/services/execution_service.py
@@ -1,20 +1,21 @@
+import logging
from contextlib import contextmanager
from datetime import datetime
from time import time
from typing import Any, Generator, TypeAlias
from app.core.correlation import CorrelationContext
-from app.core.exceptions import IntegrationException, ServiceError
-from app.core.logging import logger
from app.core.metrics.context import get_execution_metrics
from app.db.repositories.execution_repository import ExecutionRepository
from app.domain.enums.events import EventType
from app.domain.enums.execution import ExecutionStatus
+from app.domain.exceptions import InfrastructureError
from app.domain.execution import (
DomainExecution,
+ DomainExecutionCreate,
+ ExecutionNotFoundError,
ExecutionResultDomain,
ResourceLimitsDomain,
- ResourceUsageDomain,
)
from app.events.core import UnifiedProducer
from app.events.event_store import EventStore
@@ -50,6 +51,7 @@ def __init__(
producer: UnifiedProducer,
event_store: EventStore,
settings: Settings,
+ logger: logging.Logger,
) -> None:
"""
Initialize execution service.
@@ -59,11 +61,13 @@ def __init__(
producer: Kafka producer for publishing events.
event_store: Event store for event persistence.
settings: Application settings.
+ logger: Logger instance.
"""
self.execution_repo = execution_repo
self.producer = producer
self.event_store = event_store
self.settings = settings
+ self.logger = logger
self.metrics = get_execution_metrics()
@contextmanager
@@ -143,13 +147,13 @@ async def execute_script(
DomainExecution record with queued status.
Raises:
- IntegrationException: If validation fails or event publishing fails.
+ InfrastructureError: If validation fails or event publishing fails.
"""
lang_and_version = f"{lang}-{lang_version}"
start_time = time()
# Log incoming request
- logger.info(
+ self.logger.info(
"Received script execution request",
extra={
"lang": lang,
@@ -165,16 +169,15 @@ async def execute_script(
with self._track_active_execution():
# Create execution record
created_execution = await self.execution_repo.create_execution(
- DomainExecution(
+ DomainExecutionCreate(
script=script,
lang=lang,
lang_version=lang_version,
- status=ExecutionStatus.QUEUED,
user_id=user_id,
)
)
- logger.info(
+ self.logger.info(
"Created execution record",
extra={
"execution_id": str(created_execution.execution_id),
@@ -215,13 +218,13 @@ async def execute_script(
created_execution.execution_id,
f"Failed to submit execution: {str(e)}",
)
- raise IntegrationException(status_code=500, detail="Failed to submit execution request") from e
+ raise InfrastructureError("Failed to submit execution request") from e
# Success metrics and return
duration = time() - start_time
self.metrics.record_script_execution(ExecutionStatus.QUEUED, lang_and_version)
self.metrics.record_queue_wait_time(duration, lang_and_version)
- logger.info(
+ self.logger.info(
"Script execution submitted successfully",
extra={
"execution_id": str(created_execution.execution_id),
@@ -238,7 +241,7 @@ async def _update_execution_error(self, execution_id: str, error_message: str) -
exit_code=-1,
stdout="",
stderr=error_message,
- resource_usage=ResourceUsageDomain(0.0, 0, 0, 0),
+ resource_usage=None,
metadata={},
)
await self.execution_repo.write_terminal_result(result)
@@ -258,14 +261,14 @@ async def get_execution_result(self, execution_id: str) -> DomainExecution:
Current execution state.
Raises:
- IntegrationException: If execution not found.
+ ExecutionNotFoundError: If execution not found.
"""
execution = await self.execution_repo.get_execution(execution_id)
if not execution:
- logger.warning("Execution not found", extra={"execution_id": execution_id})
- raise IntegrationException(status_code=404, detail=f"Execution {execution_id} not found")
+ self.logger.warning("Execution not found", extra={"execution_id": execution_id})
+ raise ExecutionNotFoundError(execution_id)
- logger.info(
+ self.logger.info(
"Execution result retrieved successfully",
extra={
"execution_id": execution_id,
@@ -304,7 +307,7 @@ async def get_execution_events(
if len(events) > limit:
events = events[:limit]
- logger.debug(
+ self.logger.debug(
f"Retrieved {len(events)} events for execution {execution_id}",
extra={
"execution_id": execution_id,
@@ -346,7 +349,7 @@ async def get_user_executions(
query=query, limit=limit, skip=skip, sort=[("created_at", -1)]
)
- logger.debug(
+ self.logger.debug(
f"Retrieved {len(executions)} executions for user",
extra={
"user_id": str(user_id),
@@ -435,10 +438,10 @@ async def delete_execution(self, execution_id: str) -> bool:
deleted = await self.execution_repo.delete_execution(execution_id)
if not deleted:
- logger.warning(f"Execution {execution_id} not found for deletion")
- raise ServiceError("Execution not found", status_code=404)
+ self.logger.warning("Execution not found for deletion", extra={"execution_id": execution_id})
+ raise ExecutionNotFoundError(execution_id)
- logger.info("Deleted execution", extra={"execution_id": execution_id})
+ self.logger.info("Deleted execution", extra={"execution_id": execution_id})
await self._publish_deletion_event(execution_id)
@@ -459,7 +462,7 @@ async def _publish_deletion_event(self, execution_id: str) -> None:
await self.producer.produce(event_to_produce=event, key=execution_id)
- logger.info(
+ self.logger.info(
"Published cancellation event",
extra={
"execution_id": execution_id,
diff --git a/backend/app/services/grafana_alert_processor.py b/backend/app/services/grafana_alert_processor.py
index 64e0181f..a78d6d6c 100644
--- a/backend/app/services/grafana_alert_processor.py
+++ b/backend/app/services/grafana_alert_processor.py
@@ -1,8 +1,8 @@
"""Grafana alert processing service."""
+import logging
from typing import Any
-from app.core.logging import logger
from app.domain.enums.notification import NotificationSeverity
from app.schemas_pydantic.grafana import GrafanaAlertItem, GrafanaWebhook
from app.services.notification_service import NotificationService
@@ -23,10 +23,11 @@ class GrafanaAlertProcessor:
DEFAULT_TITLE = "Grafana Alert"
DEFAULT_MESSAGE = "Alert triggered"
- def __init__(self, notification_service: NotificationService) -> None:
+ def __init__(self, notification_service: NotificationService, logger: logging.Logger) -> None:
"""Initialize the processor with required services."""
self.notification_service = notification_service
- logger.info("GrafanaAlertProcessor initialized")
+ self.logger = logger
+ self.logger.info("GrafanaAlertProcessor initialized")
@classmethod
def extract_severity(cls, alert: GrafanaAlertItem, webhook: GrafanaWebhook) -> str:
@@ -103,7 +104,7 @@ async def process_single_alert(
except Exception as e:
error_msg = f"Failed to process Grafana alert: {e}"
- logger.error(error_msg, extra={"correlation_id": correlation_id}, exc_info=True)
+ self.logger.error(error_msg, extra={"correlation_id": correlation_id}, exc_info=True)
return False, error_msg
async def process_webhook(self, webhook_payload: GrafanaWebhook, correlation_id: str) -> tuple[int, list[str]]:
@@ -120,7 +121,7 @@ async def process_webhook(self, webhook_payload: GrafanaWebhook, correlation_id:
errors: list[str] = []
processed_count = 0
- logger.info(
+ self.logger.info(
"Processing Grafana webhook",
extra={
"correlation_id": correlation_id,
@@ -136,7 +137,7 @@ async def process_webhook(self, webhook_payload: GrafanaWebhook, correlation_id:
elif error_msg:
errors.append(error_msg)
- logger.info(
+ self.logger.info(
"Grafana webhook processing completed",
extra={
"correlation_id": correlation_id,
diff --git a/backend/app/services/idempotency/idempotency_manager.py b/backend/app/services/idempotency/idempotency_manager.py
index f06467c1..a49a4c62 100644
--- a/backend/app/services/idempotency/idempotency_manager.py
+++ b/backend/app/services/idempotency/idempotency_manager.py
@@ -1,13 +1,13 @@
import asyncio
import hashlib
import json
+import logging
from datetime import datetime, timedelta, timezone
from typing import Protocol
from pydantic import BaseModel
from pymongo.errors import DuplicateKeyError
-from app.core.logging import logger
from app.core.metrics.context import get_database_metrics
from app.domain.idempotency import IdempotencyRecord, IdempotencyStats, IdempotencyStatus
from app.infrastructure.kafka.events import BaseEvent
@@ -67,16 +67,17 @@ async def health_check(self) -> None: ...
class IdempotencyManager:
- def __init__(self, config: IdempotencyConfig, repository: IdempotencyRepoProtocol) -> None:
+ def __init__(self, config: IdempotencyConfig, repository: IdempotencyRepoProtocol, logger: logging.Logger) -> None:
self.config = config
self.metrics = get_database_metrics()
self._repo: IdempotencyRepoProtocol = repository
self._stats_update_task: asyncio.Task[None] | None = None
+ self.logger = logger
async def initialize(self) -> None:
if self.config.enable_metrics and self._stats_update_task is None:
self._stats_update_task = asyncio.create_task(self._update_stats_loop())
- logger.info("Idempotency manager ready")
+ self.logger.info("Idempotency manager ready")
async def close(self) -> None:
if self._stats_update_task:
@@ -85,7 +86,7 @@ async def close(self) -> None:
await self._stats_update_task
except asyncio.CancelledError:
pass
- logger.info("Closed idempotency manager")
+ self.logger.info("Closed idempotency manager")
def _generate_key(
self, event: BaseEvent, key_strategy: str, custom_key: str | None = None, fields: set[str] | None = None
@@ -152,7 +153,7 @@ async def _handle_processing_key(
now = datetime.now(timezone.utc)
if now - created_at > timedelta(seconds=self.config.processing_timeout_seconds):
- logger.warning(f"Idempotency key {full_key} processing timeout, allowing retry")
+ self.logger.warning(f"Idempotency key {full_key} processing timeout, allowing retry")
existing.created_at = now
existing.status = IdempotencyStatus.PROCESSING
await self._repo.update_record(existing)
@@ -214,7 +215,7 @@ async def _update_key_status(
if len(cached_json.encode()) <= self.config.max_result_size_bytes:
existing.result_json = cached_json
else:
- logger.warning(f"Result too large to cache for key {full_key}")
+ self.logger.warning(f"Result too large to cache for key {full_key}")
return (await self._repo.update_record(existing)) > 0
async def mark_completed(
@@ -228,10 +229,10 @@ async def mark_completed(
try:
existing = await self._repo.find_by_key(full_key)
except Exception as e: # Narrow DB op
- logger.error(f"Failed to load idempotency key for completion: {e}")
+ self.logger.error(f"Failed to load idempotency key for completion: {e}")
return False
if not existing:
- logger.warning(f"Idempotency key {full_key} not found when marking completed")
+ self.logger.warning(f"Idempotency key {full_key} not found when marking completed")
return False
# mark_completed does not accept arbitrary result today; use mark_completed_with_cache for cached payloads
return await self._update_key_status(full_key, existing, IdempotencyStatus.COMPLETED, cached_json=None)
@@ -247,7 +248,7 @@ async def mark_failed(
full_key = self._generate_key(event, key_strategy, custom_key, fields)
existing = await self._repo.find_by_key(full_key)
if not existing:
- logger.warning(f"Idempotency key {full_key} not found when marking failed")
+ self.logger.warning(f"Idempotency key {full_key} not found when marking failed")
return False
return await self._update_key_status(
full_key, existing, IdempotencyStatus.FAILED, cached_json=None, error=error
@@ -264,7 +265,7 @@ async def mark_completed_with_json(
full_key = self._generate_key(event, key_strategy, custom_key, fields)
existing = await self._repo.find_by_key(full_key)
if not existing:
- logger.warning(f"Idempotency key {full_key} not found when marking completed with cache")
+ self.logger.warning(f"Idempotency key {full_key} not found when marking completed with cache")
return False
return await self._update_key_status(full_key, existing, IdempotencyStatus.COMPLETED, cached_json=cached_json)
@@ -288,7 +289,7 @@ async def remove(
deleted = await self._repo.delete_key(full_key)
return deleted > 0
except Exception as e:
- logger.error(f"Failed to remove idempotency key: {e}")
+ self.logger.error(f"Failed to remove idempotency key: {e}")
return False
async def get_stats(self) -> IdempotencyStats:
@@ -310,7 +311,7 @@ async def _update_stats_loop(self) -> None:
except asyncio.CancelledError:
break
except Exception as e:
- logger.error(f"Failed to update idempotency stats: {e}")
+ self.logger.error(f"Failed to update idempotency stats: {e}")
await asyncio.sleep(300)
@@ -318,5 +319,6 @@ def create_idempotency_manager(
*,
repository: IdempotencyRepoProtocol,
config: IdempotencyConfig | None = None,
+ logger: logging.Logger,
) -> IdempotencyManager:
- return IdempotencyManager(config or IdempotencyConfig(), repository)
+ return IdempotencyManager(config or IdempotencyConfig(), repository, logger)
diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py
index b31afe7b..fe6e3a9e 100644
--- a/backend/app/services/idempotency/middleware.py
+++ b/backend/app/services/idempotency/middleware.py
@@ -1,9 +1,9 @@
"""Idempotent event processing middleware"""
import asyncio
+import logging
from typing import Any, Awaitable, Callable, Dict, Set
-from app.core.logging import logger
from app.domain.enums.events import EventType
from app.domain.enums.kafka import KafkaTopic
from app.events.core import EventDispatcher, UnifiedConsumer
@@ -18,6 +18,7 @@ def __init__(
self,
handler: Callable[[BaseEvent], Awaitable[None]],
idempotency_manager: IdempotencyManager,
+ logger: logging.Logger,
key_strategy: str = "event_based",
custom_key_func: Callable[[BaseEvent], str] | None = None,
fields: Set[str] | None = None,
@@ -27,6 +28,7 @@ def __init__(
):
self.handler = handler
self.idempotency_manager = idempotency_manager
+ self.logger = logger
self.key_strategy = key_strategy
self.custom_key_func = custom_key_func
self.fields = fields
@@ -36,7 +38,7 @@ def __init__(
async def __call__(self, event: BaseEvent) -> None:
"""Process event with idempotency check"""
- logger.info(
+ self.logger.info(
f"IdempotentEventHandler called for event {event.event_type}, "
f"id={event.event_id}, handler={self.handler.__name__}"
)
@@ -56,7 +58,7 @@ async def __call__(self, event: BaseEvent) -> None:
if idempotency_result.is_duplicate:
# Handle duplicate
- logger.info(
+ self.logger.info(
f"Duplicate event detected: {event.event_type} ({event.event_id}), status: {idempotency_result.status}"
)
@@ -90,6 +92,7 @@ async def __call__(self, event: BaseEvent) -> None:
def idempotent_handler(
idempotency_manager: IdempotencyManager,
+ logger: logging.Logger,
key_strategy: str = "event_based",
custom_key_func: Callable[[BaseEvent], str] | None = None,
fields: Set[str] | None = None,
@@ -103,6 +106,7 @@ def decorator(func: Callable[[BaseEvent], Awaitable[None]]) -> Callable[[BaseEve
handler = IdempotentEventHandler(
handler=func,
idempotency_manager=idempotency_manager,
+ logger=logger,
key_strategy=key_strategy,
custom_key_func=custom_key_func,
fields=fields,
@@ -123,6 +127,7 @@ def __init__(
consumer: UnifiedConsumer,
idempotency_manager: IdempotencyManager,
dispatcher: EventDispatcher,
+ logger: logging.Logger,
default_key_strategy: str = "event_based",
default_ttl_seconds: int = 3600,
enable_for_all_handlers: bool = True,
@@ -130,6 +135,7 @@ def __init__(
self.consumer = consumer
self.idempotency_manager = idempotency_manager
self.dispatcher = dispatcher
+ self.logger = logger
self.default_key_strategy = default_key_strategy
self.default_ttl_seconds = default_ttl_seconds
self.enable_for_all_handlers = enable_for_all_handlers
@@ -137,17 +143,17 @@ def __init__(
def make_handlers_idempotent(self) -> None:
"""Wrap all registered handlers with idempotency"""
- logger.info(
+ self.logger.info(
f"make_handlers_idempotent called: enable_for_all={self.enable_for_all_handlers}, "
f"dispatcher={self.dispatcher is not None}"
)
if not self.enable_for_all_handlers or not self.dispatcher:
- logger.warning("Skipping handler wrapping - conditions not met")
+ self.logger.warning("Skipping handler wrapping - conditions not met")
return
# Store original handlers using public API
self._original_handlers = self.dispatcher.get_all_handlers()
- logger.info(f"Got {len(self._original_handlers)} event types with handlers to wrap")
+ self.logger.info(f"Got {len(self._original_handlers)} event types with handlers to wrap")
# Wrap each handler
for event_type, handlers in self._original_handlers.items():
@@ -157,18 +163,19 @@ def make_handlers_idempotent(self) -> None:
wrapped = IdempotentEventHandler(
handler=handler,
idempotency_manager=self.idempotency_manager,
+ logger=self.logger,
key_strategy=self.default_key_strategy,
ttl_seconds=self.default_ttl_seconds,
)
wrapped_handlers.append(wrapped)
# Replace handlers using public API
- logger.info(
+ self.logger.info(
f"Replacing {len(handlers)} handlers for {event_type} with {len(wrapped_handlers)} wrapped handlers"
)
self.dispatcher.replace_handlers(event_type, wrapped_handlers)
- logger.info("Handler wrapping complete")
+ self.logger.info("Handler wrapping complete")
def subscribe_idempotent_handler(
self,
@@ -186,6 +193,7 @@ def subscribe_idempotent_handler(
idempotent_wrapper = IdempotentEventHandler(
handler=handler,
idempotency_manager=self.idempotency_manager,
+ logger=self.logger,
key_strategy=key_strategy or self.default_key_strategy,
custom_key_func=custom_key_func,
fields=fields,
@@ -196,15 +204,15 @@ def subscribe_idempotent_handler(
# Create an async handler that processes the message
async def async_handler(message: Any) -> Any:
- logger.info(f"IDEMPOTENT HANDLER CALLED for {event_type}")
+ self.logger.info(f"IDEMPOTENT HANDLER CALLED for {event_type}")
# Extract event from confluent-kafka Message
if not hasattr(message, "value"):
- logger.error(f"Received non-Message object for {event_type}: {type(message)}")
+ self.logger.error(f"Received non-Message object for {event_type}: {type(message)}")
return None
# Debug log to check message details
- logger.info(
+ self.logger.info(
f"Handler for {event_type} - Message type: {type(message)}, "
f"has key: {hasattr(message, 'key')}, "
f"has topic: {hasattr(message, 'topic')}"
@@ -213,32 +221,32 @@ async def async_handler(message: Any) -> Any:
raw_value = message.value()
# Debug the raw value
- logger.info(f"Raw value extracted: {raw_value[:100] if raw_value else 'None or empty'}")
+ self.logger.info(f"Raw value extracted: {raw_value[:100] if raw_value else 'None or empty'}")
# Handle tombstone messages (null value for log compaction)
if raw_value is None:
- logger.warning(f"Received empty message for {event_type} - tombstone or consumed value")
+ self.logger.warning(f"Received empty message for {event_type} - tombstone or consumed value")
return None
# Handle empty messages
if not raw_value:
- logger.warning(f"Received empty message for {event_type} - empty bytes")
+ self.logger.warning(f"Received empty message for {event_type} - empty bytes")
return None
try:
# Deserialize using schema registry if available
event = self.consumer._schema_registry.deserialize_event(raw_value, message.topic())
if not event:
- logger.error(f"Failed to deserialize event for {event_type}")
+ self.logger.error(f"Failed to deserialize event for {event_type}")
return None
# Call the idempotent wrapper directly in async context
await idempotent_wrapper(event)
- logger.debug(f"Successfully processed {event_type} event: {event.event_id}")
+ self.logger.debug(f"Successfully processed {event_type} event: {event.event_id}")
return None
except Exception as e:
- logger.error(f"Failed to process message for {event_type}: {e}", exc_info=True)
+ self.logger.error(f"Failed to process message for {event_type}: {e}", exc_info=True)
raise
# Register with the dispatcher if available
@@ -250,17 +258,17 @@ async def dispatch_handler(event: BaseEvent) -> None:
self.dispatcher.register(EventType(event_type))(dispatch_handler)
else:
# Fallback to direct consumer registration if no dispatcher
- logger.error(f"No EventDispatcher available for registering idempotent handler for {event_type}")
+ self.logger.error(f"No EventDispatcher available for registering idempotent handler for {event_type}")
async def start(self, topics: list[KafkaTopic]) -> None:
"""Start the consumer with idempotency"""
- logger.info(f"IdempotentConsumerWrapper.start called with topics: {topics}")
+ self.logger.info(f"IdempotentConsumerWrapper.start called with topics: {topics}")
# Make handlers idempotent before starting
self.make_handlers_idempotent()
# Start the consumer with required topics parameter
await self.consumer.start(topics)
- logger.info("IdempotentConsumerWrapper started successfully")
+ self.logger.info("IdempotentConsumerWrapper started successfully")
async def stop(self) -> None:
"""Stop the consumer"""
diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker.py
index 66bee135..96ef651d 100644
--- a/backend/app/services/k8s_worker/worker.py
+++ b/backend/app/services/k8s_worker/worker.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
import os
import signal
import time
@@ -6,20 +7,19 @@
from typing import Any
import redis.asyncio as redis
+from beanie import init_beanie
from kubernetes import client as k8s_client
from kubernetes import config as k8s_config
from kubernetes.client.rest import ApiException
-from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
-from app.core.database_context import Database, DBClient
+from app.core.database_context import DBClient
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics import ExecutionMetrics, KubernetesMetrics
-from app.db.schema.schema_manager import SchemaManager
+from app.db.docs import ALL_DOCUMENTS
from app.domain.enums.events import EventType
from app.domain.enums.kafka import KafkaTopic
from app.domain.enums.storage import ExecutionErrorType
-from app.domain.execution import ResourceUsageDomain
from app.events.core import ConsumerConfig, EventDispatcher, ProducerConfig, UnifiedConsumer, UnifiedProducer
from app.events.event_store import EventStore, create_event_store
from app.events.schema.schema_registry import (
@@ -59,19 +59,19 @@ class KubernetesWorker(LifecycleEnabled):
def __init__(
self,
config: K8sWorkerConfig,
- database: Database,
producer: UnifiedProducer,
schema_registry_manager: SchemaRegistryManager,
event_store: EventStore,
idempotency_manager: IdempotencyManager,
+ logger: logging.Logger,
):
+ self.logger = logger
self.metrics = KubernetesMetrics()
self.execution_metrics = ExecutionMetrics()
self.config = config or K8sWorkerConfig()
settings = get_settings()
self.kafka_servers = self.config.kafka_bootstrap_servers or settings.KAFKA_BOOTSTRAP_SERVERS
- self._db: Database = database
self._event_store = event_store
# Kubernetes clients
@@ -96,11 +96,11 @@ def __init__(
async def start(self) -> None:
"""Start the Kubernetes worker"""
if self._running:
- logger.warning("KubernetesWorker already running")
+ self.logger.warning("KubernetesWorker already running")
return
- logger.info("Starting KubernetesWorker service...")
- logger.info("DEBUG: About to initialize Kubernetes client")
+ self.logger.info("Starting KubernetesWorker service...")
+ self.logger.info("DEBUG: About to initialize Kubernetes client")
if self.config.namespace == "default":
raise RuntimeError(
@@ -109,11 +109,11 @@ async def start(self) -> None:
# Initialize Kubernetes client
self._initialize_kubernetes_client()
- logger.info("DEBUG: Kubernetes client initialized")
+ self.logger.info("DEBUG: Kubernetes client initialized")
- logger.info("Using provided producer")
+ self.logger.info("Using provided producer")
- logger.info("Idempotency manager provided")
+ self.logger.info("Idempotency manager provided")
# Create consumer configuration
consumer_config = ConsumerConfig(
@@ -123,18 +123,19 @@ async def start(self) -> None:
)
# Create dispatcher and register handlers for saga commands
- self.dispatcher = EventDispatcher()
+ self.dispatcher = EventDispatcher(logger=self.logger)
self.dispatcher.register_handler(EventType.CREATE_POD_COMMAND, self._handle_create_pod_command_wrapper)
self.dispatcher.register_handler(EventType.DELETE_POD_COMMAND, self._handle_delete_pod_command_wrapper)
# Create consumer with dispatcher
- self.consumer = UnifiedConsumer(consumer_config, event_dispatcher=self.dispatcher)
+ self.consumer = UnifiedConsumer(consumer_config, event_dispatcher=self.dispatcher, logger=self.logger)
# Wrap consumer with idempotency - use content hash for pod commands
self.idempotent_consumer = IdempotentConsumerWrapper(
consumer=self.consumer,
idempotency_manager=self.idempotency_manager,
dispatcher=self.dispatcher,
+ logger=self.logger,
default_key_strategy="content_hash", # Hash execution_id + script for deduplication
default_ttl_seconds=3600, # 1 hour TTL for pod creation events
enable_for_all_handlers=True, # Enable idempotency for all handlers
@@ -146,21 +147,21 @@ async def start(self) -> None:
# Create daemonset for image pre-pulling
asyncio.create_task(self.ensure_image_pre_puller_daemonset())
- logger.info("Image pre-puller daemonset task scheduled")
+ self.logger.info("Image pre-puller daemonset task scheduled")
- logger.info("KubernetesWorker service started successfully")
+ self.logger.info("KubernetesWorker service started successfully")
async def stop(self) -> None:
"""Stop the Kubernetes worker"""
if not self._running:
return
- logger.info("Stopping KubernetesWorker service...")
+ self.logger.info("Stopping KubernetesWorker service...")
self._running = False
# Wait for active creations to complete
if self._active_creations:
- logger.info(f"Waiting for {len(self._active_creations)} active pod creations to complete...")
+ self.logger.info(f"Waiting for {len(self._active_creations)} active pod creations to complete...")
timeout = 30
start_time = time.time()
@@ -168,7 +169,7 @@ async def stop(self) -> None:
await asyncio.sleep(1)
if self._active_creations:
- logger.warning(f"Timeout waiting for pod creations, {len(self._active_creations)} still active")
+ self.logger.warning(f"Timeout waiting for pod creations, {len(self._active_creations)} still active")
# Stop the consumer (idempotent wrapper only)
if self.idempotent_consumer:
@@ -181,25 +182,25 @@ async def stop(self) -> None:
if self.producer:
await self.producer.stop()
- logger.info("KubernetesWorker service stopped")
+ self.logger.info("KubernetesWorker service stopped")
def _initialize_kubernetes_client(self) -> None:
"""Initialize Kubernetes API clients"""
try:
# Load config
if self.config.in_cluster:
- logger.info("Using in-cluster Kubernetes configuration")
+ self.logger.info("Using in-cluster Kubernetes configuration")
k8s_config.load_incluster_config()
elif self.config.kubeconfig_path and os.path.exists(self.config.kubeconfig_path):
- logger.info(f"Using kubeconfig from {self.config.kubeconfig_path}")
+ self.logger.info(f"Using kubeconfig from {self.config.kubeconfig_path}")
k8s_config.load_kube_config(config_file=self.config.kubeconfig_path)
else:
# Try default locations
if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount"):
- logger.info("Detected in-cluster environment")
+ self.logger.info("Detected in-cluster environment")
k8s_config.load_incluster_config()
else:
- logger.info("Using default kubeconfig")
+ self.logger.info("Using default kubeconfig")
k8s_config.load_kube_config()
# Get the default configuration that was set by load_kube_config
@@ -207,8 +208,8 @@ def _initialize_kubernetes_client(self) -> None:
# The certificate data should already be configured by load_kube_config
# Log the configuration for debugging
- logger.info(f"Kubernetes API host: {configuration.host}")
- logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}")
+ self.logger.info(f"Kubernetes API host: {configuration.host}")
+ self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}")
# Create API clients with the configuration
api_client = k8s_client.ApiClient(configuration)
@@ -218,22 +219,22 @@ def _initialize_kubernetes_client(self) -> None:
# Test connection with namespace-scoped operation
_ = self.v1.list_namespaced_pod(namespace=self.config.namespace, limit=1)
- logger.info(f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible")
+ self.logger.info(f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible")
except Exception as e:
- logger.error(f"Failed to initialize Kubernetes client: {e}")
+ self.logger.error(f"Failed to initialize Kubernetes client: {e}")
raise
async def _handle_create_pod_command_wrapper(self, event: BaseEvent) -> None:
"""Wrapper for handling CreatePodCommandEvent with type safety."""
assert isinstance(event, CreatePodCommandEvent)
- logger.info(f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}")
+ self.logger.info(f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}")
await self._handle_create_pod_command(event)
async def _handle_delete_pod_command_wrapper(self, event: BaseEvent) -> None:
"""Wrapper for handling DeletePodCommandEvent."""
assert isinstance(event, DeletePodCommandEvent)
- logger.info(f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}")
+ self.logger.info(f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}")
await self._handle_delete_pod_command(event)
async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> None:
@@ -242,7 +243,7 @@ async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> No
# Check if already processing
if execution_id in self._active_creations:
- logger.warning(f"Already creating pod for execution {execution_id}")
+ self.logger.warning(f"Already creating pod for execution {execution_id}")
return
# Create pod asynchronously
@@ -251,7 +252,7 @@ async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> No
async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None:
"""Handle delete pod command from saga orchestrator (compensation)"""
execution_id = command.execution_id
- logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}")
+ self.logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}")
try:
# Delete the pod
@@ -263,7 +264,7 @@ async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> No
namespace=self.config.namespace,
grace_period_seconds=30,
)
- logger.info(f"Successfully deleted pod {pod_name}")
+ self.logger.info(f"Successfully deleted pod {pod_name}")
# Delete associated ConfigMap
configmap_name = f"script-{execution_id}"
@@ -271,15 +272,15 @@ async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> No
await asyncio.to_thread(
self.v1.delete_namespaced_config_map, name=configmap_name, namespace=self.config.namespace
)
- logger.info(f"Successfully deleted ConfigMap {configmap_name}")
+ self.logger.info(f"Successfully deleted ConfigMap {configmap_name}")
# NetworkPolicy cleanup is managed via a static cluster policy; no per-execution NP deletion
except ApiException as e:
if e.status == 404:
- logger.warning(f"Resources for execution {execution_id} not found (may have already been deleted)")
+ self.logger.warning(f"Resources for execution {execution_id} not found (may have already been deleted)")
else:
- logger.error(f"Failed to delete resources for execution {execution_id}: {e}")
+ self.logger.error(f"Failed to delete resources for execution {execution_id}: {e}")
async def _create_pod_for_execution(self, command: CreatePodCommandEvent) -> None:
"""Create pod for execution"""
@@ -315,13 +316,13 @@ async def _create_pod_for_execution(self, command: CreatePodCommandEvent) -> Non
self.metrics.record_k8s_pod_creation_duration(duration, command.language)
self.metrics.record_k8s_pod_created("success", command.language)
- logger.info(
+ self.logger.info(
f"Successfully created pod {pod.metadata.name} for execution {execution_id}. "
f"Duration: {duration:.2f}s"
)
except Exception as e:
- logger.error(f"Failed to create pod for execution {execution_id}: {e}", exc_info=True)
+ self.logger.error(f"Failed to create pod for execution {execution_id}: {e}", exc_info=True)
# Update metrics
self.metrics.record_k8s_pod_created("failed", "unknown")
@@ -365,10 +366,10 @@ async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None:
self.v1.create_namespaced_config_map, namespace=self.config.namespace, body=config_map
)
self.metrics.record_k8s_config_map_created("success")
- logger.debug(f"Created ConfigMap {config_map.metadata.name}")
+ self.logger.debug(f"Created ConfigMap {config_map.metadata.name}")
except ApiException as e:
if e.status == 409: # Already exists
- logger.warning(f"ConfigMap {config_map.metadata.name} already exists")
+ self.logger.warning(f"ConfigMap {config_map.metadata.name} already exists")
self.metrics.record_k8s_config_map_created("already_exists")
else:
self.metrics.record_k8s_config_map_created("failed")
@@ -380,10 +381,10 @@ async def _create_pod(self, pod: k8s_client.V1Pod) -> None:
raise RuntimeError("Kubernetes client not initialized")
try:
await asyncio.to_thread(self.v1.create_namespaced_pod, namespace=self.config.namespace, body=pod)
- logger.debug(f"Created Pod {pod.metadata.name}")
+ self.logger.debug(f"Created Pod {pod.metadata.name}")
except ApiException as e:
if e.status == 409: # Already exists
- logger.warning(f"Pod {pod.metadata.name} already exists")
+ self.logger.warning(f"Pod {pod.metadata.name} already exists")
else:
raise
@@ -398,7 +399,7 @@ async def _publish_execution_started(self, command: CreatePodCommandEvent, pod:
metadata=command.metadata,
)
if not self.producer:
- logger.error("Producer not initialized")
+ self.logger.error("Producer not initialized")
return
await self.producer.produce(event_to_produce=event)
@@ -412,7 +413,7 @@ async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_cl
)
if not self.producer:
- logger.error("Producer not initialized")
+ self.logger.error("Producer not initialized")
return
await self.producer.produce(event_to_produce=event)
@@ -423,13 +424,13 @@ async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, err
error_type=ExecutionErrorType.SYSTEM_ERROR,
exit_code=-1,
stderr=f"Failed to create pod: {error}",
- resource_usage=ResourceUsageDomain.from_dict({}),
+ resource_usage=None,
metadata=command.metadata,
error_message=str(error),
)
if not self.producer:
- logger.error("Producer not initialized")
+ self.logger.error("Producer not initialized")
return
await self.producer.produce(event_to_produce=event)
@@ -448,7 +449,7 @@ async def get_status(self) -> dict[str, Any]:
async def ensure_image_pre_puller_daemonset(self) -> None:
"""Ensure the runtime image pre-puller DaemonSet exists"""
if not self.apps_v1:
- logger.warning("Kubernetes AppsV1Api client not initialized. Skipping DaemonSet creation.")
+ self.logger.warning("Kubernetes AppsV1Api client not initialized. Skipping DaemonSet creation.")
return
daemonset_name = "runtime-image-pre-puller"
@@ -461,7 +462,7 @@ async def ensure_image_pre_puller_daemonset(self) -> None:
for i, image_ref in enumerate(sorted(list(all_images))):
sanitized_image_ref = image_ref.split("/")[-1].replace(":", "-").replace(".", "-").replace("_", "-")
- logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}")
+ self.logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}")
container_name = f"pull-{i}-{sanitized_image_ref}"
init_containers.append(
{
@@ -494,51 +495,56 @@ async def ensure_image_pre_puller_daemonset(self) -> None:
await asyncio.to_thread(
self.apps_v1.read_namespaced_daemon_set, name=daemonset_name, namespace=namespace
)
- logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.")
+ self.logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.")
await asyncio.to_thread(
self.apps_v1.replace_namespaced_daemon_set, name=daemonset_name, namespace=namespace, body=manifest
)
- logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.")
+ self.logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.")
except ApiException as e:
if e.status == 404:
- logger.info(f"DaemonSet '{daemonset_name}' not found. Creating...")
+ self.logger.info(f"DaemonSet '{daemonset_name}' not found. Creating...")
await asyncio.to_thread(
self.apps_v1.create_namespaced_daemon_set, namespace=namespace, body=manifest
)
- logger.info(f"DaemonSet '{daemonset_name}' created successfully.")
+ self.logger.info(f"DaemonSet '{daemonset_name}' created successfully.")
else:
raise
except ApiException as e:
- logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True)
+ self.logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True)
except Exception as e:
- logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True)
+ self.logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True)
async def run_kubernetes_worker() -> None:
"""Run the Kubernetes worker service"""
+ import os
from contextlib import AsyncExitStack
+ from app.core.logging import setup_logger
+
+ logger = setup_logger(os.environ.get("LOG_LEVEL", "INFO"))
logger.info("Initializing database connection...")
settings = get_settings()
- db_client: DBClient = AsyncIOMotorClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
+ db_client: DBClient = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
db_name = settings.DATABASE_NAME
database = db_client[db_name]
await db_client.admin.command("ping")
logger.info(f"Connected to database: {db_name}")
- await SchemaManager(database).apply_all()
+ # Initialize Beanie ODM (indexes are idempotently created via Document.Settings.indexes)
+ await init_beanie(database=database, document_models=ALL_DOCUMENTS)
logger.info("Initializing schema registry...")
- schema_registry_manager = create_schema_registry_manager()
+ schema_registry_manager = create_schema_registry_manager(logger)
await initialize_event_schemas(schema_registry_manager)
logger.info("Creating event store...")
- event_store = create_event_store(database, schema_registry_manager)
+ event_store = create_event_store(schema_registry_manager, logger)
logger.info("Creating producer for Kubernetes worker...")
producer_config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS)
- producer = UnifiedProducer(producer_config, schema_registry_manager)
+ producer = UnifiedProducer(producer_config, schema_registry_manager, logger)
config = K8sWorkerConfig()
r = redis.Redis(
@@ -553,16 +559,16 @@ async def run_kubernetes_worker() -> None:
socket_timeout=5,
)
idem_repo = RedisIdempotencyRepository(r, key_prefix="idempotency")
- idem_manager = create_idempotency_manager(repository=idem_repo, config=IdempotencyConfig())
+ idem_manager = create_idempotency_manager(repository=idem_repo, config=IdempotencyConfig(), logger=logger)
await idem_manager.initialize()
worker = KubernetesWorker(
config=config,
- database=database,
producer=producer,
schema_registry_manager=schema_registry_manager,
event_store=event_store,
idempotency_manager=idem_manager,
+ logger=logger,
)
def signal_handler(sig: int, frame: Any) -> None:
diff --git a/backend/app/services/kafka_event_service.py b/backend/app/services/kafka_event_service.py
index 166f634f..25a13f46 100644
--- a/backend/app/services/kafka_event_service.py
+++ b/backend/app/services/kafka_event_service.py
@@ -1,3 +1,4 @@
+import logging
import time
from datetime import datetime, timezone
from typing import Any, Dict
@@ -6,7 +7,6 @@
from opentelemetry import trace
from app.core.correlation import CorrelationContext
-from app.core.logging import logger
from app.core.metrics.context import get_event_metrics
from app.core.tracing.utils import inject_trace_context
from app.db.repositories.event_repository import EventRepository
@@ -14,6 +14,7 @@
from app.domain.events import Event
from app.domain.events import EventMetadata as DomainEventMetadata
from app.events.core import UnifiedProducer
+from app.infrastructure.kafka.events.base import BaseEvent
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
from app.infrastructure.kafka.mappings import get_event_class_for_type
from app.settings import get_settings
@@ -22,9 +23,10 @@
class KafkaEventService:
- def __init__(self, event_repository: EventRepository, kafka_producer: UnifiedProducer):
+ def __init__(self, event_repository: EventRepository, kafka_producer: UnifiedProducer, logger: logging.Logger):
self.event_repository = event_repository
self.kafka_producer = kafka_producer
+ self.logger = logger
self.metrics = get_event_metrics()
self.settings = get_settings()
@@ -70,8 +72,10 @@ async def publish_event(
event_id = str(uuid4())
timestamp = datetime.now(timezone.utc)
- # Convert to domain metadata for storage
- domain_metadata = DomainEventMetadata.from_dict(avro_metadata.to_dict())
+ # Convert to domain metadata for storage (only include defined fields)
+ domain_metadata = DomainEventMetadata(
+ **avro_metadata.model_dump(include=set(DomainEventMetadata.__dataclass_fields__))
+ )
event = Event(
event_id=event_id,
@@ -128,7 +132,7 @@ async def publish_event(
duration = time.time() - start_time
self.metrics.record_event_processing_duration(duration, event_type)
- logger.info(
+ self.logger.info(
"Event published",
extra={
"event_type": event_type,
@@ -148,7 +152,7 @@ async def publish_execution_event(
error_message: str | None = None,
) -> str:
"""Publish execution-related event using provided metadata (no framework coupling)."""
- logger.info(
+ self.logger.info(
"Publishing execution event",
extra={
"event_type": event_type,
@@ -169,7 +173,7 @@ async def publish_execution_event(
metadata=metadata,
)
- logger.info(
+ self.logger.info(
"Execution event published successfully",
extra={
"event_type": event_type,
@@ -202,6 +206,84 @@ async def publish_pod_event(
metadata=metadata,
)
+ async def publish_base_event(self, event: BaseEvent, key: str | None = None) -> str:
+ """
+ Publish a pre-built BaseEvent to Kafka and store an audit copy.
+
+ Used by PodMonitor and other components that create fully-formed events.
+ This ensures events are stored in the events collection AND published to Kafka.
+
+ Args:
+ event: Pre-built BaseEvent with all fields populated
+ key: Optional Kafka message key (defaults to aggregate_id)
+
+ Returns:
+ Event ID of published event
+ """
+ with tracer.start_as_current_span("publish_base_event") as span:
+ span.set_attribute("event.type", str(event.event_type))
+ if event.aggregate_id:
+ span.set_attribute("aggregate.id", event.aggregate_id)
+
+ start_time = time.time()
+
+ # Convert to domain metadata for storage
+ domain_metadata = DomainEventMetadata(**event.metadata.model_dump())
+
+ # Build payload from event attributes (exclude base fields)
+ base_fields = {"event_id", "event_type", "event_version", "timestamp", "aggregate_id", "metadata", "topic"}
+ payload = {k: v for k, v in vars(event).items() if k not in base_fields and not k.startswith("_")}
+
+ # Create domain event for storage
+ domain_event = Event(
+ event_id=event.event_id,
+ event_type=event.event_type,
+ event_version=event.event_version,
+ timestamp=event.timestamp,
+ aggregate_id=event.aggregate_id,
+ metadata=domain_metadata,
+ payload=payload,
+ )
+ await self.event_repository.store_event(domain_event)
+
+ # Prepare headers
+ headers: Dict[str, str] = {
+ "event_type": str(event.event_type),
+ "correlation_id": event.metadata.correlation_id or "",
+ "service": event.metadata.service_name,
+ }
+
+ # Add trace context
+ span_context = span.get_span_context()
+ if span_context.is_valid:
+ headers["trace_id"] = f"{span_context.trace_id:032x}"
+ headers["span_id"] = f"{span_context.span_id:016x}"
+
+ headers = inject_trace_context(headers)
+
+ # Publish to Kafka
+ await self.kafka_producer.produce(
+ event_to_produce=event,
+ key=key or event.aggregate_id,
+ headers=headers,
+ )
+
+ self.metrics.record_event_published(event.event_type)
+
+ duration = time.time() - start_time
+ self.metrics.record_event_processing_duration(duration, event.event_type)
+
+ self.logger.info(
+ "Base event published",
+ extra={
+ "event_type": str(event.event_type),
+ "event_id": event.event_id,
+ "aggregate_id": event.aggregate_id,
+ },
+ )
+
+ return event.event_id
+
async def close(self) -> None:
"""Close event service resources"""
await self.kafka_producer.stop()
diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py
index e5dfa937..e04754e0 100644
--- a/backend/app/services/notification_service.py
+++ b/backend/app/services/notification_service.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from enum import auto
@@ -6,8 +7,6 @@
import httpx
-from app.core.exceptions import ServiceError
-from app.core.logging import logger
from app.core.metrics.context import get_notification_metrics
from app.core.tracing.utils import add_span_attributes
from app.core.utils import StringEnum
@@ -22,8 +21,14 @@
from app.domain.enums.user import UserRole
from app.domain.notification import (
DomainNotification,
+ DomainNotificationCreate,
DomainNotificationListResult,
DomainNotificationSubscription,
+ DomainNotificationUpdate,
+ DomainSubscriptionUpdate,
+ NotificationNotFoundError,
+ NotificationThrottledError,
+ NotificationValidationError,
)
from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer
from app.events.schema.schema_registry import SchemaRegistryManager
@@ -116,6 +121,7 @@ def __init__(
schema_registry_manager: SchemaRegistryManager,
sse_bus: SSERedisBus,
settings: Settings,
+ logger: logging.Logger,
) -> None:
self.repository = notification_repository
self.event_service = event_service
@@ -124,6 +130,7 @@ def __init__(
self.settings = settings
self.schema_registry_manager = schema_registry_manager
self.sse_bus = sse_bus
+ self.logger = logger
# State
self._state = ServiceState.IDLE
@@ -136,7 +143,7 @@ def __init__(
self._dispatcher: EventDispatcher | None = None
self._consumer_task: asyncio.Task[None] | None = None
- logger.info(
+ self.logger.info(
"NotificationService initialized",
extra={
"repository": type(notification_repository).__name__,
@@ -158,7 +165,7 @@ def state(self) -> ServiceState:
def initialize(self) -> None:
if self._state != ServiceState.IDLE:
- logger.warning(f"Cannot initialize in state: {self._state}")
+ self.logger.warning(f"Cannot initialize in state: {self._state}")
return
self._state = ServiceState.INITIALIZING
@@ -167,14 +174,14 @@ def initialize(self) -> None:
self._state = ServiceState.RUNNING
self._start_background_tasks()
- logger.info("Notification service initialized (without Kafka consumer)")
+ self.logger.info("Notification service initialized (without Kafka consumer)")
async def shutdown(self) -> None:
"""Shutdown notification service."""
if self._state == ServiceState.STOPPED:
return
- logger.info("Shutting down notification service...")
+ self.logger.info("Shutting down notification service...")
self._state = ServiceState.STOPPING
# Cancel all tasks
@@ -193,7 +200,7 @@ async def shutdown(self) -> None:
await self._throttle_cache.clear()
self._state = ServiceState.STOPPED
- logger.info("Notification service stopped")
+ self.logger.info("Notification service stopped")
def _start_background_tasks(self) -> None:
"""Start background processing tasks."""
@@ -220,17 +227,17 @@ async def _subscribe_to_events(self) -> None:
execution_results_topic = get_topic_for_event(EventType.EXECUTION_COMPLETED)
# Log topics for debugging
- logger.info(f"Notification service will subscribe to topics: {execution_results_topic}")
+ self.logger.info(f"Notification service will subscribe to topics: {execution_results_topic}")
# Create dispatcher and register handlers for specific event types
- self._dispatcher = EventDispatcher()
+ self._dispatcher = EventDispatcher(logger=self.logger)
# Use a single handler for execution result events (simpler and less brittle)
self._dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_execution_event)
self._dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_execution_event)
self._dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_execution_event)
# Create consumer with dispatcher
- self._consumer = UnifiedConsumer(consumer_config, event_dispatcher=self._dispatcher)
+ self._consumer = UnifiedConsumer(consumer_config, event_dispatcher=self._dispatcher, logger=self.logger)
# Start consumer
await self._consumer.start([execution_results_topic])
@@ -240,7 +247,7 @@ async def _subscribe_to_events(self) -> None:
self._tasks.add(self._consumer_task)
self._consumer_task.add_done_callback(self._tasks.discard)
- logger.info("Notification service subscribed to execution events")
+ self.logger.info("Notification service subscribed to execution events")
async def create_notification(
self,
@@ -255,8 +262,8 @@ async def create_notification(
metadata: NotificationContext | None = None,
) -> DomainNotification:
if not tags:
- raise ServiceError("tags must be a non-empty list", status_code=422)
- logger.info(
+ raise NotificationValidationError("tags must be a non-empty list")
+ self.logger.info(
f"Creating notification for user {user_id}",
extra={
"user_id": user_id,
@@ -279,12 +286,15 @@ async def create_notification(
f"Max {self.settings.NOTIF_THROTTLE_MAX_PER_HOUR} "
f"per {self.settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)"
)
- logger.warning(error_msg)
- # Throttling is a client-driven rate issue
- raise ServiceError(error_msg, status_code=429)
+ self.logger.warning(error_msg)
+ raise NotificationThrottledError(
+ user_id,
+ self.settings.NOTIF_THROTTLE_MAX_PER_HOUR,
+ self.settings.NOTIF_THROTTLE_WINDOW_HOURS,
+ )
# Create notification
- notification = DomainNotification(
+ create_data = DomainNotificationCreate(
user_id=user_id,
channel=channel,
subject=subject,
@@ -293,12 +303,11 @@ async def create_notification(
severity=severity,
tags=tags,
scheduled_for=scheduled_for,
- status=NotificationStatus.PENDING,
metadata=metadata or {},
)
# Save to database
- await self.repository.create_notification(notification)
+ notification = await self.repository.create_notification(create_data)
# Publish event
event_bus = await self.event_bus_manager.get_event_bus()
@@ -353,7 +362,7 @@ async def worker(uid: str) -> str:
throttled = sum(1 for r in results if r == "throttled")
failed = sum(1 for r in results if r == "failed")
- logger.info(
+ self.logger.info(
"System notification completed",
extra={
"severity": str(cfg.severity),
@@ -408,7 +417,9 @@ async def _create_system_for_user(
)
return "created"
except Exception as e:
- logger.error("Failed to create system notification for user", extra={"user_id": user_id, "error": str(e)})
+ self.logger.error(
+ "Failed to create system notification for user", extra={"user_id": user_id, "error": str(e)}
+ )
return "failed"
async def _send_in_app(
@@ -443,7 +454,7 @@ async def _send_webhook(
headers = notification.webhook_headers or {}
headers["Content-Type"] = "application/json"
- logger.debug(
+ self.logger.debug(
f"Sending webhook notification to {webhook_url}",
extra={
"notification_id": str(notification.notification_id),
@@ -462,7 +473,7 @@ async def _send_webhook(
async with httpx.AsyncClient() as client:
response = await client.post(webhook_url, json=payload, headers=headers, timeout=30.0)
response.raise_for_status()
- logger.debug(
+ self.logger.debug(
"Webhook delivered successfully",
extra={
"notification_id": str(notification.notification_id),
@@ -498,7 +509,7 @@ async def _send_slack(self, notification: DomainNotification, subscription: Doma
if attachments and isinstance(attachments, list):
attachments[0]["actions"] = [{"type": "button", "text": "View Details", "url": notification.action_url}]
- logger.debug(
+ self.logger.debug(
"Sending Slack notification",
extra={
"notification_id": str(notification.notification_id),
@@ -516,7 +527,7 @@ async def _send_slack(self, notification: DomainNotification, subscription: Doma
async with httpx.AsyncClient() as client:
response = await client.post(subscription.slack_webhook, json=slack_message, timeout=30.0)
response.raise_for_status()
- logger.debug(
+ self.logger.debug(
"Slack notification delivered successfully",
extra={"notification_id": str(notification.notification_id), "status_code": response.status_code},
)
@@ -549,7 +560,7 @@ async def _process_pending_notifications(self) -> None:
await asyncio.sleep(5)
except Exception as e:
- logger.error(f"Error processing pending notifications: {e}")
+ self.logger.error(f"Error processing pending notifications: {e}")
await asyncio.sleep(10)
async def _cleanup_old_notifications(self) -> None:
@@ -565,10 +576,10 @@ async def _cleanup_old_notifications(self) -> None:
# Delete old notifications
deleted_count = await self.repository.cleanup_old_notifications(self.settings.NOTIF_OLD_DAYS)
- logger.info(f"Cleaned up {deleted_count} old notifications")
+ self.logger.info(f"Cleaned up {deleted_count} old notifications")
except Exception as e:
- logger.error(f"Error cleaning up old notifications: {e}")
+ self.logger.error(f"Error cleaning up old notifications: {e}")
async def _run_consumer(self) -> None:
"""Run the event consumer loop."""
@@ -577,17 +588,17 @@ async def _run_consumer(self) -> None:
# Consumer handles polling internally
await asyncio.sleep(1)
except asyncio.CancelledError:
- logger.info("Notification consumer task cancelled")
+ self.logger.info("Notification consumer task cancelled")
break
except Exception as e:
- logger.error(f"Error in notification consumer loop: {e}")
+ self.logger.error(f"Error in notification consumer loop: {e}")
await asyncio.sleep(5)
async def _handle_execution_timeout_typed(self, event: ExecutionTimeoutEvent) -> None:
"""Handle typed execution timeout event."""
user_id = event.metadata.user_id
if not user_id:
- logger.error("No user_id in event metadata")
+ self.logger.error("No user_id in event metadata")
return
title = f"Execution Timeout: {event.execution_id}"
@@ -607,7 +618,7 @@ async def _handle_execution_completed_typed(self, event: ExecutionCompletedEvent
"""Handle typed execution completed event."""
user_id = event.metadata.user_id
if not user_id:
- logger.error("No user_id in event metadata")
+ self.logger.error("No user_id in event metadata")
return
title = f"Execution Completed: {event.execution_id}"
@@ -635,15 +646,15 @@ async def _handle_execution_event(self, event: BaseEvent) -> None:
elif isinstance(event, ExecutionTimeoutEvent):
await self._handle_execution_timeout_typed(event)
else:
- logger.warning(f"Unhandled execution event type: {event.event_type}")
+ self.logger.warning(f"Unhandled execution event type: {event.event_type}")
except Exception as e:
- logger.error(f"Error handling execution event: {e}", exc_info=True)
+ self.logger.error(f"Error handling execution event: {e}", exc_info=True)
async def _handle_execution_failed_typed(self, event: ExecutionFailedEvent) -> None:
"""Handle typed execution failed event."""
user_id = event.metadata.user_id
if not user_id:
- logger.error("No user_id in event metadata")
+ self.logger.error("No user_id in event metadata")
return
# Use model_dump to get all event data
@@ -677,7 +688,7 @@ async def mark_as_read(self, user_id: str, notification_id: str) -> bool:
{"notification_id": str(notification_id), "user_id": user_id, "read_at": datetime.now(UTC).isoformat()},
)
else:
- raise ServiceError("Notification not found", status_code=404)
+ raise NotificationNotFoundError(notification_id)
return True
@@ -729,42 +740,26 @@ async def update_subscription(
# Validate channel-specific requirements
if channel == NotificationChannel.WEBHOOK and enabled:
if not webhook_url:
- raise ServiceError("webhook_url is required when enabling WEBHOOK", status_code=422)
+ raise NotificationValidationError("webhook_url is required when enabling WEBHOOK")
if not (webhook_url.startswith("http://") or webhook_url.startswith("https://")):
- raise ServiceError("webhook_url must start with http:// or https://", status_code=422)
+ raise NotificationValidationError("webhook_url must start with http:// or https://")
if channel == NotificationChannel.SLACK and enabled:
if not slack_webhook:
- raise ServiceError("slack_webhook is required when enabling SLACK", status_code=422)
+ raise NotificationValidationError("slack_webhook is required when enabling SLACK")
if not slack_webhook.startswith("https://hooks.slack.com/"):
- raise ServiceError("slack_webhook must be a valid Slack webhook URL", status_code=422)
-
- # Get existing or create new
- subscription = await self.repository.get_subscription(user_id, channel)
-
- if not subscription:
- subscription = DomainNotificationSubscription(
- user_id=user_id,
- channel=channel,
- enabled=enabled,
- )
- else:
- subscription.enabled = enabled
-
- # Update URLs if provided
- if webhook_url is not None:
- subscription.webhook_url = webhook_url
- if slack_webhook is not None:
- subscription.slack_webhook = slack_webhook
- if severities is not None:
- subscription.severities = severities
- if include_tags is not None:
- subscription.include_tags = include_tags
- if exclude_tags is not None:
- subscription.exclude_tags = exclude_tags
-
- await self.repository.upsert_subscription(user_id, channel, subscription)
+ raise NotificationValidationError("slack_webhook must be a valid Slack webhook URL")
+
+ # Build update data
+ update_data = DomainSubscriptionUpdate(
+ enabled=enabled,
+ webhook_url=webhook_url,
+ slack_webhook=slack_webhook,
+ severities=severities,
+ include_tags=include_tags,
+ exclude_tags=exclude_tags,
+ )
- return subscription
+ return await self.repository.upsert_subscription(user_id, channel, update_data)
async def mark_all_as_read(self, user_id: str) -> int:
"""Mark all notifications as read for a user."""
@@ -778,7 +773,7 @@ async def mark_all_as_read(self, user_id: str) -> int:
return count
- async def get_subscriptions(self, user_id: str) -> dict[str, DomainNotificationSubscription]:
+ async def get_subscriptions(self, user_id: str) -> dict[NotificationChannel, DomainNotificationSubscription]:
"""Get all notification subscriptions for a user."""
return await self.repository.get_all_subscriptions(user_id)
@@ -786,7 +781,7 @@ async def delete_notification(self, user_id: str, notification_id: str) -> bool:
"""Delete a notification."""
deleted = await self.repository.delete_notification(str(notification_id), user_id)
if not deleted:
- raise ServiceError("Notification not found", status_code=404)
+ raise NotificationNotFoundError(notification_id)
return deleted
async def _publish_notification_sse(self, notification: DomainNotification) -> None:
@@ -834,7 +829,7 @@ async def _deliver_notification(self, notification: DomainNotification) -> None:
if not claimed:
return
- logger.info(
+ self.logger.info(
f"Delivering notification {notification.notification_id}",
extra={
"notification_id": str(notification.notification_id),
@@ -851,10 +846,12 @@ async def _deliver_notification(self, notification: DomainNotification) -> None:
# Check if notification should be skipped
skip_reason = await self._should_skip_notification(notification, subscription)
if skip_reason:
- logger.info(skip_reason)
- notification.status = NotificationStatus.SKIPPED
- notification.error_message = skip_reason
- await self.repository.update_notification(notification)
+ self.logger.info(skip_reason)
+ await self.repository.update_notification(
+ notification.notification_id,
+ notification.user_id,
+ DomainNotificationUpdate(status=NotificationStatus.SKIPPED, error_message=skip_reason),
+ )
return
# At this point, subscription is guaranteed to be non-None (checked in _should_skip_notification)
@@ -870,16 +867,18 @@ async def _deliver_notification(self, notification: DomainNotification) -> None:
f"Available channels: {list(self._channel_handlers.keys())}"
)
- logger.debug(f"Using handler {handler.__name__} for channel {notification.channel}")
+ self.logger.debug(f"Using handler {handler.__name__} for channel {notification.channel}")
await handler(notification, subscription)
delivery_time = asyncio.get_event_loop().time() - start_time
- # Mark delivered if handler didn't change it
- notification.status = NotificationStatus.DELIVERED
- notification.delivered_at = datetime.now(UTC)
- await self.repository.update_notification(notification)
+ # Mark delivered
+ await self.repository.update_notification(
+ notification.notification_id,
+ notification.user_id,
+ DomainNotificationUpdate(status=NotificationStatus.DELIVERED, delivered_at=datetime.now(UTC)),
+ )
- logger.info(
+ self.logger.info(
f"Successfully delivered notification {notification.notification_id}",
extra={
"notification_id": str(notification.notification_id),
@@ -904,26 +903,43 @@ async def _deliver_notification(self, notification: DomainNotification) -> None:
"max_retries": notification.max_retries,
}
- logger.error(
+ self.logger.error(
f"Failed to deliver notification {notification.notification_id}: {str(e)}",
extra=error_details,
exc_info=True,
)
- notification.status = NotificationStatus.FAILED
- notification.failed_at = datetime.now(UTC)
- notification.error_message = f"Delivery failed via {notification.channel}: {str(e)}"
- notification.retry_count = notification.retry_count + 1
+ new_retry_count = notification.retry_count + 1
+ error_message = f"Delivery failed via {notification.channel}: {str(e)}"
+ failed_at = datetime.now(UTC)
# Schedule retry if under limit
- if notification.retry_count < notification.max_retries:
+ if new_retry_count < notification.max_retries:
retry_time = datetime.now(UTC) + timedelta(minutes=self.settings.NOTIF_RETRY_DELAY_MINUTES)
- notification.scheduled_for = retry_time
- notification.status = NotificationStatus.PENDING
- logger.info(
- f"Scheduled retry {notification.retry_count}/{notification.max_retries} "
- f"for {notification.notification_id}",
+ self.logger.info(
+ f"Scheduled retry {new_retry_count}/{notification.max_retries} for {notification.notification_id}",
extra={"retry_at": retry_time.isoformat()},
)
-
- await self.repository.update_notification(notification)
+ # Will be retried - keep as PENDING but with scheduled_for
+ # Note: scheduled_for not in DomainNotificationUpdate, so we update status fields only
+ await self.repository.update_notification(
+ notification.notification_id,
+ notification.user_id,
+ DomainNotificationUpdate(
+ status=NotificationStatus.PENDING,
+ failed_at=failed_at,
+ error_message=error_message,
+ retry_count=new_retry_count,
+ ),
+ )
+ else:
+ await self.repository.update_notification(
+ notification.notification_id,
+ notification.user_id,
+ DomainNotificationUpdate(
+ status=NotificationStatus.FAILED,
+ failed_at=failed_at,
+ error_message=error_message,
+ retry_count=new_retry_count,
+ ),
+ )
diff --git a/backend/app/services/pod_monitor/event_mapper.py b/backend/app/services/pod_monitor/event_mapper.py
index fc16bc1f..c608035a 100644
--- a/backend/app/services/pod_monitor/event_mapper.py
+++ b/backend/app/services/pod_monitor/event_mapper.py
@@ -1,11 +1,11 @@
import ast
import json
+import logging
from dataclasses import dataclass
from typing import Protocol
from kubernetes import client as k8s_client
-from app.core.logging import logger
from app.domain.enums.kafka import GroupId
from app.domain.enums.storage import ExecutionErrorType
from app.domain.execution import ResourceUsageDomain
@@ -40,12 +40,12 @@ class PodContext:
@dataclass(frozen=True)
class PodLogs:
- """Parsed pod logs and execution results"""
+ """Parsed pod logs and execution results. Only created when parsing succeeds."""
- stdout: str = ""
- stderr: str = ""
- exit_code: int | None = None
- resource_usage: ResourceUsageDomain | None = None
+ stdout: str
+ stderr: str
+ exit_code: int
+ resource_usage: ResourceUsageDomain
class EventMapper(Protocol):
@@ -57,7 +57,8 @@ def __call__(self, ctx: PodContext) -> BaseEvent | None: ...
class PodEventMapper:
"""Maps Kubernetes pod objects to application events"""
- def __init__(self, k8s_api: k8s_client.CoreV1Api | None = None) -> None:
+ def __init__(self, logger: logging.Logger, k8s_api: k8s_client.CoreV1Api | None = None) -> None:
+ self.logger = logger
self._event_cache: dict[str, PodPhase] = {}
self._k8s_api = k8s_api
@@ -76,14 +77,14 @@ def __init__(self, k8s_api: k8s_client.CoreV1Api | None = None) -> None:
def map_pod_event(self, pod: k8s_client.V1Pod, event_type: str) -> EventList:
"""Map a Kubernetes pod to application events"""
- logger.info(
+ self.logger.info(
f"POD-EVENT: type={event_type} name={getattr(pod.metadata, 'name', None)} "
f"ns={getattr(pod.metadata, 'namespace', None)} phase={getattr(pod.status, 'phase', None)}"
)
# Extract execution ID
execution_id = self._extract_execution_id(pod)
if not execution_id:
- logger.warning(
+ self.logger.warning(
f"POD-EVENT: missing execution_id name={getattr(pod.metadata, 'name', None)} "
f"labels={getattr(pod.metadata, 'labels', None)} "
f"annotations={getattr(pod.metadata, 'annotations', None)}"
@@ -98,13 +99,13 @@ def map_pod_event(self, pod: k8s_client.V1Pod, event_type: str) -> EventList:
# Skip duplicate events
if pod.metadata and self._is_duplicate(pod.metadata.name, phase):
- logger.debug(f"POD-EVENT: duplicate ignored name={pod.metadata.name} phase={phase}")
+ self.logger.debug(f"POD-EVENT: duplicate ignored name={pod.metadata.name} phase={phase}")
return []
ctx = PodContext(
pod=pod, execution_id=execution_id, metadata=self._create_metadata(pod), phase=phase, event_type=event_type
)
- logger.info(
+ self.logger.info(
f"POD-EVENT: ctx execution_id={ctx.execution_id} phase={ctx.phase} "
f"reason={getattr(getattr(pod, 'status', None), 'reason', None)}"
)
@@ -114,7 +115,7 @@ def map_pod_event(self, pod: k8s_client.V1Pod, event_type: str) -> EventList:
# Check for timeout first - if pod timed out, only return timeout event
if timeout_event := self._check_timeout(ctx):
- logger.info(
+ self.logger.info(
f"POD-EVENT: mapped TIMEOUT exec={ctx.execution_id} phase={ctx.phase} "
f"adl={getattr(getattr(pod, 'spec', None), 'active_deadline_seconds', None)}"
)
@@ -129,21 +130,23 @@ def map_pod_event(self, pod: k8s_client.V1Pod, event_type: str) -> EventList:
and pod.metadata
and prior_phase == "Pending"
):
- logger.debug(f"POD-EVENT: skipping running map due to empty statuses after Pending exec={execution_id}")
+ self.logger.debug(
+ f"POD-EVENT: skipping running map due to empty statuses after Pending exec={execution_id}"
+ )
return events
# Phase-based mappers
for mapper in self._phase_mappers.get(phase, []):
if event := mapper(ctx):
mapper_name = getattr(mapper, "__name__", repr(mapper))
- logger.info(f"POD-EVENT: phase-map {mapper_name} -> {event.event_type} exec={ctx.execution_id}")
+ self.logger.info(f"POD-EVENT: phase-map {mapper_name} -> {event.event_type} exec={ctx.execution_id}")
events.append(event)
# Event type mappers
for mapper in self._event_type_mappers.get(event_type, []):
if event := mapper(ctx):
mapper_name = getattr(mapper, "__name__", repr(mapper))
- logger.info(f"POD-EVENT: type-map {mapper_name} -> {event.event_type} exec={ctx.execution_id}")
+ self.logger.info(f"POD-EVENT: type-map {mapper_name} -> {event.event_type} exec={ctx.execution_id}")
events.append(event)
return events
@@ -155,17 +158,19 @@ def _extract_execution_id(self, pod: k8s_client.V1Pod) -> str | None:
# Try labels first
if pod.metadata.labels and (exec_id := pod.metadata.labels.get("execution-id")):
- logger.debug(f"POD-EVENT: extracted exec-id from label name={pod.metadata.name} exec_id={exec_id}")
+ self.logger.debug(f"POD-EVENT: extracted exec-id from label name={pod.metadata.name} exec_id={exec_id}")
return str(exec_id)
# Try annotations
if pod.metadata.annotations and (exec_id := pod.metadata.annotations.get("integr8s.io/execution-id")):
- logger.debug(f"POD-EVENT: extracted exec-id from annotation name={pod.metadata.name} exec_id={exec_id}")
+ self.logger.debug(
+ f"POD-EVENT: extracted exec-id from annotation name={pod.metadata.name} exec_id={exec_id}"
+ )
return str(exec_id)
# Try pod name pattern
if pod.metadata.name and pod.metadata.name.startswith("exec-"):
- logger.debug(f"POD-EVENT: extracted exec-id from name pattern name={pod.metadata.name}")
+ self.logger.debug(f"POD-EVENT: extracted exec-id from name pattern name={pod.metadata.name}")
return str(pod.metadata.name[5:])
return None
@@ -185,7 +190,7 @@ def _create_metadata(self, pod: k8s_client.V1Pod) -> EventMetadata:
service_version="1.0.0",
correlation_id=correlation_id,
)
- logger.info(f"POD-EVENT: metadata user_id={md.user_id} corr={md.correlation_id} name={pod.metadata.name}")
+ self.logger.info(f"POD-EVENT: metadata user_id={md.user_id} corr={md.correlation_id} name={pod.metadata.name}")
return md
def _is_duplicate(self, pod_name: str, phase: PodPhase) -> bool:
@@ -215,7 +220,7 @@ def _map_scheduled(self, ctx: PodContext) -> PodScheduledEvent | None:
node_name=ctx.pod.spec.node_name or "pending",
metadata=ctx.metadata,
)
- logger.debug(f"POD-EVENT: mapped scheduled -> {evt.event_type} exec={ctx.execution_id}")
+ self.logger.debug(f"POD-EVENT: mapped scheduled -> {evt.event_type} exec={ctx.execution_id}")
return evt
def _map_running(self, ctx: PodContext) -> PodRunningEvent | None:
@@ -240,7 +245,7 @@ def _map_running(self, ctx: PodContext) -> PodRunningEvent | None:
container_statuses=json.dumps(container_statuses), # Serialize as JSON string
metadata=ctx.metadata,
)
- logger.debug(f"POD-EVENT: mapped running -> {evt.event_type} exec={ctx.execution_id}")
+ self.logger.debug(f"POD-EVENT: mapped running -> {evt.event_type} exec={ctx.execution_id}")
return evt
def _map_completed(self, ctx: PodContext) -> ExecutionCompletedEvent | None:
@@ -250,18 +255,20 @@ def _map_completed(self, ctx: PodContext) -> ExecutionCompletedEvent | None:
return None
logs = self._extract_logs(ctx.pod)
- exit_code = logs.exit_code if logs.exit_code is not None else container.state.terminated.exit_code
+ if not logs:
+ self.logger.error(f"POD-EVENT: failed to extract logs for completed pod exec={ctx.execution_id}")
+ return None
evt = ExecutionCompletedEvent(
execution_id=ctx.execution_id,
- aggregate_id=ctx.execution_id, # Set aggregate_id to execution_id
- exit_code=exit_code,
+ aggregate_id=ctx.execution_id,
+ exit_code=logs.exit_code,
+ resource_usage=logs.resource_usage,
stdout=logs.stdout,
stderr=logs.stderr,
- resource_usage=logs.resource_usage or ResourceUsageDomain.from_dict({}),
metadata=ctx.metadata,
)
- logger.info(f"POD-EVENT: mapped completed exec={ctx.execution_id} exit_code={exit_code}")
+ self.logger.info(f"POD-EVENT: mapped completed exec={ctx.execution_id} exit_code={logs.exit_code}")
return evt
def _map_failed_or_completed(self, ctx: PodContext) -> BaseEvent | None:
@@ -279,29 +286,24 @@ def _map_failed(self, ctx: PodContext) -> ExecutionFailedEvent | None:
error_info = self._analyze_failure(ctx.pod)
logs = self._extract_logs(ctx.pod)
- # If no stderr from logs but we have an error message, use it as stderr
- stderr = logs.stderr if logs.stderr else error_info.message
- # Ensure exit_code is populated (fallback to logs or generic non-zero)
- exit_code = (
- error_info.exit_code
- if error_info.exit_code is not None
- else (logs.exit_code if logs.exit_code is not None else 1)
- )
+ # Use logs data if available, fallback to error_info
+ stdout = logs.stdout if logs else ""
+ stderr = logs.stderr if logs and logs.stderr else error_info.message
+ exit_code = error_info.exit_code if error_info.exit_code is not None else (logs.exit_code if logs else 1)
evt = ExecutionFailedEvent(
execution_id=ctx.execution_id,
- aggregate_id=ctx.execution_id, # Set aggregate_id to execution_id
+ aggregate_id=ctx.execution_id,
error_type=error_info.error_type,
exit_code=exit_code,
- stdout=logs.stdout,
+ stdout=stdout,
stderr=stderr,
error_message=stderr,
- resource_usage=logs.resource_usage or ResourceUsageDomain.from_dict({}),
+ resource_usage=logs.resource_usage if logs else None,
metadata=ctx.metadata,
)
- logger.info(
- f"POD-EVENT: mapped failed exec={ctx.execution_id} error_type={error_info.error_type} "
- f"exit={error_info.exit_code}"
+ self.logger.info(
+ f"POD-EVENT: mapped failed exec={ctx.execution_id} error_type={error_info.error_type} exit={exit_code}"
)
return evt
@@ -320,7 +322,7 @@ def _map_terminated(self, ctx: PodContext) -> PodTerminatedEvent | None:
message=getattr(terminated, "message", None),
metadata=ctx.metadata,
)
- logger.info(
+ self.logger.info(
f"POD-EVENT: mapped terminated exec={ctx.execution_id} reason={terminated.reason} "
f"exit={terminated.exit_code}"
)
@@ -331,16 +333,22 @@ def _check_timeout(self, ctx: PodContext) -> ExecutionTimeoutEvent | None:
return None
logs = self._extract_logs(ctx.pod)
+ if not logs:
+ self.logger.error(f"POD-EVENT: failed to extract logs for timed out pod exec={ctx.execution_id}")
+ return None
+
evt = ExecutionTimeoutEvent(
execution_id=ctx.execution_id,
aggregate_id=ctx.execution_id,
timeout_seconds=ctx.pod.spec.active_deadline_seconds or 0,
+ resource_usage=logs.resource_usage,
stdout=logs.stdout,
stderr=logs.stderr,
- resource_usage=logs.resource_usage or ResourceUsageDomain.from_dict({}),
metadata=ctx.metadata,
)
- logger.info(f"POD-EVENT: mapped timeout exec={ctx.execution_id} adl={ctx.pod.spec.active_deadline_seconds}")
+ self.logger.info(
+ f"POD-EVENT: mapped timeout exec={ctx.execution_id} adl={ctx.pod.spec.active_deadline_seconds}"
+ )
return evt
def _get_main_container(self, pod: k8s_client.V1Pod) -> k8s_client.V1ContainerStatus | None:
@@ -435,11 +443,11 @@ def _analyze_failure(self, pod: k8s_client.V1Pod) -> FailureInfo:
return default
- def _extract_logs(self, pod: k8s_client.V1Pod) -> PodLogs:
- """Extract and parse pod logs"""
+ def _extract_logs(self, pod: k8s_client.V1Pod) -> PodLogs | None:
+ """Extract and parse pod logs. Returns None if extraction fails."""
# Without k8s API or metadata, can't fetch logs
if not self._k8s_api or not pod.metadata:
- return PodLogs()
+ return None
# Check if any container terminated
has_terminated = any(
@@ -447,8 +455,8 @@ def _extract_logs(self, pod: k8s_client.V1Pod) -> PodLogs:
)
if not has_terminated:
- logger.debug(f"Pod {pod.metadata.name} has no terminated containers")
- return PodLogs()
+ self.logger.debug(f"Pod {pod.metadata.name} has no terminated containers")
+ return None
try:
logs = self._k8s_api.read_namespaced_pod_log(
@@ -456,17 +464,17 @@ def _extract_logs(self, pod: k8s_client.V1Pod) -> PodLogs:
)
if not logs:
- return PodLogs()
+ return None
# Try to parse executor JSON
return self._parse_executor_output(logs)
except Exception as e:
self._log_extraction_error(pod.metadata.name, str(e))
- return PodLogs()
+ return None
- def _parse_executor_output(self, logs: str) -> PodLogs:
- """Parse executor JSON output from logs"""
+ def _parse_executor_output(self, logs: str) -> PodLogs | None:
+ """Parse executor JSON output from logs. Returns None if parsing fails."""
logs_stripped = logs.strip()
# Try full output as JSON
@@ -478,9 +486,9 @@ def _parse_executor_output(self, logs: str) -> PodLogs:
if result := self._try_parse_json(line.strip()):
return result
- # Fallback to raw logs
- logger.warning("Logs do not contain valid executor JSON, treating as raw output")
- return PodLogs(stdout=logs)
+ # No valid executor JSON found
+ self.logger.warning("Logs do not contain valid executor JSON")
+ return None
def _try_parse_json(self, text: str) -> PodLogs | None:
"""Try to parse text as executor JSON output"""
@@ -492,7 +500,7 @@ def _try_parse_json(self, text: str) -> PodLogs | None:
stdout=data.get("stdout", ""),
stderr=data.get("stderr", ""),
exit_code=data.get("exit_code", 0),
- resource_usage=ResourceUsageDomain.from_dict(data.get("resource_usage", {})),
+ resource_usage=ResourceUsageDomain(**data.get("resource_usage", {})),
)
def _log_extraction_error(self, pod_name: str, error: str) -> None:
@@ -500,11 +508,11 @@ def _log_extraction_error(self, pod_name: str, error: str) -> None:
error_lower = error.lower()
if "404" in error or "not found" in error_lower:
- logger.debug(f"Pod {pod_name} logs not found - pod may have been deleted")
+ self.logger.debug(f"Pod {pod_name} logs not found - pod may have been deleted")
elif "400" in error:
- logger.debug(f"Pod {pod_name} logs not available - container may still be creating")
+ self.logger.debug(f"Pod {pod_name} logs not available - container may still be creating")
else:
- logger.warning(f"Failed to extract logs from pod {pod_name}: {error}")
+ self.logger.warning(f"Failed to extract logs from pod {pod_name}: {error}")
def clear_cache(self) -> None:
"""Clear event cache"""
diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py
index 97fbbb4f..b2512e68 100644
--- a/backend/app/services/pod_monitor/monitor.py
+++ b/backend/app/services/pod_monitor/monitor.py
@@ -1,28 +1,31 @@
import asyncio
+import logging
import signal
import time
from collections.abc import AsyncIterator
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from enum import auto
-from typing import Any, Protocol
+from typing import Any
+from beanie import init_beanie
from kubernetes import client as k8s_client
from kubernetes import config as k8s_config
from kubernetes import watch
from kubernetes.client.rest import ApiException
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
+from app.core.logging import setup_logger
from app.core.metrics.context import get_kubernetes_metrics
from app.core.utils import StringEnum
-
-# Metrics will be passed as parameter to avoid globals
+from app.db.docs import ALL_DOCUMENTS
+from app.db.repositories.event_repository import EventRepository
from app.events.core import ProducerConfig, UnifiedProducer
from app.events.schema.schema_registry import create_schema_registry_manager, initialize_event_schemas
from app.infrastructure.kafka.events import BaseEvent
-from app.infrastructure.kafka.mappings import get_topic_for_event
+from app.services.kafka_event_service import KafkaEventService
from app.services.pod_monitor.config import PodMonitorConfig
from app.services.pod_monitor.event_mapper import PodEventMapper
from app.settings import get_settings
@@ -96,65 +99,34 @@ class ReconciliationResult:
error: str | None = None
-class EventPublisher(Protocol):
- """Protocol for event publishing."""
-
- async def send_event(self, event: BaseEvent, topic: str, key: str | None = None) -> bool:
- """Send an event to a topic."""
- ...
-
- async def is_healthy(self) -> bool:
- """Check if publisher is healthy."""
- ...
-
-
-class UnifiedProducerAdapter:
- """Adapter to make UnifiedProducer compatible with EventPublisher protocol."""
-
- def __init__(self, producer: UnifiedProducer) -> None:
- self._producer = producer
-
- async def send_event(self, event: BaseEvent, topic: str, key: str | None = None) -> bool:
- """Send event and return success status."""
- try:
- await self._producer.produce(event_to_produce=event, key=key)
- return True
- except Exception as e:
- logger.error(f"Failed to send event: {e}")
- return False
-
- async def is_healthy(self) -> bool:
- """Check if producer is healthy."""
- # UnifiedProducer doesn't have is_healthy, assume healthy if initialized
- return self._producer._producer is not None
-
-
class PodMonitor(LifecycleEnabled):
"""
Monitors Kubernetes pods and publishes lifecycle events.
This service watches pods with specific labels using the K8s watch API,
maps Kubernetes events to application events, and publishes them to Kafka.
+ Events are stored in the events collection AND published to Kafka via KafkaEventService.
"""
def __init__(
- self, config: PodMonitorConfig, producer: UnifiedProducer, k8s_clients: K8sClients | None = None
+ self,
+ config: PodMonitorConfig,
+ kafka_event_service: KafkaEventService,
+ logger: logging.Logger,
+ k8s_clients: K8sClients | None = None,
) -> None:
"""Initialize the pod monitor."""
+ self.logger = logger
self.config = config or PodMonitorConfig()
- settings = get_settings()
-
- # Kafka configuration
- self.kafka_servers = self.config.kafka_bootstrap_servers or settings.KAFKA_BOOTSTRAP_SERVERS
# Kubernetes clients (initialized on start)
self._v1: k8s_client.CoreV1Api | None = None
self._watch: watch.Watch | None = None
self._clients: K8sClients | None = k8s_clients
- # Components - producer is required
- self._event_mapper = PodEventMapper()
- self._producer = UnifiedProducerAdapter(producer)
+ # Components
+ self._event_mapper = PodEventMapper(logger=self.logger)
+ self._kafka_event_service = kafka_event_service
# State
self._state = MonitorState.IDLE
@@ -177,10 +149,10 @@ def state(self) -> MonitorState:
async def start(self) -> None:
"""Start the pod monitor."""
if self._state != MonitorState.IDLE:
- logger.warning(f"Cannot start monitor in state: {self._state}")
+ self.logger.warning(f"Cannot start monitor in state: {self._state}")
return
- logger.info("Starting PodMonitor service...")
+ self.logger.info("Starting PodMonitor service...")
# Initialize components
self._initialize_kubernetes_client()
@@ -193,14 +165,14 @@ async def start(self) -> None:
if self.config.enable_state_reconciliation:
self._reconcile_task = asyncio.create_task(self._reconciliation_loop())
- logger.info("PodMonitor service started successfully")
+ self.logger.info("PodMonitor service started successfully")
async def stop(self) -> None:
"""Stop the pod monitor."""
if self._state == MonitorState.STOPPED:
return
- logger.info("Stopping PodMonitor service...")
+ self.logger.info("Stopping PodMonitor service...")
self._state = MonitorState.STOPPING
# Cancel tasks
@@ -221,25 +193,25 @@ async def stop(self) -> None:
self._event_mapper.clear_cache()
self._state = MonitorState.STOPPED
- logger.info("PodMonitor service stopped")
+ self.logger.info("PodMonitor service stopped")
def _initialize_kubernetes_client(self) -> None:
"""Initialize Kubernetes API clients."""
if self._clients is None:
match (self.config.in_cluster, self.config.kubeconfig_path):
case (True, _):
- logger.info("Using in-cluster Kubernetes configuration")
+ self.logger.info("Using in-cluster Kubernetes configuration")
k8s_config.load_incluster_config()
case (False, path) if path:
- logger.info(f"Using kubeconfig from {path}")
+ self.logger.info(f"Using kubeconfig from {path}")
k8s_config.load_kube_config(config_file=path)
case _:
- logger.info("Using default kubeconfig")
+ self.logger.info("Using default kubeconfig")
k8s_config.load_kube_config()
configuration = k8s_client.Configuration.get_default_copy()
- logger.info(f"Kubernetes API host: {configuration.host}")
- logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}")
+ self.logger.info(f"Kubernetes API host: {configuration.host}")
+ self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}")
api_client = k8s_client.ApiClient(configuration)
self._v1 = k8s_client.CoreV1Api(api_client)
@@ -248,8 +220,8 @@ def _initialize_kubernetes_client(self) -> None:
self._watch = watch.Watch()
self._v1.get_api_resources()
- logger.info("Successfully connected to Kubernetes API")
- self._event_mapper = PodEventMapper(k8s_api=self._v1)
+ self.logger.info("Successfully connected to Kubernetes API")
+ self._event_mapper = PodEventMapper(logger=self.logger, k8s_api=self._v1)
async def _watch_pods(self) -> None:
"""Main watch loop for pods."""
@@ -261,17 +233,17 @@ async def _watch_pods(self) -> None:
except ApiException as e:
match e.status:
case 410: # Gone - resource version too old
- logger.warning("Resource version expired, resetting watch")
+ self.logger.warning("Resource version expired, resetting watch")
self._last_resource_version = None
self._metrics.record_pod_monitor_watch_error(str(ErrorType.RESOURCE_VERSION_EXPIRED.value))
case _:
- logger.error(f"API error in watch: {e}")
+ self.logger.error(f"API error in watch: {e}")
self._metrics.record_pod_monitor_watch_error(str(ErrorType.API_ERROR.value))
await self._handle_watch_error()
except Exception as e:
- logger.error(f"Unexpected error in watch: {e}", exc_info=True)
+ self.logger.error(f"Unexpected error in watch: {e}", exc_info=True)
self._metrics.record_pod_monitor_watch_error(str(ErrorType.UNEXPECTED.value))
await self._handle_watch_error()
@@ -287,7 +259,7 @@ async def _watch_pod_events(self) -> None:
resource_version=self._last_resource_version,
)
- logger.info(f"Starting pod watch with selector: {context.label_selector}, namespace: {context.namespace}")
+ self.logger.info(f"Starting pod watch with selector: {context.label_selector}, namespace: {context.namespace}")
# Create watch stream
kwargs = {
@@ -342,7 +314,7 @@ async def _process_raw_event(self, raw_event: KubeEvent) -> None:
await self._process_pod_event(event)
except (KeyError, ValueError) as e:
- logger.error(f"Invalid event format: {e}")
+ self.logger.error(f"Invalid event format: {e}")
self._metrics.record_pod_monitor_watch_error(str(ErrorType.PROCESSING_ERROR.value))
async def _process_pod_event(self, event: PodEvent) -> None:
@@ -379,7 +351,7 @@ async def _process_pod_event(self, event: PodEvent) -> None:
# Log event
if app_events:
- logger.info(
+ self.logger.info(
f"Processed {event.event_type.value} event for pod {pod_name} "
f"(phase: {pod_phase or 'Unknown'}), "
f"published {len(app_events)} events"
@@ -390,52 +362,33 @@ async def _process_pod_event(self, event: PodEvent) -> None:
self._metrics.record_pod_monitor_event_processing_duration(duration, str(event.event_type.value))
except Exception as e:
- logger.error(f"Error processing pod event: {e}", exc_info=True)
+ self.logger.error(f"Error processing pod event: {e}", exc_info=True)
self._metrics.record_pod_monitor_watch_error(str(ErrorType.PROCESSING_ERROR.value))
async def _publish_event(self, event: BaseEvent, pod: k8s_client.V1Pod) -> None:
- """Publish event to Kafka."""
+ """Publish event to Kafka and store in events collection."""
try:
- # Get proper topic from event type mapping
-
- topic = str(get_topic_for_event(event.event_type))
-
# Add correlation ID from pod labels
if pod.metadata and pod.metadata.labels:
event.metadata.correlation_id = pod.metadata.labels.get("execution-id")
- # Get execution ID from event if it has one
execution_id = getattr(event, "execution_id", None) or event.aggregate_id
+ key = str(execution_id or (pod.metadata.name if pod.metadata else "unknown"))
- logger.info(f"Publishing event {event.event_type} to topic {topic} for execution_id: {execution_id}")
-
- # Check producer health
- if not await self._producer.is_healthy():
- logger.error(f"Producer is not healthy, cannot send event {event.event_type}")
- return
-
- # Publish event
- key = str(execution_id or pod.metadata.name)
- success = await self._producer.send_event(event=event, topic=topic, key=key)
+ await self._kafka_event_service.publish_base_event(event=event, key=key)
- if not success:
- logger.error(f"Failed to send event {event.event_type} to topic {topic}")
- return
-
- # Event published successfully
phase = pod.status.phase if pod.status else "Unknown"
self._metrics.record_pod_monitor_event_published(str(event.event_type), phase)
- logger.info(f"Successfully published {event.event_type} event to {topic}")
except Exception as e:
- logger.error(f"Error publishing event: {e}", exc_info=True)
+ self.logger.error(f"Error publishing event: {e}", exc_info=True)
async def _handle_watch_error(self) -> None:
"""Handle watch errors with exponential backoff."""
self._reconnect_attempts += 1
if self._reconnect_attempts > self.config.max_reconnect_attempts:
- logger.error(
+ self.logger.error(
f"Max reconnect attempts ({self.config.max_reconnect_attempts}) exceeded, stopping pod monitor"
)
self._state = MonitorState.STOPPING
@@ -444,7 +397,7 @@ async def _handle_watch_error(self) -> None:
# Calculate exponential backoff
backoff = min(self.config.watch_reconnect_delay * (2 ** (self._reconnect_attempts - 1)), MAX_BACKOFF_SECONDS)
- logger.info(
+ self.logger.info(
f"Reconnecting watch in {backoff}s "
f"(attempt {self._reconnect_attempts}/{self.config.max_reconnect_attempts})"
)
@@ -463,7 +416,7 @@ async def _reconciliation_loop(self) -> None:
self._log_reconciliation_result(result)
except Exception as e:
- logger.error(f"Error in reconciliation loop: {e}", exc_info=True)
+ self.logger.error(f"Error in reconciliation loop: {e}", exc_info=True)
async def _reconcile_state(self) -> ReconciliationResult:
"""Reconcile tracked pods with actual state."""
@@ -472,11 +425,11 @@ async def _reconcile_state(self) -> ReconciliationResult:
start_time = time.time()
try:
- logger.info("Starting pod state reconciliation")
+ self.logger.info("Starting pod state reconciliation")
# List all pods matching selector
if not self._v1:
- logger.warning("K8s API not initialized, skipping reconciliation")
+ self.logger.warning("K8s API not initialized, skipping reconciliation")
return ReconciliationResult(
missing_pods=set(),
extra_pods=set(),
@@ -499,7 +452,7 @@ async def _reconcile_state(self) -> ReconciliationResult:
# Process missing pods
for pod in pods.items:
if pod.metadata.name in missing_pods:
- logger.info(f"Reconciling missing pod: {pod.metadata.name}")
+ self.logger.info(f"Reconciling missing pod: {pod.metadata.name}")
event = PodEvent(
event_type=WatchEventType.ADDED, pod=pod, resource_version=pod.metadata.resource_version
)
@@ -507,7 +460,7 @@ async def _reconcile_state(self) -> ReconciliationResult:
# Remove extra pods
for pod_name in extra_pods:
- logger.info(f"Removing stale pod from tracking: {pod_name}")
+ self.logger.info(f"Removing stale pod from tracking: {pod_name}")
self._tracked_pods.discard(pod_name)
# Update metrics
@@ -521,7 +474,7 @@ async def _reconcile_state(self) -> ReconciliationResult:
)
except Exception as e:
- logger.error(f"Failed to reconcile state: {e}", exc_info=True)
+ self.logger.error(f"Failed to reconcile state: {e}", exc_info=True)
self._metrics.record_pod_monitor_reconciliation_run("failed")
return ReconciliationResult(
@@ -535,13 +488,13 @@ async def _reconcile_state(self) -> ReconciliationResult:
def _log_reconciliation_result(self, result: ReconciliationResult) -> None:
"""Log reconciliation result."""
if result.success:
- logger.info(
+ self.logger.info(
f"Reconciliation completed in {result.duration_seconds:.2f}s. "
f"Found {len(result.missing_pods)} missing, "
f"{len(result.extra_pods)} extra pods"
)
else:
- logger.error(f"Reconciliation failed after {result.duration_seconds:.2f}s: {result.error}")
+ self.logger.error(f"Reconciliation failed after {result.duration_seconds:.2f}s: {result.error}")
async def get_status(self) -> StatusDict:
"""Get monitor status."""
@@ -561,11 +514,17 @@ async def get_status(self) -> StatusDict:
@asynccontextmanager
async def create_pod_monitor(
config: PodMonitorConfig,
- producer: UnifiedProducer,
+ kafka_event_service: KafkaEventService,
+ logger: logging.Logger,
k8s_clients: K8sClients | None = None,
) -> AsyncIterator[PodMonitor]:
"""Create and manage a pod monitor instance."""
- monitor = PodMonitor(config=config, producer=producer, k8s_clients=k8s_clients)
+ monitor = PodMonitor(
+ config=config,
+ kafka_event_service=kafka_event_service,
+ logger=logger,
+ k8s_clients=k8s_clients,
+ )
try:
await monitor.start()
@@ -576,17 +535,45 @@ async def create_pod_monitor(
async def run_pod_monitor() -> None:
"""Run the pod monitor service."""
+ import os
+
+ logger = setup_logger(os.environ.get("LOG_LEVEL", "INFO"))
+ settings = get_settings()
+
+ # Initialize MongoDB
+ db_client: AsyncMongoClient[Any] = AsyncMongoClient(
+ settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000
+ )
+ database = db_client[settings.DATABASE_NAME]
+ await db_client.admin.command("ping")
+ logger.info(f"Connected to database: {settings.DATABASE_NAME}")
+ await init_beanie(database=database, document_models=ALL_DOCUMENTS)
+
# Initialize schema registry
- schema_registry_manager = create_schema_registry_manager()
+ schema_registry_manager = create_schema_registry_manager(logger)
await initialize_event_schemas(schema_registry_manager)
- # Create producer and monitor
- settings = get_settings()
+ # Create producer
producer_config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS)
- producer = UnifiedProducer(producer_config, schema_registry_manager)
+ producer = UnifiedProducer(producer_config, schema_registry_manager, logger)
+
+ # Create KafkaEventService (stores events + publishes to Kafka)
+ event_repository = EventRepository(logger)
+ kafka_event_service = KafkaEventService(
+ event_repository=event_repository,
+ kafka_producer=producer,
+ logger=logger,
+ )
+
+ # Create monitor
monitor_config = PodMonitorConfig()
- clients = create_k8s_clients()
- monitor = PodMonitor(config=monitor_config, producer=producer, k8s_clients=clients)
+ clients = create_k8s_clients(logger)
+ monitor = PodMonitor(
+ config=monitor_config,
+ kafka_event_service=kafka_event_service,
+ logger=logger,
+ k8s_clients=clients,
+ )
# Setup signal handlers
loop = asyncio.get_running_loop()
@@ -596,13 +583,12 @@ async def shutdown() -> None:
logger.info("Initiating graceful shutdown...")
await monitor.stop()
await producer.stop()
+ await db_client.close()
- # Register signal handlers
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))
async with AsyncExitStack() as stack:
- # Ensure Kubernetes clients are always closed, even on exceptions
stack.callback(close_k8s_clients, clients)
await stack.enter_async_context(producer)
await stack.enter_async_context(monitor)
diff --git a/backend/app/services/rate_limit_service.py b/backend/app/services/rate_limit_service.py
index e4b8d9e5..cc98a4c1 100644
--- a/backend/app/services/rate_limit_service.py
+++ b/backend/app/services/rate_limit_service.py
@@ -5,13 +5,14 @@
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
-from typing import Any, Awaitable, Generator, Optional, cast
+from typing import Any, Awaitable, Dict, Generator, Optional, cast
import redis.asyncio as redis
from app.core.metrics.rate_limit import RateLimitMetrics
from app.core.tracing.utils import add_span_attributes
from app.domain.rate_limit import (
+ EndpointGroup,
RateLimitAlgorithm,
RateLimitConfig,
RateLimitRule,
@@ -19,10 +20,93 @@
UserRateLimit,
UserRateLimitSummary,
)
-from app.infrastructure.mappers import RateLimitConfigMapper
from app.settings import Settings
+def _rule_to_dict(rule: RateLimitRule) -> Dict[str, Any]:
+ return {
+ "endpoint_pattern": rule.endpoint_pattern,
+ "group": rule.group.value,
+ "requests": rule.requests,
+ "window_seconds": rule.window_seconds,
+ "burst_multiplier": rule.burst_multiplier,
+ "algorithm": rule.algorithm.value,
+ "priority": rule.priority,
+ "enabled": rule.enabled,
+ }
+
+
+def _rule_from_dict(data: Dict[str, Any]) -> RateLimitRule:
+ return RateLimitRule(
+ endpoint_pattern=data["endpoint_pattern"],
+ group=EndpointGroup(data["group"]),
+ requests=data["requests"],
+ window_seconds=data["window_seconds"],
+ burst_multiplier=data.get("burst_multiplier", 1.5),
+ algorithm=RateLimitAlgorithm(data.get("algorithm", RateLimitAlgorithm.SLIDING_WINDOW)),
+ priority=data.get("priority", 0),
+ enabled=data.get("enabled", True),
+ )
+
+
+def _user_limit_to_dict(user_limit: UserRateLimit) -> Dict[str, Any]:
+ return {
+ "user_id": user_limit.user_id,
+ "bypass_rate_limit": user_limit.bypass_rate_limit,
+ "global_multiplier": user_limit.global_multiplier,
+ "rules": [_rule_to_dict(rule) for rule in user_limit.rules],
+ "created_at": user_limit.created_at.isoformat() if user_limit.created_at else None,
+ "updated_at": user_limit.updated_at.isoformat() if user_limit.updated_at else None,
+ "notes": user_limit.notes,
+ }
+
+
+def _user_limit_from_dict(data: Dict[str, Any]) -> UserRateLimit:
+ created_at = data.get("created_at")
+ if created_at and isinstance(created_at, str):
+ created_at = datetime.fromisoformat(created_at)
+ elif not created_at:
+ created_at = datetime.now(timezone.utc)
+
+ updated_at = data.get("updated_at")
+ if updated_at and isinstance(updated_at, str):
+ updated_at = datetime.fromisoformat(updated_at)
+ elif not updated_at:
+ updated_at = datetime.now(timezone.utc)
+
+ return UserRateLimit(
+ user_id=data["user_id"],
+ bypass_rate_limit=data.get("bypass_rate_limit", False),
+ global_multiplier=data.get("global_multiplier", 1.0),
+ rules=[_rule_from_dict(rule_data) for rule_data in data.get("rules", [])],
+ created_at=created_at,
+ updated_at=updated_at,
+ notes=data.get("notes"),
+ )
+
+
+def _config_to_json(config: RateLimitConfig) -> str:
+ data = {
+ "default_rules": [_rule_to_dict(rule) for rule in config.default_rules],
+ "user_overrides": {uid: _user_limit_to_dict(user_limit) for uid, user_limit in config.user_overrides.items()},
+ "global_enabled": config.global_enabled,
+ "redis_ttl": config.redis_ttl,
+ }
+ return json.dumps(data)
+
+
+def _config_from_json(json_str: str | bytes) -> RateLimitConfig:
+ data = json.loads(json_str)
+ return RateLimitConfig(
+ default_rules=[_rule_from_dict(rule_data) for rule_data in data.get("default_rules", [])],
+ user_overrides={
+ uid: _user_limit_from_dict(user_data) for uid, user_data in data.get("user_overrides", {}).items()
+ },
+ global_enabled=data.get("global_enabled", True),
+ redis_ttl=data.get("redis_ttl", 3600),
+ )
+
+
class RateLimitService:
def __init__(self, redis_client: redis.Redis, settings: Settings, metrics: "RateLimitMetrics"):
self.redis = redis_client
@@ -380,17 +464,16 @@ async def _get_config(self) -> RateLimitConfig:
# Try to get from Redis cache
config_key = f"{self.prefix}config"
config_data = await self.redis.get(config_key)
- mapper = RateLimitConfigMapper()
if config_data:
- config = mapper.model_validate_json(config_data)
+ config = _config_from_json(config_data)
else:
# Return default config and cache it
config = RateLimitConfig.get_default_config()
await self.redis.setex(
config_key,
300, # Cache for 5 minutes
- mapper.model_dump_json(config),
+ _config_to_json(config),
)
# Prepare for fast matching
@@ -409,10 +492,9 @@ async def _get_config(self) -> RateLimitConfig:
async def update_config(self, config: RateLimitConfig) -> None:
config_key = f"{self.prefix}config"
- mapper = RateLimitConfigMapper()
with self._timer(self.metrics.redis_duration, {"operation": "update_config"}):
- await self.redis.setex(config_key, 300, mapper.model_dump_json(config))
+ await self.redis.setex(config_key, 300, _config_to_json(config))
# Update configuration metrics - just record the absolute values
active_rules_count = len([r for r in config.default_rules if r.enabled])
diff --git a/backend/app/services/replay_service.py b/backend/app/services/replay_service.py
index 916f3d76..2bf4e86d 100644
--- a/backend/app/services/replay_service.py
+++ b/backend/app/services/replay_service.py
@@ -1,12 +1,13 @@
+import logging
from datetime import datetime, timedelta, timezone
from typing import List
-from app.core.exceptions import ServiceError
-from app.core.logging import logger
from app.db.repositories.replay_repository import ReplayRepository
from app.domain.replay import (
ReplayConfig,
+ ReplayOperationError,
ReplayOperationResult,
+ ReplaySessionNotFoundError,
ReplaySessionState,
)
from app.schemas_pydantic.replay import CleanupResponse
@@ -19,9 +20,12 @@
class ReplayService:
"""Service for managing replay sessions and providing business logic"""
- def __init__(self, repository: ReplayRepository, event_replay_service: EventReplayService) -> None:
+ def __init__(
+ self, repository: ReplayRepository, event_replay_service: EventReplayService, logger: logging.Logger
+ ) -> None:
self.repository = repository
self.event_replay_service = event_replay_service
+ self.logger = logger
async def create_session_from_config(self, config: ReplayConfig) -> ReplayOperationResult:
"""Create a new replay session from a domain config"""
@@ -36,12 +40,12 @@ async def create_session_from_config(self, config: ReplayConfig) -> ReplayOperat
message="Replay session created successfully",
)
except Exception as e:
- logger.error(f"Failed to create replay session: {e}")
- raise ServiceError(str(e), status_code=500) from e
+ self.logger.error(f"Failed to create replay session: {e}")
+ raise ReplayOperationError("", "create", str(e)) from e
async def start_session(self, session_id: str) -> ReplayOperationResult:
"""Start a replay session"""
- logger.info(f"Starting replay session {session_id}")
+ self.logger.info(f"Starting replay session {session_id}")
try:
await self.event_replay_service.start_replay(session_id)
@@ -51,11 +55,11 @@ async def start_session(self, session_id: str) -> ReplayOperationResult:
session_id=session_id, status=ReplayStatus.RUNNING, message="Replay session started"
)
- except ValueError as e:
- raise ServiceError(str(e), status_code=404) from e
+ except ValueError:
+ raise ReplaySessionNotFoundError(session_id)
except Exception as e:
- logger.error(f"Failed to start replay session: {e}")
- raise ServiceError(str(e), status_code=500) from e
+ self.logger.error(f"Failed to start replay session: {e}")
+ raise ReplayOperationError(session_id, "start", str(e)) from e
async def pause_session(self, session_id: str) -> ReplayOperationResult:
"""Pause a replay session"""
@@ -68,11 +72,11 @@ async def pause_session(self, session_id: str) -> ReplayOperationResult:
session_id=session_id, status=ReplayStatus.PAUSED, message="Replay session paused"
)
- except ValueError as e:
- raise ServiceError(str(e), status_code=404) from e
+ except ValueError:
+ raise ReplaySessionNotFoundError(session_id)
except Exception as e:
- logger.error(f"Failed to pause replay session: {e}")
- raise ServiceError(str(e), status_code=500) from e
+ self.logger.error(f"Failed to pause replay session: {e}")
+ raise ReplayOperationError(session_id, "pause", str(e)) from e
async def resume_session(self, session_id: str) -> ReplayOperationResult:
"""Resume a paused replay session"""
@@ -85,11 +89,11 @@ async def resume_session(self, session_id: str) -> ReplayOperationResult:
session_id=session_id, status=ReplayStatus.RUNNING, message="Replay session resumed"
)
- except ValueError as e:
- raise ServiceError(str(e), status_code=404) from e
+ except ValueError:
+ raise ReplaySessionNotFoundError(session_id)
except Exception as e:
- logger.error(f"Failed to resume replay session: {e}")
- raise ServiceError(str(e), status_code=500) from e
+ self.logger.error(f"Failed to resume replay session: {e}")
+ raise ReplayOperationError(session_id, "resume", str(e)) from e
async def cancel_session(self, session_id: str) -> ReplayOperationResult:
"""Cancel a replay session"""
@@ -102,11 +106,11 @@ async def cancel_session(self, session_id: str) -> ReplayOperationResult:
session_id=session_id, status=ReplayStatus.CANCELLED, message="Replay session cancelled"
)
- except ValueError as e:
- raise ServiceError(str(e), status_code=404) from e
+ except ValueError:
+ raise ReplaySessionNotFoundError(session_id)
except Exception as e:
- logger.error(f"Failed to cancel replay session: {e}")
- raise ServiceError(str(e), status_code=500) from e
+ self.logger.error(f"Failed to cancel replay session: {e}")
+ raise ReplayOperationError(session_id, "cancel", str(e)) from e
def list_sessions(self, status: ReplayStatus | None = None, limit: int = 100) -> List[ReplaySessionState]:
"""List replay sessions with optional filtering (domain objects)."""
@@ -118,13 +122,13 @@ def get_session(self, session_id: str) -> ReplaySessionState:
# Get from memory-based service for performance
session = self.event_replay_service.get_session(session_id)
if not session:
- raise ServiceError("Session not found", status_code=404)
+ raise ReplaySessionNotFoundError(session_id)
return session
- except ServiceError:
+ except ReplaySessionNotFoundError:
raise
except Exception as e:
- logger.error(f"Failed to get replay session {session_id}: {e}")
- raise ServiceError("Internal server error", status_code=500) from e
+ self.logger.error(f"Failed to get replay session {session_id}: {e}")
+ raise ReplayOperationError(session_id, "get", str(e)) from e
async def cleanup_old_sessions(self, older_than_hours: int = 24) -> CleanupResponse:
"""Clean up old replay sessions"""
@@ -133,10 +137,10 @@ async def cleanup_old_sessions(self, older_than_hours: int = 24) -> CleanupRespo
# Clean up from database
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
- removed_db = await self.repository.delete_old_sessions(cutoff_time.isoformat())
+ removed_db = await self.repository.delete_old_sessions(cutoff_time)
total_removed = max(removed_memory, removed_db)
return CleanupResponse(removed_sessions=total_removed, message=f"Removed {total_removed} old sessions")
except Exception as e:
- logger.error(f"Failed to cleanup old sessions: {e}")
- raise ServiceError(str(e), status_code=500) from e
+ self.logger.error(f"Failed to cleanup old sessions: {e}")
+ raise ReplayOperationError("", "cleanup", str(e)) from e
diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py
index 1e852bf0..cb88b9c0 100644
--- a/backend/app/services/result_processor/processor.py
+++ b/backend/app/services/result_processor/processor.py
@@ -1,21 +1,24 @@
import asyncio
+import logging
+from contextlib import AsyncExitStack
from enum import auto
from typing import Any
+from beanie import init_beanie
from pydantic import BaseModel, ConfigDict, Field
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
from app.core.container import create_result_processor_container
-from app.core.exceptions import ServiceError
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics.context import get_execution_metrics
from app.core.utils import StringEnum
+from app.db.docs import ALL_DOCUMENTS
from app.db.repositories.execution_repository import ExecutionRepository
from app.domain.enums.events import EventType
from app.domain.enums.execution import ExecutionStatus
from app.domain.enums.kafka import GroupId, KafkaTopic
from app.domain.enums.storage import ExecutionErrorType, StorageType
-from app.domain.execution import ExecutionResultDomain
+from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain
from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer
from app.infrastructure.kafka import BaseEvent
from app.infrastructure.kafka.events.execution import (
@@ -63,7 +66,11 @@ class ResultProcessor(LifecycleEnabled):
"""Service for processing execution completion events and storing results."""
def __init__(
- self, execution_repo: ExecutionRepository, producer: UnifiedProducer, idempotency_manager: IdempotencyManager
+ self,
+ execution_repo: ExecutionRepository,
+ producer: UnifiedProducer,
+ idempotency_manager: IdempotencyManager,
+ logger: logging.Logger,
) -> None:
"""Initialize the result processor."""
self.config = ResultProcessorConfig()
@@ -74,30 +81,31 @@ def __init__(
self._state = ProcessingState.IDLE
self._consumer: IdempotentConsumerWrapper | None = None
self._dispatcher: EventDispatcher | None = None
+ self.logger = logger
async def start(self) -> None:
"""Start the result processor."""
if self._state != ProcessingState.IDLE:
- logger.warning(f"Cannot start processor in state: {self._state}")
+ self.logger.warning(f"Cannot start processor in state: {self._state}")
return
- logger.info("Starting ResultProcessor...")
+ self.logger.info("Starting ResultProcessor...")
# Initialize idempotency manager (safe to call multiple times)
await self._idempotency_manager.initialize()
- logger.info("Idempotency manager initialized for ResultProcessor")
+ self.logger.info("Idempotency manager initialized for ResultProcessor")
self._dispatcher = self._create_dispatcher()
self._consumer = await self._create_consumer()
self._state = ProcessingState.PROCESSING
- logger.info("ResultProcessor started successfully with idempotency protection")
+ self.logger.info("ResultProcessor started successfully with idempotency protection")
async def stop(self) -> None:
"""Stop the result processor."""
if self._state == ProcessingState.STOPPED:
return
- logger.info("Stopping ResultProcessor...")
+ self.logger.info("Stopping ResultProcessor...")
self._state = ProcessingState.STOPPED
if self._consumer:
@@ -105,11 +113,11 @@ async def stop(self) -> None:
await self._idempotency_manager.close()
await self._producer.stop()
- logger.info("ResultProcessor stopped")
+ self.logger.info("ResultProcessor stopped")
def _create_dispatcher(self) -> EventDispatcher:
"""Create and configure event dispatcher with handlers."""
- dispatcher = EventDispatcher()
+ dispatcher = EventDispatcher(logger=self.logger)
# Register handlers for specific event types
dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_completed_wrapper)
@@ -133,11 +141,12 @@ async def _create_consumer(self) -> IdempotentConsumerWrapper:
if not self._dispatcher:
raise RuntimeError("Event dispatcher not initialized")
- base_consumer = UnifiedConsumer(consumer_config, event_dispatcher=self._dispatcher)
+ base_consumer = UnifiedConsumer(consumer_config, event_dispatcher=self._dispatcher, logger=self.logger)
wrapper = IdempotentConsumerWrapper(
consumer=base_consumer,
idempotency_manager=self._idempotency_manager,
dispatcher=self._dispatcher,
+ logger=self.logger,
default_key_strategy="content_hash",
default_ttl_seconds=7200,
enable_for_all_handlers=True,
@@ -164,7 +173,7 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None:
exec_obj = await self._execution_repo.get_execution(event.execution_id)
if exec_obj is None:
- raise ServiceError(message=f"Execution {event.execution_id} not found", status_code=404)
+ raise ExecutionNotFoundError(event.execution_id)
lang_and_version = f"{exec_obj.lang}-{exec_obj.lang_version}"
@@ -199,7 +208,7 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None:
await self._execution_repo.write_terminal_result(result)
await self._publish_result_stored(result)
except Exception as e:
- logger.error(f"Failed to handle ExecutionCompletedEvent: {e}", exc_info=True)
+ self.logger.error(f"Failed to handle ExecutionCompletedEvent: {e}", exc_info=True)
await self._publish_result_failed(event.execution_id, str(e))
async def _handle_failed(self, event: ExecutionFailedEvent) -> None:
@@ -208,7 +217,7 @@ async def _handle_failed(self, event: ExecutionFailedEvent) -> None:
# Fetch execution to get language and version for metrics
exec_obj = await self._execution_repo.get_execution(event.execution_id)
if exec_obj is None:
- raise ServiceError(message=f"Execution {event.execution_id} not found", status_code=404)
+ raise ExecutionNotFoundError(event.execution_id)
self._metrics.record_error(event.error_type)
lang_and_version = f"{exec_obj.lang}-{exec_obj.lang_version}"
@@ -228,7 +237,7 @@ async def _handle_failed(self, event: ExecutionFailedEvent) -> None:
await self._execution_repo.write_terminal_result(result)
await self._publish_result_stored(result)
except Exception as e:
- logger.error(f"Failed to handle ExecutionFailedEvent: {e}", exc_info=True)
+ self.logger.error(f"Failed to handle ExecutionFailedEvent: {e}", exc_info=True)
await self._publish_result_failed(event.execution_id, str(e))
async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None:
@@ -236,7 +245,7 @@ async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None:
exec_obj = await self._execution_repo.get_execution(event.execution_id)
if exec_obj is None:
- raise ServiceError(message=f"Execution {event.execution_id} not found", status_code=404)
+ raise ExecutionNotFoundError(event.execution_id)
self._metrics.record_error(ExecutionErrorType.TIMEOUT)
lang_and_version = f"{exec_obj.lang}-{exec_obj.lang_version}"
@@ -259,7 +268,7 @@ async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None:
await self._execution_repo.write_terminal_result(result)
await self._publish_result_stored(result)
except Exception as e:
- logger.error(f"Failed to handle ExecutionTimeoutEvent: {e}", exc_info=True)
+ self.logger.error(f"Failed to handle ExecutionTimeoutEvent: {e}", exc_info=True)
await self._publish_result_failed(event.execution_id, str(e))
async def _publish_result_stored(self, result: ExecutionResultDomain) -> None:
@@ -302,22 +311,32 @@ async def get_status(self) -> dict[str, Any]:
async def run_result_processor() -> None:
- from contextlib import AsyncExitStack
+ settings = get_settings()
+
+ # Initialize MongoDB and Beanie ODM (required per-process)
+ db_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient(
+ settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000
+ )
+ await init_beanie(database=db_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS)
container = create_result_processor_container()
producer = await container.get(UnifiedProducer)
idempotency_manager = await container.get(IdempotencyManager)
execution_repo = await container.get(ExecutionRepository)
+ logger = await container.get(logging.Logger)
+ logger.info(f"Beanie ODM initialized with {len(ALL_DOCUMENTS)} document models")
processor = ResultProcessor(
execution_repo=execution_repo,
producer=producer,
idempotency_manager=idempotency_manager,
+ logger=logger,
)
async with AsyncExitStack() as stack:
await stack.enter_async_context(processor)
stack.push_async_callback(container.close)
+ stack.callback(db_client.close)
while True:
await asyncio.sleep(60)
diff --git a/backend/app/services/result_processor/resource_cleaner.py b/backend/app/services/result_processor/resource_cleaner.py
index 2fb5c4c4..1a48a1da 100644
--- a/backend/app/services/result_processor/resource_cleaner.py
+++ b/backend/app/services/result_processor/resource_cleaner.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import Any
@@ -7,8 +8,7 @@
from kubernetes import config as k8s_config
from kubernetes.client.rest import ApiException
-from app.core.exceptions import ServiceError
-from app.core.logging import logger
+from app.domain.exceptions import InfrastructureError, InvalidStateError
# Python 3.12 type aliases
type ResourceDict = dict[str, list[str]]
@@ -18,10 +18,11 @@
class ResourceCleaner:
"""Service for cleaning up Kubernetes resources"""
- def __init__(self) -> None:
+ def __init__(self, logger: logging.Logger) -> None:
self.v1: k8s_client.CoreV1Api | None = None
self.networking_v1: k8s_client.NetworkingV1Api | None = None
self._initialized = False
+ self.logger = logger
async def initialize(self) -> None:
"""Initialize Kubernetes clients"""
@@ -31,18 +32,18 @@ async def initialize(self) -> None:
try:
try:
k8s_config.load_incluster_config()
- logger.info("Using in-cluster Kubernetes config")
+ self.logger.info("Using in-cluster Kubernetes config")
except k8s_config.ConfigException:
k8s_config.load_kube_config()
- logger.info("Using kubeconfig")
+ self.logger.info("Using kubeconfig")
self.v1 = k8s_client.CoreV1Api()
self.networking_v1 = k8s_client.NetworkingV1Api()
self._initialized = True
except Exception as e:
- logger.error(f"Failed to initialize Kubernetes client: {e}")
- raise ServiceError(f"Kubernetes initialization failed: {e}") from e
+ self.logger.error(f"Failed to initialize Kubernetes client: {e}")
+ raise InfrastructureError(f"Kubernetes initialization failed: {e}") from e
async def cleanup_pod_resources(
self,
@@ -54,7 +55,7 @@ async def cleanup_pod_resources(
) -> None:
"""Clean up all resources associated with a pod"""
await self.initialize()
- logger.info(f"Cleaning up resources for pod: {pod_name}")
+ self.logger.info(f"Cleaning up resources for pod: {pod_name}")
try:
tasks = [
@@ -71,19 +72,19 @@ async def cleanup_pod_resources(
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=timeout)
- logger.info(f"Successfully cleaned up resources for pod: {pod_name}")
+ self.logger.info(f"Successfully cleaned up resources for pod: {pod_name}")
except asyncio.TimeoutError as e:
- logger.error(f"Timeout cleaning up resources for pod: {pod_name}")
- raise ServiceError("Resource cleanup timed out") from e
+ self.logger.error(f"Timeout cleaning up resources for pod: {pod_name}")
+ raise InfrastructureError("Resource cleanup timed out") from e
except Exception as e:
- logger.error(f"Failed to cleanup resources: {e}")
- raise ServiceError(f"Resource cleanup failed: {e}") from e
+ self.logger.error(f"Failed to cleanup resources: {e}")
+ raise InfrastructureError(f"Resource cleanup failed: {e}") from e
async def _delete_pod(self, pod_name: str, namespace: str) -> None:
"""Delete a pod"""
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
try:
loop = asyncio.get_event_loop()
@@ -93,19 +94,19 @@ async def _delete_pod(self, pod_name: str, namespace: str) -> None:
None, partial(self.v1.delete_namespaced_pod, pod_name, namespace, grace_period_seconds=30)
)
- logger.info(f"Deleted pod: {pod_name}")
+ self.logger.info(f"Deleted pod: {pod_name}")
except ApiException as e:
if e.status == 404:
- logger.info(f"Pod {pod_name} already deleted")
+ self.logger.info(f"Pod {pod_name} already deleted")
else:
- logger.error(f"Failed to delete pod: {e}")
+ self.logger.error(f"Failed to delete pod: {e}")
raise
async def _delete_configmaps(self, execution_id: str, namespace: str) -> None:
"""Delete ConfigMaps for an execution"""
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
await self._delete_labeled_resources(
execution_id,
@@ -118,7 +119,7 @@ async def _delete_configmaps(self, execution_id: str, namespace: str) -> None:
async def _delete_pvcs(self, execution_id: str, namespace: str) -> None:
"""Delete PersistentVolumeClaims for an execution"""
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
await self._delete_labeled_resources(
execution_id,
@@ -140,10 +141,10 @@ async def _delete_labeled_resources(
for resource in resources.items:
await loop.run_in_executor(None, delete_func, resource.metadata.name, namespace)
- logger.info(f"Deleted {resource_type}: {resource.metadata.name}")
+ self.logger.info(f"Deleted {resource_type}: {resource.metadata.name}")
except ApiException as e:
- logger.error(f"Failed to delete {resource_type}s: {e}")
+ self.logger.error(f"Failed to delete {resource_type}s: {e}")
async def cleanup_orphaned_resources(
self,
@@ -168,15 +169,15 @@ async def cleanup_orphaned_resources(
return cleaned
except Exception as e:
- logger.error(f"Failed to cleanup orphaned resources: {e}")
- raise ServiceError(f"Orphaned resource cleanup failed: {e}") from e
+ self.logger.error(f"Failed to cleanup orphaned resources: {e}")
+ raise InfrastructureError(f"Orphaned resource cleanup failed: {e}") from e
async def _cleanup_orphaned_pods(
self, namespace: str, cutoff_time: datetime, cleaned: ResourceDict, dry_run: bool
) -> None:
"""Clean up orphaned pods"""
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
loop = asyncio.get_event_loop()
pods = await loop.run_in_executor(
@@ -196,14 +197,14 @@ async def _cleanup_orphaned_pods(
try:
await self._delete_pod(pod.metadata.name, namespace)
except Exception as e:
- logger.error(f"Failed to delete orphaned pod {pod.metadata.name}: {e}")
+ self.logger.error(f"Failed to delete orphaned pod {pod.metadata.name}: {e}")
async def _cleanup_orphaned_configmaps(
self, namespace: str, cutoff_time: datetime, cleaned: ResourceDict, dry_run: bool
) -> None:
"""Clean up orphaned ConfigMaps"""
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
loop = asyncio.get_event_loop()
configmaps = await loop.run_in_executor(
@@ -220,7 +221,7 @@ async def _cleanup_orphaned_configmaps(
None, self.v1.delete_namespaced_config_map, cm.metadata.name, namespace
)
except Exception as e:
- logger.error(f"Failed to delete orphaned ConfigMap {cm.metadata.name}: {e}")
+ self.logger.error(f"Failed to delete orphaned ConfigMap {cm.metadata.name}: {e}")
async def get_resource_usage(self, namespace: str = "default") -> CountDict:
"""Get current resource usage counts"""
@@ -235,33 +236,33 @@ async def get_resource_usage(self, namespace: str = "default") -> CountDict:
# Get pods count
try:
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
pods = await loop.run_in_executor(
None, partial(self.v1.list_namespaced_pod, namespace, label_selector=label_selector)
)
pod_count = len(pods.items)
except Exception as e:
- logger.warning(f"Failed to get pods: {e}")
+ self.logger.warning(f"Failed to get pods: {e}")
pod_count = 0
# Get configmaps count
try:
if not self.v1:
- raise ServiceError("Kubernetes client not initialized")
+ raise InvalidStateError("Kubernetes client not initialized")
configmaps = await loop.run_in_executor(
None, partial(self.v1.list_namespaced_config_map, namespace, label_selector=label_selector)
)
configmap_count = len(configmaps.items)
except Exception as e:
- logger.warning(f"Failed to get configmaps: {e}")
+ self.logger.warning(f"Failed to get configmaps: {e}")
configmap_count = 0
# Get network policies count
try:
if not self.networking_v1:
- raise ServiceError("Kubernetes networking client not initialized")
+ raise InvalidStateError("Kubernetes networking client not initialized")
policies = await loop.run_in_executor(
None,
@@ -271,7 +272,7 @@ async def get_resource_usage(self, namespace: str = "default") -> CountDict:
)
policy_count = len(policies.items)
except Exception as e:
- logger.warning(f"Failed to get network policies: {e}")
+ self.logger.warning(f"Failed to get network policies: {e}")
policy_count = 0
return {
@@ -281,5 +282,5 @@ async def get_resource_usage(self, namespace: str = "default") -> CountDict:
}
except Exception as e:
- logger.error(f"Failed to get resource usage: {e}")
+ self.logger.error(f"Failed to get resource usage: {e}")
return default_counts
diff --git a/backend/app/services/saga/execution_saga.py b/backend/app/services/saga/execution_saga.py
index 616915f2..81f705da 100644
--- a/backend/app/services/saga/execution_saga.py
+++ b/backend/app/services/saga/execution_saga.py
@@ -3,6 +3,7 @@
from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository
from app.domain.enums.events import EventType
+from app.domain.saga import DomainResourceAllocationCreate
from app.events.core import UnifiedProducer
from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent
from app.infrastructure.kafka.events.metadata import AvroEventMetadata as EventMetadata
@@ -79,19 +80,18 @@ async def execute(self, context: SagaContext, event: ExecutionRequestedEvent) ->
raise ValueError("Resource limit exceeded")
# Create allocation record via repository
- ok = await self.alloc_repo.create_allocation(
- execution_id,
- execution_id=execution_id,
- language=event.language,
- cpu_request=event.cpu_request,
- memory_request=event.memory_request,
- cpu_limit=event.cpu_limit,
- memory_limit=event.memory_limit,
+ allocation = await self.alloc_repo.create_allocation(
+ DomainResourceAllocationCreate(
+ execution_id=execution_id,
+ language=event.language,
+ cpu_request=event.cpu_request,
+ memory_request=event.memory_request,
+ cpu_limit=event.cpu_limit,
+ memory_limit=event.memory_limit,
+ )
)
- if not ok:
- raise RuntimeError("Failed to persist resource allocation")
- context.set("allocation_id", execution_id)
+ context.set("allocation_id", allocation.allocation_id)
context.set("resources_allocated", True)
return True
diff --git a/backend/app/services/saga/saga_orchestrator.py b/backend/app/services/saga/saga_orchestrator.py
index 89f8c239..e200bf63 100644
--- a/backend/app/services/saga/saga_orchestrator.py
+++ b/backend/app/services/saga/saga_orchestrator.py
@@ -124,7 +124,7 @@ async def _start_consumer(self) -> None:
enable_auto_commit=False,
)
- dispatcher = EventDispatcher()
+ dispatcher = EventDispatcher(logger=logger)
for event_type in event_types_to_register:
dispatcher.register_handler(event_type, self._handle_event)
logger.info(f"Registered handler for event type: {event_type}")
@@ -132,11 +132,13 @@ async def _start_consumer(self) -> None:
base_consumer = UnifiedConsumer(
config=consumer_config,
event_dispatcher=dispatcher,
+ logger=logger,
)
self._consumer = IdempotentConsumerWrapper(
consumer=base_consumer,
idempotency_manager=self._idempotency_manager,
dispatcher=dispatcher,
+ logger=logger,
default_key_strategy="event_based",
default_ttl_seconds=7200,
enable_for_all_handlers=False,
@@ -193,7 +195,8 @@ async def _start_saga(self, saga_name: str, trigger_event: BaseEvent) -> str | N
existing = await self._repo.get_saga_by_execution_and_name(execution_id, saga_name)
if existing:
logger.info(f"Saga {saga_name} already exists for execution {execution_id}")
- return existing.saga_id
+ saga_id: str = existing.saga_id
+ return saga_id
instance = Saga(
saga_id=str(uuid4()),
diff --git a/backend/app/services/saga/saga_service.py b/backend/app/services/saga/saga_service.py
index 27617bb0..5ed6e4e3 100644
--- a/backend/app/services/saga/saga_service.py
+++ b/backend/app/services/saga/saga_service.py
@@ -1,4 +1,5 @@
-from app.core.logging import logger
+import logging
+
from app.db.repositories import ExecutionRepository, SagaRepository
from app.domain.enums import SagaState, UserRole
from app.domain.saga.exceptions import (
@@ -7,19 +8,26 @@
SagaNotFoundError,
)
from app.domain.saga.models import Saga, SagaFilter, SagaListResult
-from app.domain.user import User
+from app.schemas_pydantic.user import User
from app.services.saga import SagaOrchestrator
class SagaService:
"""Service for saga business logic and orchestration."""
- def __init__(self, saga_repo: SagaRepository, execution_repo: ExecutionRepository, orchestrator: SagaOrchestrator):
+ def __init__(
+ self,
+ saga_repo: SagaRepository,
+ execution_repo: ExecutionRepository,
+ orchestrator: SagaOrchestrator,
+ logger: logging.Logger,
+ ):
self.saga_repo = saga_repo
self.execution_repo = execution_repo
self.orchestrator = orchestrator
+ self.logger = logger
- logger.info(
+ self.logger.info(
"SagaService initialized",
extra={
"saga_repo": type(saga_repo).__name__,
@@ -39,27 +47,35 @@ async def check_execution_access(self, execution_id: str, user: User) -> bool:
if execution and execution.user_id == user.user_id:
return True
- logger.debug(
- f"Access denied for user {user.user_id} to execution {execution_id}",
- extra={"user_role": user.role, "execution_exists": execution is not None},
+ self.logger.debug(
+ "Access denied to execution",
+ extra={
+ "user_id": user.user_id,
+ "execution_id": execution_id,
+ "user_role": user.role,
+ "execution_exists": execution is not None,
+ },
)
return False
async def get_saga_with_access_check(self, saga_id: str, user: User) -> Saga:
"""Get saga with access control."""
- logger.debug(f"Getting saga {saga_id} for user {user.user_id}", extra={"user_role": user.role})
+ self.logger.debug(
+ "Getting saga for user", extra={"saga_id": saga_id, "user_id": user.user_id, "user_role": user.role}
+ )
saga = await self.saga_repo.get_saga(saga_id)
if not saga:
- logger.warning(f"Saga {saga_id} not found")
- raise SagaNotFoundError(f"Saga {saga_id} not found")
+ self.logger.warning("Saga not found", extra={"saga_id": saga_id})
+ raise SagaNotFoundError(saga_id)
# Check access permissions
if not await self.check_execution_access(saga.execution_id, user):
- logger.warning(
- f"Access denied for user {user.user_id} to saga {saga_id}", extra={"execution_id": saga.execution_id}
+ self.logger.warning(
+ "Access denied to saga",
+ extra={"user_id": user.user_id, "saga_id": saga_id, "execution_id": saga.execution_id},
)
- raise SagaAccessDeniedError(f"Access denied - you don't have access to execution {saga.execution_id}")
+ raise SagaAccessDeniedError(saga_id, user.user_id)
return saga
@@ -69,10 +85,11 @@ async def get_execution_sagas(
"""Get sagas for an execution with access control."""
# Check access to execution
if not await self.check_execution_access(execution_id, user):
- logger.warning(
- f"Access denied for user {user.user_id} to execution {execution_id}", extra={"user_role": user.role}
+ self.logger.warning(
+ "Access denied to execution",
+ extra={"user_id": user.user_id, "execution_id": execution_id, "user_role": user.role},
)
- raise SagaAccessDeniedError(f"Access denied - no access to execution {execution_id}")
+ raise SagaAccessDeniedError(execution_id, user.user_id)
return await self.saga_repo.get_sagas_by_execution(execution_id, state, limit=limit, skip=skip)
@@ -86,39 +103,49 @@ async def list_user_sagas(
if user.role != UserRole.ADMIN:
user_execution_ids = await self.saga_repo.get_user_execution_ids(user.user_id)
saga_filter.execution_ids = user_execution_ids
- logger.debug(
- f"Filtering sagas for user {user.user_id}",
- extra={"execution_count": len(user_execution_ids) if user_execution_ids else 0},
+ self.logger.debug(
+ "Filtering sagas for user",
+ extra={
+ "user_id": user.user_id,
+ "execution_count": len(user_execution_ids) if user_execution_ids else 0,
+ },
)
# Get sagas from repository
result = await self.saga_repo.list_sagas(saga_filter, limit, skip)
- logger.debug(
- f"Listed {len(result.sagas)} sagas for user {user.user_id}",
- extra={"total": result.total, "state_filter": str(state) if state else None},
+ self.logger.debug(
+ "Listed sagas for user",
+ extra={
+ "user_id": user.user_id,
+ "count": len(result.sagas),
+ "total": result.total,
+ "state_filter": str(state) if state else None,
+ },
)
return result
async def cancel_saga(self, saga_id: str, user: User) -> bool:
"""Cancel a saga with permission check."""
- logger.info(f"User {user.user_id} requesting cancellation of saga {saga_id}", extra={"user_role": user.role})
+ self.logger.info(
+ "User requesting saga cancellation",
+ extra={"user_id": user.user_id, "saga_id": saga_id, "user_role": user.role},
+ )
# Get saga with access check
saga = await self.get_saga_with_access_check(saga_id, user)
# Check if saga can be cancelled
if saga.state not in [SagaState.RUNNING, SagaState.CREATED]:
- raise SagaInvalidStateError(
- f"Cannot cancel saga in {saga.state} state. Only RUNNING or CREATED sagas can be cancelled."
- )
+ raise SagaInvalidStateError(saga_id, str(saga.state), "cancel")
# Use orchestrator to cancel
success = await self.orchestrator.cancel_saga(saga_id)
if success:
- logger.info(
- f"User {user.user_id} cancelled saga {saga_id}", extra={"user_role": user.role, "saga_id": saga_id}
+ self.logger.info(
+ "User cancelled saga",
+ extra={"user_id": user.user_id, "saga_id": saga_id, "user_role": user.role},
)
else:
- logger.error(f"Failed to cancel saga {saga_id} for user {user.user_id}", extra={"saga_id": saga_id})
+ self.logger.error("Failed to cancel saga", extra={"saga_id": saga_id, "user_id": user.user_id})
return success
async def get_saga_statistics(self, user: User, include_all: bool = False) -> dict[str, object]:
@@ -134,22 +161,22 @@ async def get_saga_statistics(self, user: User, include_all: bool = False) -> di
async def get_saga_status_from_orchestrator(self, saga_id: str, user: User) -> Saga | None:
"""Get saga status from orchestrator with fallback to database."""
- logger.debug(f"Getting live saga status for {saga_id}")
+ self.logger.debug("Getting live saga status", extra={"saga_id": saga_id})
# Try orchestrator first for live status
saga = await self.orchestrator.get_saga_status(saga_id)
if saga:
# Check access
if not await self.check_execution_access(saga.execution_id, user):
- logger.warning(
- f"Access denied for user {user.user_id} to live saga {saga_id}",
- extra={"execution_id": saga.execution_id},
+ self.logger.warning(
+ "Access denied to live saga",
+ extra={"user_id": user.user_id, "saga_id": saga_id, "execution_id": saga.execution_id},
)
- raise SagaAccessDeniedError(f"Access denied - no access to execution {saga.execution_id}")
+ raise SagaAccessDeniedError(saga_id, user.user_id)
- logger.debug(f"Retrieved live status for saga {saga_id}")
+ self.logger.debug("Retrieved live status for saga", extra={"saga_id": saga_id})
return saga
# Fall back to repository
- logger.debug(f"No live status found for saga {saga_id}, checking database")
+ self.logger.debug("No live status found for saga, checking database", extra={"saga_id": saga_id})
return await self.get_saga_with_access_check(saga_id, user)
diff --git a/backend/app/services/saved_script_service.py b/backend/app/services/saved_script_service.py
index d36e6e96..adedd344 100644
--- a/backend/app/services/saved_script_service.py
+++ b/backend/app/services/saved_script_service.py
@@ -1,21 +1,23 @@
-from app.core.exceptions import ServiceError
-from app.core.logging import logger
+import logging
+
from app.db.repositories import SavedScriptRepository
from app.domain.saved_script import (
DomainSavedScript,
DomainSavedScriptCreate,
DomainSavedScriptUpdate,
+ SavedScriptNotFoundError,
)
class SavedScriptService:
- def __init__(self, saved_script_repo: SavedScriptRepository):
+ def __init__(self, saved_script_repo: SavedScriptRepository, logger: logging.Logger):
self.saved_script_repo = saved_script_repo
+ self.logger = logger
async def create_saved_script(
self, saved_script_create: DomainSavedScriptCreate, user_id: str
) -> DomainSavedScript:
- logger.info(
+ self.logger.info(
"Creating new saved script",
extra={
"user_id": user_id,
@@ -26,7 +28,7 @@ async def create_saved_script(
created_script = await self.saved_script_repo.create_saved_script(saved_script_create, user_id)
- logger.info(
+ self.logger.info(
"Successfully created saved script",
extra={
"script_id": str(created_script.script_id),
@@ -37,7 +39,7 @@ async def create_saved_script(
return created_script
async def get_saved_script(self, script_id: str, user_id: str) -> DomainSavedScript:
- logger.info(
+ self.logger.info(
"Retrieving saved script",
extra={
"user_id": user_id,
@@ -47,14 +49,13 @@ async def get_saved_script(self, script_id: str, user_id: str) -> DomainSavedScr
script = await self.saved_script_repo.get_saved_script(script_id, user_id)
if not script:
- logger.warning(
+ self.logger.warning(
"Script not found for user",
extra={"user_id": user_id, "script_id": script_id},
)
+ raise SavedScriptNotFoundError(script_id)
- raise ServiceError("Script not found", status_code=404)
-
- logger.info(
+ self.logger.info(
"Successfully retrieved script",
extra={"script_id": script.script_id, "script_name": script.name},
)
@@ -63,7 +64,7 @@ async def get_saved_script(self, script_id: str, user_id: str) -> DomainSavedScr
async def update_saved_script(
self, script_id: str, user_id: str, update_data: DomainSavedScriptUpdate
) -> DomainSavedScript:
- logger.info(
+ self.logger.info(
"Updating saved script",
extra={
"user_id": user_id,
@@ -76,16 +77,16 @@ async def update_saved_script(
await self.saved_script_repo.update_saved_script(script_id, user_id, update_data)
updated_script = await self.saved_script_repo.get_saved_script(script_id, user_id)
if not updated_script:
- raise ServiceError("Script not found", status_code=404)
+ raise SavedScriptNotFoundError(script_id)
- logger.info(
+ self.logger.info(
"Successfully updated script",
extra={"script_id": script_id, "script_name": updated_script.name},
)
return updated_script
async def delete_saved_script(self, script_id: str, user_id: str) -> None:
- logger.info(
+ self.logger.info(
"Deleting saved script",
extra={
"user_id": user_id,
@@ -95,13 +96,13 @@ async def delete_saved_script(self, script_id: str, user_id: str) -> None:
await self.saved_script_repo.delete_saved_script(script_id, user_id)
- logger.info(
+ self.logger.info(
"Successfully deleted script",
extra={"script_id": script_id, "user_id": user_id},
)
async def list_saved_scripts(self, user_id: str) -> list[DomainSavedScript]:
- logger.info(
+ self.logger.info(
"Listing saved scripts",
extra={
"user_id": user_id,
@@ -110,7 +111,7 @@ async def list_saved_scripts(self, user_id: str) -> list[DomainSavedScript]:
scripts = await self.saved_script_repo.list_saved_scripts(user_id)
- logger.info(
+ self.logger.info(
"Successfully retrieved saved scripts",
extra={"user_id": user_id, "script_count": len(scripts)},
)
diff --git a/backend/app/services/sse/kafka_redis_bridge.py b/backend/app/services/sse/kafka_redis_bridge.py
index 64f8c216..f34b29c3 100644
--- a/backend/app/services/sse/kafka_redis_bridge.py
+++ b/backend/app/services/sse/kafka_redis_bridge.py
@@ -1,10 +1,10 @@
from __future__ import annotations
import asyncio
+import logging
import os
from app.core.lifecycle import LifecycleEnabled
-from app.core.logging import logger
from app.core.metrics.events import EventMetrics
from app.domain.enums.events import EventType
from app.domain.enums.kafka import KafkaTopic
@@ -30,11 +30,13 @@ def __init__(
settings: Settings,
event_metrics: EventMetrics,
sse_bus: SSERedisBus,
+ logger: logging.Logger,
) -> None:
self.schema_registry = schema_registry
self.settings = settings
self.event_metrics = event_metrics
self.sse_bus = sse_bus
+ self.logger = logger
self.num_consumers = settings.SSE_CONSUMER_POOL_SIZE
self.consumers: list[UnifiedConsumer] = []
@@ -48,7 +50,7 @@ async def start(self) -> None:
if self._initialized:
return
- logger.info(f"Starting SSE Kafka→Redis bridge with {self.num_consumers} consumers")
+ self.logger.info(f"Starting SSE Kafka→Redis bridge with {self.num_consumers} consumers")
for i in range(self.num_consumers):
consumer = await self._create_consumer(i)
@@ -56,14 +58,14 @@ async def start(self) -> None:
self._running = True
self._initialized = True
- logger.info("SSE Kafka→Redis bridge started successfully")
+ self.logger.info("SSE Kafka→Redis bridge started successfully")
async def stop(self) -> None:
async with self._lock:
if not self._initialized:
return
- logger.info("Stopping SSE Kafka→Redis bridge")
+ self.logger.info("Stopping SSE Kafka→Redis bridge")
self._running = False
for consumer in self.consumers:
@@ -71,7 +73,7 @@ async def stop(self) -> None:
self.consumers.clear()
self._initialized = False
- logger.info("SSE Kafka→Redis bridge stopped")
+ self.logger.info("SSE Kafka→Redis bridge stopped")
async def _create_consumer(self, consumer_index: int) -> UnifiedConsumer:
suffix = os.environ.get("KAFKA_GROUP_SUFFIX", "")
@@ -93,10 +95,10 @@ async def _create_consumer(self, consumer_index: int) -> UnifiedConsumer:
heartbeat_interval_ms=3000,
)
- dispatcher = EventDispatcher()
+ dispatcher = EventDispatcher(logger=self.logger)
self._register_routing_handlers(dispatcher)
- consumer = UnifiedConsumer(config=config, event_dispatcher=dispatcher)
+ consumer = UnifiedConsumer(config=config, event_dispatcher=dispatcher, logger=self.logger)
topics = [
KafkaTopic.EXECUTION_EVENTS,
@@ -109,7 +111,7 @@ async def _create_consumer(self, consumer_index: int) -> UnifiedConsumer:
]
await consumer.start(topics)
- logger.info(f"Bridge consumer {consumer_index} started")
+ self.logger.info(f"Bridge consumer {consumer_index} started")
return consumer
def _register_routing_handlers(self, dispatcher: EventDispatcher) -> None:
@@ -137,13 +139,13 @@ async def route_event(event: BaseEvent) -> None:
data = event.model_dump()
execution_id = data.get("execution_id")
if not execution_id:
- logger.debug(f"Event {event.event_type} has no execution_id")
+ self.logger.debug(f"Event {event.event_type} has no execution_id")
return
try:
await self.sse_bus.publish_event(execution_id, event)
- logger.info(f"Published {event.event_type} to Redis for {execution_id}")
+ self.logger.info(f"Published {event.event_type} to Redis for {execution_id}")
except Exception as e:
- logger.error(
+ self.logger.error(
f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}",
exc_info=True,
)
@@ -165,10 +167,12 @@ def create_sse_kafka_redis_bridge(
settings: Settings,
event_metrics: EventMetrics,
sse_bus: SSERedisBus,
+ logger: logging.Logger,
) -> SSEKafkaRedisBridge:
return SSEKafkaRedisBridge(
schema_registry=schema_registry,
settings=settings,
event_metrics=event_metrics,
sse_bus=sse_bus,
+ logger=logger,
)
diff --git a/backend/app/services/sse/redis_bus.py b/backend/app/services/sse/redis_bus.py
index e0d9fad8..979edd16 100644
--- a/backend/app/services/sse/redis_bus.py
+++ b/backend/app/services/sse/redis_bus.py
@@ -1,11 +1,11 @@
from __future__ import annotations
+import logging
from typing import Type, TypeVar
import redis.asyncio as redis
from pydantic import BaseModel
-from app.core.logging import logger
from app.infrastructure.kafka.events.base import BaseEvent
from app.schemas_pydantic.sse import RedisNotificationMessage, RedisSSEMessage
@@ -15,9 +15,10 @@
class SSERedisSubscription:
"""Subscription wrapper for Redis pubsub with typed message parsing."""
- def __init__(self, pubsub: redis.client.PubSub, channel: str) -> None:
+ def __init__(self, pubsub: redis.client.PubSub, channel: str, logger: logging.Logger) -> None:
self._pubsub = pubsub
self._channel = channel
+ self.logger = logger
async def get(self, model: Type[T]) -> T | None:
"""Get next typed message from the subscription."""
@@ -27,7 +28,7 @@ async def get(self, model: Type[T]) -> T | None:
try:
return model.model_validate_json(msg["data"])
except Exception as e:
- logger.warning(
+ self.logger.warning(
f"Failed to parse Redis message on channel {self._channel}: {e}",
extra={"channel": self._channel, "model": model.__name__},
)
@@ -44,9 +45,14 @@ class SSERedisBus:
"""Redis-backed pub/sub bus for SSE event fan-out across workers."""
def __init__(
- self, redis_client: redis.Redis, exec_prefix: str = "sse:exec:", notif_prefix: str = "sse:notif:"
+ self,
+ redis_client: redis.Redis,
+ logger: logging.Logger,
+ exec_prefix: str = "sse:exec:",
+ notif_prefix: str = "sse:notif:",
) -> None:
self._redis = redis_client
+ self.logger = logger
self._exec_prefix = exec_prefix
self._notif_prefix = notif_prefix
@@ -68,7 +74,7 @@ async def open_subscription(self, execution_id: str) -> SSERedisSubscription:
pubsub = self._redis.pubsub()
channel = self._exec_channel(execution_id)
await pubsub.subscribe(channel)
- return SSERedisSubscription(pubsub, channel)
+ return SSERedisSubscription(pubsub, channel, self.logger)
async def publish_notification(self, user_id: str, notification: RedisNotificationMessage) -> None:
"""Publish a typed notification message to Redis for SSE delivery."""
@@ -78,4 +84,4 @@ async def open_notification_subscription(self, user_id: str) -> SSERedisSubscrip
pubsub = self._redis.pubsub()
channel = self._notif_channel(user_id)
await pubsub.subscribe(channel)
- return SSERedisSubscription(pubsub, channel)
+ return SSERedisSubscription(pubsub, channel, self.logger)
diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py
index 9824ec8a..3608ec2e 100644
--- a/backend/app/services/sse/sse_service.py
+++ b/backend/app/services/sse/sse_service.py
@@ -1,9 +1,9 @@
import asyncio
+import logging
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import Any, Dict
-from app.core.logging import logger
from app.core.metrics.context import get_connection_metrics
from app.db.repositories.sse_repository import SSERepository
from app.domain.enums.events import EventType
@@ -37,12 +37,14 @@ def __init__(
sse_bus: SSERedisBus,
shutdown_manager: SSEShutdownManager,
settings: Settings,
+ logger: logging.Logger,
) -> None:
self.repository = repository
self.router = router
self.sse_bus = sse_bus
self.shutdown_manager = shutdown_manager
self.settings = settings
+ self.logger = logger
self.metrics = get_connection_metrics()
self.heartbeat_interval = getattr(settings, "SSE_HEARTBEAT_INTERVAL", 30)
@@ -75,9 +77,9 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn
)
# Complete Redis subscription after handshake
- logger.info(f"Opening Redis subscription for execution {execution_id}")
+ self.logger.info("Opening Redis subscription for execution", extra={"execution_id": execution_id})
subscription = await sub_task
- logger.info(f"Redis subscription opened for execution {execution_id}")
+ self.logger.info("Redis subscription opened for execution", extra={"execution_id": execution_id})
initial_status = await self.repository.get_execution_status(execution_id)
if initial_status:
@@ -103,7 +105,7 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn
if subscription is not None:
await subscription.close()
await self.shutdown_manager.unregister_connection(execution_id, connection_id)
- logger.info(f"SSE connection closed: execution_id={execution_id}")
+ self.logger.info("SSE connection closed", extra={"execution_id": execution_id})
async def _stream_events_redis(
self,
@@ -142,19 +144,25 @@ async def _stream_events_redis(
if not msg:
continue
- logger.info(f"Received Redis message for execution {execution_id}: {msg.event_type}")
+ self.logger.info(
+ "Received Redis message for execution",
+ extra={"execution_id": execution_id, "event_type": str(msg.event_type)},
+ )
try:
sse_event = await self._build_sse_event_from_redis(execution_id, msg)
yield self._format_sse_event(sse_event)
# End on terminal event types
if msg.event_type in self.TERMINAL_EVENT_TYPES:
- logger.info(f"Terminal event for execution {execution_id}: {msg.event_type}")
+ self.logger.info(
+ "Terminal event for execution",
+ extra={"execution_id": execution_id, "event_type": str(msg.event_type)},
+ )
break
except Exception as e:
- logger.warning(
- f"Failed to process SSE message for execution {execution_id}: {e}",
- extra={"execution_id": execution_id, "event_type": str(msg.event_type)},
+ self.logger.warning(
+ "Failed to process SSE message",
+ extra={"execution_id": execution_id, "event_type": str(msg.event_type), "error": str(e)},
)
continue
diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py
index 87011d2b..086682b9 100644
--- a/backend/app/services/sse/sse_shutdown_manager.py
+++ b/backend/app/services/sse/sse_shutdown_manager.py
@@ -1,9 +1,9 @@
import asyncio
+import logging
import time
from enum import Enum
from typing import Dict, Set
-from app.core.logging import logger
from app.core.metrics.context import get_connection_metrics
from app.domain.sse import ShutdownStatus
from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge
@@ -34,8 +34,13 @@ class SSEShutdownManager:
"""
def __init__(
- self, drain_timeout: float = 30.0, notification_timeout: float = 5.0, force_close_timeout: float = 10.0
+ self,
+ logger: logging.Logger,
+ drain_timeout: float = 30.0,
+ notification_timeout: float = 5.0,
+ force_close_timeout: float = 10.0,
):
+ self.logger = logger
self.drain_timeout = drain_timeout
self.notification_timeout = notification_timeout
self.force_close_timeout = force_close_timeout
@@ -59,10 +64,9 @@ def __init__(
self._shutdown_event = asyncio.Event()
self._drain_complete_event = asyncio.Event()
- logger.info(
- f"SSEShutdownManager initialized: "
- f"drain_timeout={drain_timeout}s, "
- f"notification_timeout={notification_timeout}s"
+ self.logger.info(
+ "SSEShutdownManager initialized",
+ extra={"drain_timeout": drain_timeout, "notification_timeout": notification_timeout},
)
def set_router(self, router: "SSEKafkaRedisBridge") -> None:
@@ -78,9 +82,9 @@ async def register_connection(self, execution_id: str, connection_id: str) -> as
"""
async with self._lock:
if self._shutdown_initiated:
- logger.warning(
- f"Rejecting new SSE connection during shutdown: "
- f"execution_id={execution_id}, connection_id={connection_id}"
+ self.logger.warning(
+ "Rejecting new SSE connection during shutdown",
+ extra={"execution_id": execution_id, "connection_id": connection_id},
)
return None
@@ -93,7 +97,7 @@ async def register_connection(self, execution_id: str, connection_id: str) -> as
shutdown_event = asyncio.Event()
self._connection_callbacks[connection_id] = shutdown_event
- logger.debug(f"Registered SSE connection: {connection_id}")
+ self.logger.debug("Registered SSE connection", extra={"connection_id": connection_id})
self.metrics.increment_sse_connections("executions")
return shutdown_event
@@ -109,7 +113,7 @@ async def unregister_connection(self, execution_id: str, connection_id: str) ->
self._connection_callbacks.pop(connection_id, None)
self._draining_connections.discard(connection_id)
- logger.debug(f"Unregistered SSE connection: {connection_id}")
+ self.logger.debug("Unregistered SSE connection", extra={"connection_id": connection_id})
self.metrics.decrement_sse_connections("executions")
# Check if all connections are drained
@@ -120,7 +124,7 @@ async def initiate_shutdown(self) -> None:
"""Initiate graceful shutdown of all SSE connections"""
async with self._lock:
if self._shutdown_initiated:
- logger.warning("SSE shutdown already initiated")
+ self.logger.warning("SSE shutdown already initiated")
return
self._shutdown_initiated = True
@@ -128,7 +132,7 @@ async def initiate_shutdown(self) -> None:
self._phase = ShutdownPhase.DRAINING
total_connections = sum(len(conns) for conns in self._active_connections.values())
- logger.info(f"Initiating SSE shutdown with {total_connections} active connections")
+ self.logger.info(f"Initiating SSE shutdown with {total_connections} active connections")
self.metrics.update_sse_draining_connections(total_connections)
@@ -139,7 +143,7 @@ async def initiate_shutdown(self) -> None:
try:
await self._execute_shutdown()
except Exception as e:
- logger.error(f"Error during SSE shutdown: {e}")
+ self.logger.error(f"Error during SSE shutdown: {e}")
raise
finally:
self._shutdown_complete = True
@@ -150,7 +154,7 @@ async def _execute_shutdown(self) -> None:
# Phase 1: Stop accepting new connections (already done by setting _shutdown_initiated)
phase_start = time.time()
- logger.info("Phase 1: Stopped accepting new SSE connections")
+ self.logger.info("Phase 1: Stopped accepting new SSE connections")
# Phase 2: Notify connections about shutdown
await self._notify_connections()
@@ -170,9 +174,9 @@ async def _execute_shutdown(self) -> None:
if self._shutdown_start_time is not None:
total_duration = time.time() - self._shutdown_start_time
self.metrics.update_sse_shutdown_duration(total_duration, "total")
- logger.info(f"SSE shutdown complete in {total_duration:.2f}s")
+ self.logger.info(f"SSE shutdown complete in {total_duration:.2f}s")
else:
- logger.info("SSE shutdown complete")
+ self.logger.info("SSE shutdown complete")
async def _notify_connections(self) -> None:
"""Notify all active connections about shutdown"""
@@ -183,7 +187,7 @@ async def _notify_connections(self) -> None:
connection_events = list(self._connection_callbacks.values())
self._draining_connections = set(self._connection_callbacks.keys())
- logger.info(f"Phase 2: Notifying {active_count} connections about shutdown")
+ self.logger.info(f"Phase 2: Notifying {active_count} connections about shutdown")
self.metrics.update_sse_draining_connections(active_count)
# Trigger shutdown events for all connections
@@ -194,7 +198,7 @@ async def _notify_connections(self) -> None:
# Give connections time to send shutdown messages
await asyncio.sleep(self.notification_timeout)
- logger.info("Shutdown notification phase complete")
+ self.logger.info("Shutdown notification phase complete")
async def _drain_connections(self) -> None:
"""Wait for connections to close gracefully"""
@@ -203,7 +207,7 @@ async def _drain_connections(self) -> None:
async with self._lock:
remaining = sum(len(conns) for conns in self._active_connections.values())
- logger.info(f"Phase 3: Draining {remaining} connections (timeout: {self.drain_timeout}s)")
+ self.logger.info(f"Phase 3: Draining {remaining} connections (timeout: {self.drain_timeout}s)")
self.metrics.update_sse_draining_connections(remaining)
start_time = time.time()
@@ -223,14 +227,14 @@ async def _drain_connections(self) -> None:
remaining = sum(len(conns) for conns in self._active_connections.values())
if remaining < last_count:
- logger.info(f"Connections remaining: {remaining}")
+ self.logger.info(f"Connections remaining: {remaining}")
self.metrics.update_sse_draining_connections(remaining)
last_count = remaining
if remaining == 0:
- logger.info("All connections drained gracefully")
+ self.logger.info("All connections drained gracefully")
else:
- logger.warning(f"{remaining} connections still active after drain timeout")
+ self.logger.warning(f"{remaining} connections still active after drain timeout")
async def _force_close_connections(self) -> None:
"""Force close any remaining connections"""
@@ -240,10 +244,10 @@ async def _force_close_connections(self) -> None:
remaining_count = sum(len(conns) for conns in self._active_connections.values())
if remaining_count == 0:
- logger.info("Phase 4: No connections to force close")
+ self.logger.info("Phase 4: No connections to force close")
return
- logger.warning(f"Phase 4: Force closing {remaining_count} connections")
+ self.logger.warning(f"Phase 4: Force closing {remaining_count} connections")
self.metrics.update_sse_draining_connections(remaining_count)
# Clear all tracking - connections will be forcibly terminated
@@ -256,7 +260,7 @@ async def _force_close_connections(self) -> None:
await self._router.stop()
self.metrics.update_sse_draining_connections(0)
- logger.info("Force close phase complete")
+ self.logger.info("Force close phase complete")
def is_shutting_down(self) -> bool:
"""Check if shutdown is in progress"""
@@ -300,11 +304,15 @@ async def _wait_for_complete(self) -> None:
def create_sse_shutdown_manager(
- drain_timeout: float = 30.0, notification_timeout: float = 5.0, force_close_timeout: float = 10.0
+ logger: logging.Logger,
+ drain_timeout: float = 30.0,
+ notification_timeout: float = 5.0,
+ force_close_timeout: float = 10.0,
) -> SSEShutdownManager:
"""Factory function to create an SSE shutdown manager.
Args:
+ logger: Logger instance
drain_timeout: Time to wait for connections to close gracefully
notification_timeout: Time to wait for shutdown notifications to be sent
force_close_timeout: Time before force closing connections
@@ -313,5 +321,8 @@ def create_sse_shutdown_manager(
A new SSE shutdown manager instance
"""
return SSEShutdownManager(
- drain_timeout=drain_timeout, notification_timeout=notification_timeout, force_close_timeout=force_close_timeout
+ logger=logger,
+ drain_timeout=drain_timeout,
+ notification_timeout=notification_timeout,
+ force_close_timeout=force_close_timeout,
)
diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py
index 5e53d8d2..ce51aaf2 100644
--- a/backend/app/services/user_settings_service.py
+++ b/backend/app/services/user_settings_service.py
@@ -1,10 +1,11 @@
import asyncio
+import json
+import logging
from datetime import datetime, timedelta, timezone
from typing import Any, List
from cachetools import TTLCache
-from app.core.logging import logger
from app.db.repositories.user_settings_repository import UserSettingsRepository
from app.domain.enums import Theme
from app.domain.enums.auth import SettingsType
@@ -23,9 +24,12 @@
class UserSettingsService:
- def __init__(self, repository: UserSettingsRepository, event_service: KafkaEventService) -> None:
+ def __init__(
+ self, repository: UserSettingsRepository, event_service: KafkaEventService, logger: logging.Logger
+ ) -> None:
self.repository = repository
self.event_service = event_service
+ self.logger = logger
# TTL+LRU cache for settings
self._cache_ttl = timedelta(minutes=5)
self._max_cache_size = 1000
@@ -36,7 +40,7 @@ def __init__(self, repository: UserSettingsRepository, event_service: KafkaEvent
self._event_bus_manager: EventBusManager | None = None
self._subscription_id: str | None = None
- logger.info(
+ self.logger.info(
"UserSettingsService initialized",
extra={"cache_ttl_seconds": self._cache_ttl.total_seconds(), "max_cache_size": self._max_cache_size},
)
@@ -45,7 +49,7 @@ async def get_user_settings(self, user_id: str) -> DomainUserSettings:
"""Get settings with cache; rebuild and cache on miss."""
if user_id in self._cache:
cached = self._cache[user_id]
- logger.debug(f"Settings cache hit for user {user_id}", extra={"cache_size": len(self._cache)})
+ self.logger.debug(f"Settings cache hit for user {user_id}", extra={"cache_size": len(self._cache)})
return cached
return await self.get_user_settings_fresh(user_id)
@@ -67,6 +71,7 @@ async def get_user_settings_fresh(self, user_id: str) -> DomainUserSettings:
"""Bypass cache and rebuild settings from snapshot + events."""
snapshot = await self.repository.get_snapshot(user_id)
+ settings: DomainUserSettings
if snapshot:
settings = snapshot
events = await self._get_settings_events(user_id, since=snapshot.updated_at)
@@ -182,17 +187,17 @@ async def update_user_settings(
settings_type = SettingsType.DISPLAY
else:
settings_type = SettingsType.PREFERENCES
- # Flatten changes to string map for the generic event
- changes: dict[str, str] = {}
- for k, v in updated.items():
- changes[k] = str(v)
+ # Stringify all values for Avro compatibility (nested dicts become JSON strings)
+ updated_stringified: dict[str, str] = {
+ k: json.dumps(v) if isinstance(v, dict) else str(v) for k, v in updated.items()
+ }
await self.event_service.publish_event(
event_type=EventType.USER_SETTINGS_UPDATED,
aggregate_id=f"user_settings_{user_id}",
payload={
"user_id": user_id,
"settings_type": settings_type,
- "changes": changes,
+ "updated": updated_stringified,
"reason": reason,
},
metadata=None,
@@ -297,7 +302,7 @@ async def restore_settings_to_point(self, user_id: str, timestamp: datetime) ->
payload={
"user_id": user_id,
"settings_type": SettingsType.PREFERENCES,
- "changes": {"restored_to": timestamp.isoformat()},
+ "updated": {"restored_to": timestamp.isoformat()},
},
metadata=None,
)
@@ -327,7 +332,7 @@ async def _get_settings_events(
event_type=et,
timestamp=e.timestamp,
payload=e.payload,
- correlation_id=e.correlation_id,
+ correlation_id=e.metadata.correlation_id if e.metadata else None,
)
)
return out
@@ -339,23 +344,10 @@ def _apply_event(self, settings: DomainUserSettings, event: DomainSettingsEvent)
settings.theme = Theme(new_theme)
return settings
- upd = event.payload.get("updated")
- if not upd:
- return settings
-
- # Top-level
- if "theme" in upd:
- settings.theme = Theme(upd["theme"])
- if "timezone" in upd:
- settings.timezone = upd["timezone"]
- if "date_format" in upd:
- settings.date_format = upd["date_format"]
- if "time_format" in upd:
- settings.time_format = upd["time_format"]
- # Nested
- if "notifications" in upd and isinstance(upd["notifications"], dict):
- n = upd["notifications"]
- channels: list[NotificationChannel] = [NotificationChannel(c) for c in n.get("channels", [])]
+ if event.event_type == EventType.USER_NOTIFICATION_SETTINGS_UPDATED:
+ n = event.payload.get("settings", {})
+ channels_raw = event.payload.get("channels", [])
+ channels: list[NotificationChannel] = [NotificationChannel(c) for c in channels_raw] if channels_raw else []
settings.notifications = DomainNotificationSettings(
execution_completed=n.get("execution_completed", settings.notifications.execution_completed),
execution_failed=n.get("execution_failed", settings.notifications.execution_failed),
@@ -363,8 +355,11 @@ def _apply_event(self, settings: DomainUserSettings, event: DomainSettingsEvent)
security_alerts=n.get("security_alerts", settings.notifications.security_alerts),
channels=channels or settings.notifications.channels,
)
- if "editor" in upd and isinstance(upd["editor"], dict):
- e = upd["editor"]
+ settings.updated_at = event.timestamp
+ return settings
+
+ if event.event_type == EventType.USER_EDITOR_SETTINGS_UPDATED:
+ e = event.payload.get("settings", {})
settings.editor = DomainEditorSettings(
theme=e.get("theme", settings.editor.theme),
font_size=e.get("font_size", settings.editor.font_size),
@@ -373,8 +368,58 @@ def _apply_event(self, settings: DomainUserSettings, event: DomainSettingsEvent)
word_wrap=e.get("word_wrap", settings.editor.word_wrap),
show_line_numbers=e.get("show_line_numbers", settings.editor.show_line_numbers),
)
- if "custom_settings" in upd and isinstance(upd["custom_settings"], dict):
- settings.custom_settings = upd["custom_settings"]
+ settings.updated_at = event.timestamp
+ return settings
+
+ upd = event.payload.get("updated")
+ if not upd:
+ return settings
+
+ # Helper to parse JSON strings or return dict as-is
+ def parse_value(val: object) -> object:
+ if isinstance(val, str):
+ try:
+ return json.loads(val)
+ except (json.JSONDecodeError, ValueError):
+ return val
+ return val
+
+ # Top-level
+ if "theme" in upd:
+ settings.theme = Theme(str(upd["theme"]))
+ if "timezone" in upd:
+ settings.timezone = str(upd["timezone"])
+ if "date_format" in upd:
+ settings.date_format = str(upd["date_format"])
+ if "time_format" in upd:
+ settings.time_format = str(upd["time_format"])
+ # Nested (may be JSON strings or dicts)
+ if "notifications" in upd:
+ n = parse_value(upd["notifications"])
+ if isinstance(n, dict):
+ notif_channels: list[NotificationChannel] = [NotificationChannel(c) for c in n.get("channels", [])]
+ settings.notifications = DomainNotificationSettings(
+ execution_completed=n.get("execution_completed", settings.notifications.execution_completed),
+ execution_failed=n.get("execution_failed", settings.notifications.execution_failed),
+ system_updates=n.get("system_updates", settings.notifications.system_updates),
+ security_alerts=n.get("security_alerts", settings.notifications.security_alerts),
+ channels=notif_channels or settings.notifications.channels,
+ )
+ if "editor" in upd:
+ e = parse_value(upd["editor"])
+ if isinstance(e, dict):
+ settings.editor = DomainEditorSettings(
+ theme=e.get("theme", settings.editor.theme),
+ font_size=e.get("font_size", settings.editor.font_size),
+ tab_size=e.get("tab_size", settings.editor.tab_size),
+ use_tabs=e.get("use_tabs", settings.editor.use_tabs),
+ word_wrap=e.get("word_wrap", settings.editor.word_wrap),
+ show_line_numbers=e.get("show_line_numbers", settings.editor.show_line_numbers),
+ )
+ if "custom_settings" in upd:
+ cs = parse_value(upd["custom_settings"])
+ if isinstance(cs, dict):
+ settings.custom_settings = cs
settings.version = event.payload.get("version", settings.version)
settings.updated_at = event.timestamp
return settings
@@ -383,12 +428,12 @@ def invalidate_cache(self, user_id: str) -> None:
"""Invalidate cached settings for a user"""
removed = self._cache.pop(user_id, None) is not None
if removed:
- logger.debug(f"Invalidated cache for user {user_id}", extra={"cache_size": len(self._cache)})
+ self.logger.debug(f"Invalidated cache for user {user_id}", extra={"cache_size": len(self._cache)})
def _add_to_cache(self, user_id: str, settings: DomainUserSettings) -> None:
"""Add settings to TTL+LRU cache."""
self._cache[user_id] = settings
- logger.debug(f"Cached settings for user {user_id}", extra={"cache_size": len(self._cache)})
+ self.logger.debug(f"Cached settings for user {user_id}", extra={"cache_size": len(self._cache)})
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics for monitoring."""
@@ -407,4 +452,4 @@ async def reset_user_settings(self, user_id: str) -> None:
# Delete from database
await self.repository.delete_user_settings(user_id)
- logger.info(f"Reset settings for user {user_id}")
+ self.logger.info(f"Reset settings for user {user_id}")
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 15a1fea3..24e6749f 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -55,7 +55,6 @@ dependencies = [
"markdown-it-py==4.0.0",
"MarkupSafe==3.0.2",
"mdurl==0.1.2",
- "motor==3.6.0",
"msgpack==1.1.0",
"multidict==6.7.0",
"oauthlib==3.2.2",
@@ -78,6 +77,7 @@ dependencies = [
"opentelemetry-semantic-conventions==0.43b0",
"opentelemetry-util-http==0.43b0",
"packaging==24.1",
+ "beanie==2.0.1",
"passlib==1.7.4",
"pathspec==0.12.1",
"prometheus-fastapi-instrumentator==7.0.0",
@@ -93,7 +93,7 @@ dependencies = [
"pydantic_core==2.23.4",
"Pygments==2.19.2",
"PyJWT==2.9.0",
- "pymongo==4.9.2",
+ "pymongo==4.12.1",
"pyparsing==3.2.3",
"python-dateutil==2.9.0.post0",
"python-dotenv==1.0.1",
@@ -203,8 +203,14 @@ markers = [
"performance: marks tests as performance tests"
]
asyncio_mode = "auto"
-asyncio_default_fixture_loop_scope = "function"
+asyncio_default_fixture_loop_scope = "session"
+asyncio_default_test_loop_scope = "session"
log_cli = false
log_cli_level = "ERROR"
log_level = "ERROR"
addopts = "-n 4 --dist loadfile --tb=short -q --no-header -q"
+
+# Coverage configuration
+[tool.coverage.run]
+# Use sysmon for faster coverage (requires Python 3.12+)
+core = "sysmon"
diff --git a/backend/scripts/create_topics.py b/backend/scripts/create_topics.py
index 85094ae6..0620a010 100755
--- a/backend/scripts/create_topics.py
+++ b/backend/scripts/create_topics.py
@@ -4,15 +4,18 @@
"""
import asyncio
+import os
import sys
from typing import List
-from app.core.logging import logger
+from app.core.logging import setup_logger
from app.infrastructure.kafka.topics import get_all_topics, get_topic_configs
from app.settings import get_settings
from confluent_kafka import KafkaException
from confluent_kafka.admin import AdminClient, NewTopic
+logger = setup_logger(os.environ.get("LOG_LEVEL", "INFO"))
+
async def create_topics() -> None:
"""Create all required Kafka topics"""
diff --git a/backend/scripts/seed_users.py b/backend/scripts/seed_users.py
index a7653c69..a8450954 100755
--- a/backend/scripts/seed_users.py
+++ b/backend/scripts/seed_users.py
@@ -18,14 +18,15 @@
from typing import Any
from bson import ObjectId
-from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from passlib.context import CryptContext
+from pymongo.asynchronous.database import AsyncDatabase
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
async def upsert_user(
- db: AsyncIOMotorDatabase[dict[str, Any]],
+ db: AsyncDatabase[dict[str, Any]],
username: str,
email: str,
password: str,
@@ -73,7 +74,7 @@ async def seed_users() -> None:
print(f"Connecting to MongoDB (database: {db_name})...")
- client: AsyncIOMotorClient[dict[str, Any]] = AsyncIOMotorClient(mongodb_url)
+ client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(mongodb_url)
db = client[db_name]
# Default user
@@ -103,7 +104,7 @@ async def seed_users() -> None:
print("Admin: admin / admin123 (or ADMIN_USER_PASSWORD)")
print("=" * 50)
- client.close()
+ await client.close()
if __name__ == "__main__":
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index ee0703e1..7a076ffa 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -11,7 +11,7 @@
from dishka import AsyncContainer
from dotenv import load_dotenv
from httpx import ASGITransport
-from motor.motor_asyncio import AsyncIOMotorDatabase
+from app.core.database_context import Database
import redis.asyncio as redis
# Load test environment variables BEFORE any app imports
@@ -126,9 +126,20 @@ def create_test_app():
# ===== App without lifespan for tests =====
-@pytest_asyncio.fixture(scope="function")
-async def app():
- """Create FastAPI app for the function without starting lifespan."""
+@pytest_asyncio.fixture(scope="session")
+async def app(_test_env): # type: ignore[valid-type]
+ """Create FastAPI app once per session/worker.
+
+ Session-scoped to avoid Pydantic schema validator memory issues when
+ FastAPI recreates OpenAPI schemas hundreds of times with pytest-xdist.
+ See: https://github.com/pydantic/pydantic/issues/1864
+
+ Depends on _test_env to ensure env vars (REDIS_DB, DATABASE_NAME, etc.)
+ are set before the app/Settings are created.
+
+ Note: Tests must not modify app.state or registered routes.
+ Use function-scoped `client` fixture for test isolation.
+ """
application = create_test_app()
yield application
@@ -138,7 +149,7 @@ async def app():
await container.close()
-@pytest_asyncio.fixture(scope="function")
+@pytest_asyncio.fixture(scope="session")
async def app_container(app): # type: ignore[valid-type]
"""Expose the Dishka container attached to the app."""
container: AsyncContainer = app.state.dishka_container # type: ignore[attr-defined]
@@ -146,7 +157,7 @@ async def app_container(app): # type: ignore[valid-type]
# ===== Client (function-scoped for clean cookies per test) =====
-@pytest_asyncio.fixture(scope="function")
+@pytest_asyncio.fixture
async def client(app) -> AsyncGenerator[httpx.AsyncClient, None]: # type: ignore[valid-type]
# Use httpx with ASGI app directly
# The app fixture already handles lifespan via LifespanManager
@@ -168,42 +179,27 @@ async def _container_scope(container: AsyncContainer):
yield scope
-@pytest_asyncio.fixture(scope="function")
+@pytest_asyncio.fixture
async def scope(app_container: AsyncContainer): # type: ignore[valid-type]
async with _container_scope(app_container) as s:
yield s
-@pytest_asyncio.fixture(scope="function")
-async def db(scope) -> AsyncGenerator[AsyncIOMotorDatabase, None]: # type: ignore[valid-type]
- database: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
+@pytest_asyncio.fixture
+async def db(scope) -> AsyncGenerator[Database, None]: # type: ignore[valid-type]
+ database: Database = await scope.get(Database)
yield database
-@pytest_asyncio.fixture(scope="function")
+@pytest_asyncio.fixture
async def redis_client(scope) -> AsyncGenerator[redis.Redis, None]: # type: ignore[valid-type]
client: redis.Redis = await scope.get(redis.Redis)
yield client
-# ===== Per-test cleanup =====
-@pytest_asyncio.fixture(scope="function", autouse=True)
-async def _cleanup(db: AsyncIOMotorDatabase, redis_client: redis.Redis):
- # Pre-test: ensure clean state
- collections = await db.list_collection_names()
- for name in collections:
- if not name.startswith("system."):
- await db.drop_collection(name)
- await redis_client.flushdb()
-
- yield
-
- # Post-test: cleanup for next test
- collections = await db.list_collection_names()
- for name in collections:
- if not name.startswith("system."):
- await db.drop_collection(name)
- await redis_client.flushdb()
+# ===== Per-test cleanup (only for integration tests, see integration/conftest.py) =====
+# Note: autouse cleanup moved to tests/integration/conftest.py to avoid
+# requiring DB/Redis for unit tests. Unit tests use tests/unit/conftest.py instead.
# ===== HTTP helpers (auth) =====
@@ -216,7 +212,7 @@ async def _http_login(client: httpx.AsyncClient, username: str, password: str) -
# Session-scoped shared users for convenience
@pytest.fixture(scope="session")
-def shared_user_credentials():
+def test_user_credentials():
uid = os.environ.get("PYTEST_SESSION_ID", uuid.uuid4().hex[:8])
return {
"username": f"test_user_{uid}",
@@ -227,7 +223,7 @@ def shared_user_credentials():
@pytest.fixture(scope="session")
-def shared_admin_credentials():
+def test_admin_credentials():
uid = os.environ.get("PYTEST_SESSION_ID", uuid.uuid4().hex[:8])
return {
"username": f"admin_user_{uid}",
@@ -237,28 +233,29 @@ def shared_admin_credentials():
}
-@pytest_asyncio.fixture(scope="function")
-async def shared_user(client: httpx.AsyncClient, shared_user_credentials):
- creds = shared_user_credentials
- # Always attempt to register; DB is wiped after each test
+@pytest_asyncio.fixture
+async def test_user(client: httpx.AsyncClient, test_user_credentials):
+ """Function-scoped authenticated user. Recreated each test (DB wiped between tests)."""
+ creds = test_user_credentials
r = await client.post("/api/v1/auth/register", json=creds)
if r.status_code not in (200, 201, 400):
- pytest.skip(f"Cannot create shared user (status {r.status_code}).")
+ pytest.skip(f"Cannot create test user (status {r.status_code}).")
csrf = await _http_login(client, creds["username"], creds["password"])
return {**creds, "csrf_token": csrf, "headers": {"X-CSRF-Token": csrf}}
-@pytest_asyncio.fixture(scope="function")
-async def shared_admin(client: httpx.AsyncClient, shared_admin_credentials):
- creds = shared_admin_credentials
+@pytest_asyncio.fixture
+async def test_admin(client: httpx.AsyncClient, test_admin_credentials):
+ """Function-scoped authenticated admin. Recreated each test (DB wiped between tests)."""
+ creds = test_admin_credentials
r = await client.post("/api/v1/auth/register", json=creds)
if r.status_code not in (200, 201, 400):
- pytest.skip(f"Cannot create shared admin (status {r.status_code}).")
+ pytest.skip(f"Cannot create test admin (status {r.status_code}).")
csrf = await _http_login(client, creds["username"], creds["password"])
return {**creds, "csrf_token": csrf, "headers": {"X-CSRF-Token": csrf}}
-@pytest_asyncio.fixture(scope="function")
+@pytest_asyncio.fixture
async def another_user(client: httpx.AsyncClient):
username = f"test_user_{uuid.uuid4().hex[:8]}"
email = f"{username}@example.com"
diff --git a/backend/tests/fixtures/real_services.py b/backend/tests/fixtures/real_services.py
index ea1bd905..7a51e602 100644
--- a/backend/tests/fixtures/real_services.py
+++ b/backend/tests/fixtures/real_services.py
@@ -9,28 +9,29 @@
import pytest
import pytest_asyncio
-from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
import redis.asyncio as redis
from aiokafka import AIOKafkaProducer, AIOKafkaConsumer
from aiokafka.errors import KafkaConnectionError
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
+from app.core.database_context import Database, DBClient
from app.settings import Settings
class TestServiceConnections:
"""Manages connections to real services for testing."""
-
+
def __init__(self, test_id: str):
self.test_id = test_id
- self.mongo_client: Optional[AsyncIOMotorClient] = None
+ self.mongo_client: Optional[DBClient] = None
self.redis_client: Optional[redis.Redis] = None
self.kafka_producer: Optional[AIOKafkaProducer] = None
self.kafka_consumer: Optional[AIOKafkaConsumer] = None
self.db_name = f"test_{test_id}"
-
- async def connect_mongodb(self, url: str) -> AsyncIOMotorDatabase:
+
+ async def connect_mongodb(self, url: str) -> Database:
"""Connect to MongoDB and return test-specific database."""
- self.mongo_client = AsyncIOMotorClient(
+ self.mongo_client = AsyncMongoClient(
url,
serverSelectionTimeoutMS=5000,
connectTimeoutMS=5000,
@@ -97,7 +98,7 @@ async def cleanup(self):
# Drop test MongoDB database
if self.mongo_client:
await self.mongo_client.drop_database(self.db_name)
- self.mongo_client.close()
+ await self.mongo_client.close()
# Clear Redis test database
if self.redis_client:
@@ -130,7 +131,7 @@ async def real_services(request) -> AsyncGenerator[TestServiceConnections, None]
@pytest_asyncio.fixture
-async def real_mongodb(real_services: TestServiceConnections) -> AsyncIOMotorDatabase:
+async def real_mongodb(real_services: TestServiceConnections) -> Database:
"""Get real MongoDB database for testing."""
# Use MongoDB from docker-compose with auth
return await real_services.connect_mongodb(
@@ -158,7 +159,7 @@ async def real_kafka_consumer(real_services: TestServiceConnections) -> Optional
@asynccontextmanager
-async def mongodb_transaction(db: AsyncIOMotorDatabase):
+async def mongodb_transaction(db: Database):
"""
Context manager for MongoDB transactions.
Automatically rolls back on error.
@@ -189,9 +190,9 @@ async def redis_pipeline(client: redis.Redis):
class TestDataFactory:
"""Factory for creating test data in real services."""
-
+
@staticmethod
- async def create_test_user(db: AsyncIOMotorDatabase, **kwargs) -> Dict[str, Any]:
+ async def create_test_user(db: Database, **kwargs) -> Dict[str, Any]:
"""Create a test user in MongoDB."""
user_data = {
"user_id": str(uuid.uuid4()),
@@ -211,7 +212,7 @@ async def create_test_user(db: AsyncIOMotorDatabase, **kwargs) -> Dict[str, Any]
return user_data
@staticmethod
- async def create_test_execution(db: AsyncIOMotorDatabase, **kwargs) -> Dict[str, Any]:
+ async def create_test_execution(db: Database, **kwargs) -> Dict[str, Any]:
"""Create a test execution in MongoDB."""
execution_data = {
"execution_id": str(uuid.uuid4()),
@@ -230,7 +231,7 @@ async def create_test_execution(db: AsyncIOMotorDatabase, **kwargs) -> Dict[str,
return execution_data
@staticmethod
- async def create_test_event(db: AsyncIOMotorDatabase, **kwargs) -> Dict[str, Any]:
+ async def create_test_event(db: Database, **kwargs) -> Dict[str, Any]:
"""Create a test event in MongoDB."""
event_data = {
"event_id": str(uuid.uuid4()),
@@ -308,35 +309,39 @@ async def wait_for_service(check_func, timeout: int = 30, service_name: str = "s
async def ensure_services_running():
"""Ensure required Docker services are running."""
import subprocess
-
+
# Check MongoDB
- try:
- client = AsyncIOMotorClient(
+ async def check_mongo() -> None:
+ client = AsyncMongoClient(
"mongodb://root:rootpassword@localhost:27017",
serverSelectionTimeoutMS=5000
)
- await client.admin.command("ping")
- client.close()
+ try:
+ await client.admin.command("ping")
+ finally:
+ await client.close()
+
+ try:
+ await check_mongo()
except Exception:
print("Starting MongoDB...")
subprocess.run(["docker-compose", "up", "-d", "mongo"], check=False)
- await wait_for_service(
- lambda: AsyncIOMotorClient("mongodb://root:rootpassword@localhost:27017").admin.command("ping"),
- service_name="MongoDB"
- )
-
+ await wait_for_service(check_mongo, service_name="MongoDB")
+
# Check Redis
- try:
+ async def check_redis() -> None:
r = redis.Redis(host="localhost", port=6379, socket_connect_timeout=5)
- await r.execute_command("PING")
- await r.aclose()
+ try:
+ await r.execute_command("PING")
+ finally:
+ await r.aclose()
+
+ try:
+ await check_redis()
except Exception:
print("Starting Redis...")
subprocess.run(["docker-compose", "up", "-d", "redis"], check=False)
- await wait_for_service(
- lambda: redis.Redis(host="localhost", port=6379).execute_command("PING"),
- service_name="Redis"
- )
+ await wait_for_service(check_redis, service_name="Redis")
# Kafka is optional - don't fail if not available
try:
diff --git a/backend/tests/integration/app/__init__.py b/backend/tests/integration/app/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/backend/tests/integration/app/__init__.py
@@ -0,0 +1 @@
+
diff --git a/backend/tests/unit/app/test_main_app.py b/backend/tests/integration/app/test_main_app.py
similarity index 97%
rename from backend/tests/unit/app/test_main_app.py
rename to backend/tests/integration/app/test_main_app.py
index 5b84e75f..36af7d12 100644
--- a/backend/tests/unit/app/test_main_app.py
+++ b/backend/tests/integration/app/test_main_app.py
@@ -12,7 +12,7 @@
RequestSizeLimitMiddleware,
)
-pytestmark = pytest.mark.unit
+pytestmark = pytest.mark.integration
def test_create_app_real_instance(app) -> None: # type: ignore[valid-type]
diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py
new file mode 100644
index 00000000..02ad99d2
--- /dev/null
+++ b/backend/tests/integration/conftest.py
@@ -0,0 +1,33 @@
+"""Integration tests conftest - with infrastructure cleanup."""
+import pytest_asyncio
+import redis.asyncio as redis
+from beanie import init_beanie
+
+from app.core.database_context import Database
+from app.db.docs import ALL_DOCUMENTS
+
+
+@pytest_asyncio.fixture(autouse=True)
+async def _cleanup(db: Database, redis_client: redis.Redis):
+ """Clean DB and Redis before each integration test.
+
+ Only pre-test cleanup - post-test cleanup causes event loop issues
+ when SSE/streaming tests hold connections across loop boundaries.
+
+ NOTE: With pytest-xdist, each worker uses a separate Redis database
+ (gw0→db0, gw1→db1, etc.), so flushdb() is safe and only affects
+ that worker's database. See tests/conftest.py for REDIS_DB setup.
+ """
+ collections = await db.list_collection_names()
+ for name in collections:
+ if not name.startswith("system."):
+ await db.drop_collection(name)
+
+ await redis_client.flushdb()
+
+ # Initialize Beanie with document models
+ # Note: db fixture is already the AsyncDatabase object (type alias Database = AsyncDatabase[MongoDocument])
+ await init_beanie(database=db, document_models=ALL_DOCUMENTS)
+
+ yield
+ # No post-test cleanup to avoid "Event loop is closed" errors
diff --git a/backend/tests/integration/core/test_container.py b/backend/tests/integration/core/test_container.py
index ff8806bb..36bad89a 100644
--- a/backend/tests/integration/core/test_container.py
+++ b/backend/tests/integration/core/test_container.py
@@ -1,6 +1,6 @@
import pytest
from dishka import AsyncContainer
-from motor.motor_asyncio import AsyncIOMotorDatabase
+from app.core.database_context import Database
from app.services.event_service import EventService
@@ -13,7 +13,7 @@ async def test_container_resolves_services(app_container, scope) -> None: # typ
assert isinstance(app_container, AsyncContainer)
# Can resolve core dependencies from DI
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
+ db: Database = await scope.get(Database)
assert db.name and isinstance(db.name, str)
svc: EventService = await scope.get(EventService)
diff --git a/backend/tests/integration/core/test_database_context.py b/backend/tests/integration/core/test_database_context.py
deleted file mode 100644
index e2ac5753..00000000
--- a/backend/tests/integration/core/test_database_context.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import pytest
-
-from app.core.database_context import AsyncDatabaseConnection, ContextualDatabaseProvider, DatabaseNotInitializedError
-from motor.motor_asyncio import AsyncIOMotorDatabase
-
-pytestmark = [pytest.mark.integration, pytest.mark.mongodb]
-
-
-@pytest.mark.asyncio
-async def test_database_connection_from_di(scope) -> None: # type: ignore[valid-type]
- # Resolve both the raw connection and the database via DI
- conn: AsyncDatabaseConnection = await scope.get(AsyncDatabaseConnection)
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
- assert conn.is_connected() is True
- assert db.name and isinstance(db.name, str)
-
-
-def test_contextual_provider_requires_set() -> None:
- provider = ContextualDatabaseProvider()
- assert provider.is_initialized() is False
- with pytest.raises(DatabaseNotInitializedError):
- _ = provider.client
diff --git a/backend/tests/unit/core/test_dishka_lifespan.py b/backend/tests/integration/core/test_dishka_lifespan.py
similarity index 100%
rename from backend/tests/unit/core/test_dishka_lifespan.py
rename to backend/tests/integration/core/test_dishka_lifespan.py
diff --git a/backend/tests/unit/app/__init__.py b/backend/tests/integration/db/repositories/__init__.py
similarity index 100%
rename from backend/tests/unit/app/__init__.py
rename to backend/tests/integration/db/repositories/__init__.py
diff --git a/backend/tests/unit/db/repositories/test_admin_settings_repository.py b/backend/tests/integration/db/repositories/test_admin_settings_repository.py
similarity index 87%
rename from backend/tests/unit/db/repositories/test_admin_settings_repository.py
rename to backend/tests/integration/db/repositories/test_admin_settings_repository.py
index a6897574..7c19cf50 100644
--- a/backend/tests/unit/db/repositories/test_admin_settings_repository.py
+++ b/backend/tests/integration/db/repositories/test_admin_settings_repository.py
@@ -1,14 +1,13 @@
import pytest
-
from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository
from app.domain.admin import SystemSettings
-pytestmark = pytest.mark.unit
+pytestmark = pytest.mark.integration
@pytest.fixture()
-def repo(db) -> AdminSettingsRepository: # type: ignore[valid-type]
- return AdminSettingsRepository(db)
+async def repo(scope) -> AdminSettingsRepository: # type: ignore[valid-type]
+ return await scope.get(AdminSettingsRepository)
@pytest.mark.asyncio
diff --git a/backend/tests/integration/db/repositories/test_dlq_repository.py b/backend/tests/integration/db/repositories/test_dlq_repository.py
new file mode 100644
index 00000000..07d3711f
--- /dev/null
+++ b/backend/tests/integration/db/repositories/test_dlq_repository.py
@@ -0,0 +1,105 @@
+import logging
+from datetime import datetime, timezone
+
+import pytest
+from app.db.docs import DLQMessageDocument
+from app.db.repositories.dlq_repository import DLQRepository
+from app.dlq import DLQMessageStatus
+from app.domain.enums.events import EventType
+
+pytestmark = pytest.mark.integration
+
+_test_logger = logging.getLogger("test.db.repositories.dlq_repository")
+
+
+@pytest.fixture()
+def repo() -> DLQRepository:
+ return DLQRepository(_test_logger)
+
+
+async def insert_test_dlq_docs():
+ """Insert test DLQ documents using Beanie."""
+ now = datetime.now(timezone.utc)
+
+ docs = [
+ DLQMessageDocument(
+ event_id="id1",
+ event_type=str(EventType.USER_LOGGED_IN),
+ event={
+ "event_type": str(EventType.USER_LOGGED_IN),
+ "metadata": {"service_name": "svc", "service_version": "1"},
+ "user_id": "u1",
+ "login_method": "password",
+ },
+ original_topic="t1",
+ error="err",
+ retry_count=0,
+ failed_at=now,
+ status=DLQMessageStatus.PENDING,
+ producer_id="p1",
+ ),
+ DLQMessageDocument(
+ event_id="id2",
+ event_type=str(EventType.USER_LOGGED_IN),
+ event={
+ "event_type": str(EventType.USER_LOGGED_IN),
+ "metadata": {"service_name": "svc", "service_version": "1"},
+ "user_id": "u1",
+ "login_method": "password",
+ },
+ original_topic="t1",
+ error="err",
+ retry_count=0,
+ failed_at=now,
+ status=DLQMessageStatus.RETRIED,
+ producer_id="p1",
+ ),
+ DLQMessageDocument(
+ event_id="id3",
+ event_type=str(EventType.EXECUTION_STARTED),
+ event={
+ "event_type": str(EventType.EXECUTION_STARTED),
+ "metadata": {"service_name": "svc", "service_version": "1"},
+ "execution_id": "x1",
+ "pod_name": "p1",
+ },
+ original_topic="t2",
+ error="err",
+ retry_count=0,
+ failed_at=now,
+ status=DLQMessageStatus.PENDING,
+ producer_id="p1",
+ ),
+ ]
+
+ for doc in docs:
+ await doc.insert()
+
+
+@pytest.mark.asyncio
+async def test_stats_list_get_and_updates(repo: DLQRepository) -> None:
+ await insert_test_dlq_docs()
+
+ stats = await repo.get_dlq_stats()
+ assert isinstance(stats.by_status, dict) and len(stats.by_topic) >= 1
+
+ res = await repo.get_messages(limit=2)
+ assert res.total >= 3 and len(res.messages) <= 2
+ msg = await repo.get_message_by_id("id1")
+ assert msg and msg.event_id == "id1"
+ assert await repo.mark_message_retried("id1") in (True, False)
+ assert await repo.mark_message_discarded("id1", "r") in (True, False)
+
+ topics = await repo.get_topics_summary()
+ assert any(t.topic == "t1" for t in topics)
+
+
+@pytest.mark.asyncio
+async def test_retry_batch(repo: DLQRepository) -> None:
+ class Manager:
+ async def retry_message_manually(self, eid: str) -> bool: # noqa: ARG002
+ return True
+
+ result = await repo.retry_messages_batch(["missing"], Manager())
+ # Missing messages cause failures
+ assert result.total == 1 and result.failed >= 1
diff --git a/backend/tests/integration/db/repositories/test_execution_repository.py b/backend/tests/integration/db/repositories/test_execution_repository.py
new file mode 100644
index 00000000..eb3bf2cb
--- /dev/null
+++ b/backend/tests/integration/db/repositories/test_execution_repository.py
@@ -0,0 +1,46 @@
+import logging
+from uuid import uuid4
+
+import pytest
+from app.db.repositories.execution_repository import ExecutionRepository
+from app.domain.enums.execution import ExecutionStatus
+from app.domain.execution import DomainExecutionCreate, DomainExecutionUpdate
+
+_test_logger = logging.getLogger("test.db.repositories.execution_repository")
+
+pytestmark = pytest.mark.integration
+
+
+@pytest.mark.asyncio
+async def test_execution_crud_and_query() -> None:
+ repo = ExecutionRepository(logger=_test_logger)
+ user_id = str(uuid4())
+
+ # Create
+ create_data = DomainExecutionCreate(
+ script="print('hello')",
+ lang="python",
+ lang_version="3.11",
+ user_id=user_id,
+ )
+ created = await repo.create_execution(create_data)
+ assert created.execution_id
+
+ # Get
+ got = await repo.get_execution(created.execution_id)
+ assert got and got.script.startswith("print") and got.status == ExecutionStatus.QUEUED
+
+ # Update
+ update = DomainExecutionUpdate(status=ExecutionStatus.RUNNING, stdout="ok")
+ ok = await repo.update_execution(created.execution_id, update)
+ assert ok is True
+ got2 = await repo.get_execution(created.execution_id)
+ assert got2 and got2.status == ExecutionStatus.RUNNING
+
+ # List
+ items = await repo.get_executions({"user_id": user_id}, limit=10, skip=0, sort=[("created_at", 1)])
+ assert any(x.execution_id == created.execution_id for x in items)
+
+ # Delete
+ assert await repo.delete_execution(created.execution_id) is True
+ assert await repo.get_execution(created.execution_id) is None
diff --git a/backend/tests/unit/db/repositories/test_saved_script_repository.py b/backend/tests/integration/db/repositories/test_saved_script_repository.py
similarity index 80%
rename from backend/tests/unit/db/repositories/test_saved_script_repository.py
rename to backend/tests/integration/db/repositories/test_saved_script_repository.py
index 473a833f..85fc2b58 100644
--- a/backend/tests/unit/db/repositories/test_saved_script_repository.py
+++ b/backend/tests/integration/db/repositories/test_saved_script_repository.py
@@ -1,13 +1,17 @@
import pytest
-
from app.db.repositories.saved_script_repository import SavedScriptRepository
from app.domain.saved_script import DomainSavedScriptCreate, DomainSavedScriptUpdate
-pytestmark = pytest.mark.unit
+pytestmark = pytest.mark.integration
+
+
+@pytest.fixture()
+async def repo(scope) -> SavedScriptRepository: # type: ignore[valid-type]
+ return await scope.get(SavedScriptRepository)
+
@pytest.mark.asyncio
-async def test_create_get_update_delete_saved_script(db) -> None: # type: ignore[valid-type]
- repo = SavedScriptRepository(db)
+async def test_create_get_update_delete_saved_script(repo: SavedScriptRepository) -> None:
create = DomainSavedScriptCreate(name="n", lang="python", lang_version="3.11", description=None, script="print(1)")
created = await repo.create_saved_script(create, user_id="u1")
assert created.user_id == "u1" and created.script == "print(1)"
diff --git a/backend/tests/integration/db/schema/test_schema_manager.py b/backend/tests/integration/db/schema/test_schema_manager.py
deleted file mode 100644
index bf5c7207..00000000
--- a/backend/tests/integration/db/schema/test_schema_manager.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import pytest
-
-from app.db.schema.schema_manager import SchemaManager
-
-
-pytestmark = [pytest.mark.integration, pytest.mark.mongodb]
-
-
-@pytest.mark.asyncio
-async def test_is_applied_and_mark_applied(db) -> None: # type: ignore[valid-type]
- mgr = SchemaManager(db)
- mig_id = "test_migration_123"
- assert await mgr._is_applied(mig_id) is False
- await mgr._mark_applied(mig_id, "desc")
- assert await mgr._is_applied(mig_id) is True
- doc = await db["schema_versions"].find_one({"_id": mig_id})
- assert doc and doc.get("description") == "desc" and "applied_at" in doc
-
-
-@pytest.mark.asyncio
-async def test_apply_all_idempotent_and_creates_indexes(db) -> None: # type: ignore[valid-type]
- mgr = SchemaManager(db)
- await mgr.apply_all()
- # Apply again should be a no-op
- await mgr.apply_all()
- versions = await db["schema_versions"].count_documents({})
- assert versions >= 9
-
- # Verify some expected indexes exist
- async def idx_names(coll: str) -> set[str]:
- lst = await db[coll].list_indexes().to_list(length=None)
- return {i.get("name", "") for i in lst}
-
- # events
- ev_idx = await idx_names("events")
- assert {"idx_event_id_unique", "idx_event_type_ts", "idx_text_search"}.issubset(ev_idx)
- # user settings
- us_idx = await idx_names("user_settings_snapshots")
- assert {"idx_settings_user_unique", "idx_settings_updated_at_desc"}.issubset(us_idx)
- # replay
- rp_idx = await idx_names("replay_sessions")
- assert {"idx_replay_session_id", "idx_replay_status"}.issubset(rp_idx)
- # notifications
- notif_idx = await idx_names("notifications")
- assert {"idx_notif_user_created_desc", "idx_notif_id_unique"}.issubset(notif_idx)
- subs_idx = await idx_names("notification_subscriptions")
- assert {"idx_sub_user_channel_unique", "idx_sub_enabled"}.issubset(subs_idx)
- # idempotency
- idem_idx = await idx_names("idempotency_keys")
- assert {"idx_idem_key_unique", "idx_idem_created_ttl"}.issubset(idem_idx)
- # sagas
- saga_idx = await idx_names("sagas")
- assert {"idx_saga_id_unique", "idx_saga_state_created"}.issubset(saga_idx)
- # execution_results
- res_idx = await idx_names("execution_results")
- assert {"idx_results_execution_unique", "idx_results_created_at"}.issubset(res_idx)
- # dlq
- dlq_idx = await idx_names("dlq_messages")
- assert {"idx_dlq_event_id_unique", "idx_dlq_failed_desc"}.issubset(dlq_idx)
-
-
-@pytest.mark.asyncio
-async def test_migrations_handle_exceptions_gracefully(db, monkeypatch) -> None: # type: ignore[valid-type]
- # Patch events.create_indexes to fail and db.command to fail (validator)
- mgr = SchemaManager(db)
-
- async def failing_create(*_args, **_kwargs): # noqa: ANN001
- raise RuntimeError("boom")
-
- async def failing_command(*_args, **_kwargs):
- raise RuntimeError("cmd_fail")
-
- monkeypatch.setattr(db["events"], "create_indexes", failing_create, raising=True)
- monkeypatch.setattr(db, "command", failing_command, raising=True)
-
- # Call individual migrations; they should not raise despite failures
- await mgr._m_0001_events_init()
- await mgr._m_0002_user_settings()
- await mgr._m_0003_replay()
- await mgr._m_0004_notifications()
- await mgr._m_0005_idempotency()
- await mgr._m_0006_sagas()
- await mgr._m_0007_execution_results()
- await mgr._m_0008_dlq()
- await mgr._m_0009_event_store_extra()
-
-
-@pytest.mark.asyncio
-async def test_apply_all_skips_already_applied(db) -> None: # type: ignore[valid-type]
- mgr = SchemaManager(db)
- # Mark first migration as applied
- await db["schema_versions"].insert_one({"_id": "0001_events_init"})
- await mgr.apply_all()
- # Ensure we have all migrations recorded and no duplicates
- count = await db["schema_versions"].count_documents({})
- assert count >= 9
diff --git a/backend/tests/integration/dlq/test_dlq_discard_policy.py b/backend/tests/integration/dlq/test_dlq_discard_policy.py
index 371ef59e..cd85a62f 100644
--- a/backend/tests/integration/dlq/test_dlq_discard_policy.py
+++ b/backend/tests/integration/dlq/test_dlq_discard_policy.py
@@ -1,23 +1,29 @@
import asyncio
import json
+import logging
+import os
from datetime import datetime, timezone
import pytest
from confluent_kafka import Producer
+from app.db.docs import DLQMessageDocument
from app.dlq.manager import create_dlq_manager
-from app.dlq.models import DLQFields, DLQMessageStatus, RetryPolicy, RetryStrategy
-import os
+from app.dlq.models import DLQMessageStatus, RetryPolicy, RetryStrategy
from app.domain.enums.kafka import KafkaTopic
+from app.events.schema.schema_registry import create_schema_registry_manager
from tests.helpers import make_execution_requested_event
from tests.helpers.eventually import eventually
pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb]
+_test_logger = logging.getLogger("test.dlq.discard_policy")
+
@pytest.mark.asyncio
async def test_dlq_manager_discards_with_manual_policy(db) -> None: # type: ignore[valid-type]
- manager = create_dlq_manager(database=db)
+ schema_registry = create_schema_registry_manager(_test_logger)
+ manager = create_dlq_manager(schema_registry=schema_registry, logger=_test_logger)
prefix = os.environ.get("KAFKA_TOPIC_PREFIX", "")
topic = f"{prefix}{str(KafkaTopic.EXECUTION_EVENTS)}"
manager.set_retry_policy(topic, RetryPolicy(topic=topic, strategy=RetryStrategy.MANUAL))
@@ -42,11 +48,9 @@ async def test_dlq_manager_discards_with_manual_policy(db) -> None: # type: ign
producer.flush(5)
async with manager:
- coll = db.get_collection("dlq_messages")
-
async def _discarded() -> None:
- doc = await coll.find_one({"event_id": ev.event_id})
+ doc = await DLQMessageDocument.find_one({"event_id": ev.event_id})
assert doc is not None
- assert doc.get(str(DLQFields.STATUS)) == DLQMessageStatus.DISCARDED
+ assert doc.status == DLQMessageStatus.DISCARDED
await eventually(_discarded, timeout=10.0, interval=0.2)
diff --git a/backend/tests/integration/dlq/test_dlq_manager.py b/backend/tests/integration/dlq/test_dlq_manager.py
index 33a9a622..f45f2ceb 100644
--- a/backend/tests/integration/dlq/test_dlq_manager.py
+++ b/backend/tests/integration/dlq/test_dlq_manager.py
@@ -1,24 +1,30 @@
import asyncio
import json
+import logging
+import os
from datetime import datetime, timezone
import pytest
from confluent_kafka import Producer
+from app.db.docs import DLQMessageDocument
from app.dlq.manager import create_dlq_manager
-import os
from app.domain.enums.kafka import KafkaTopic
+from app.events.schema.schema_registry import create_schema_registry_manager
from tests.helpers import make_execution_requested_event
from tests.helpers.eventually import eventually
pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb]
+_test_logger = logging.getLogger("test.dlq.manager")
+
@pytest.mark.asyncio
async def test_dlq_manager_persists_in_mongo(db) -> None: # type: ignore[valid-type]
- manager = create_dlq_manager(database=db)
+ schema_registry = create_schema_registry_manager(_test_logger)
+ manager = create_dlq_manager(schema_registry=schema_registry, logger=_test_logger)
- # Build a DLQ payload matching DLQMapper.from_kafka_message expectations
+ # Build a DLQ payload
ev = make_execution_requested_event(execution_id="exec-dlq-1")
prefix = os.environ.get("KAFKA_TOPIC_PREFIX", "")
@@ -42,10 +48,8 @@ async def test_dlq_manager_persists_in_mongo(db) -> None: # type: ignore[valid-
# Run the manager briefly to consume and persist
async with manager:
- coll = db.get_collection("dlq_messages")
-
async def _exists():
- doc = await coll.find_one({"event_id": ev.event_id})
+ doc = await DLQMessageDocument.find_one({"event_id": ev.event_id})
assert doc is not None
# Poll until the document appears
diff --git a/backend/tests/integration/dlq/test_dlq_retry_immediate.py b/backend/tests/integration/dlq/test_dlq_retry_immediate.py
index e0418828..0752cc34 100644
--- a/backend/tests/integration/dlq/test_dlq_retry_immediate.py
+++ b/backend/tests/integration/dlq/test_dlq_retry_immediate.py
@@ -1,23 +1,29 @@
import asyncio
import json
+import logging
+import os
from datetime import datetime, timezone
import pytest
from confluent_kafka import Producer
+from app.db.docs import DLQMessageDocument
from app.dlq.manager import create_dlq_manager
-from app.dlq.models import DLQFields, DLQMessageStatus, RetryPolicy, RetryStrategy
-import os
+from app.dlq.models import DLQMessageStatus, RetryPolicy, RetryStrategy
from app.domain.enums.kafka import KafkaTopic
+from app.events.schema.schema_registry import create_schema_registry_manager
from tests.helpers import make_execution_requested_event
from tests.helpers.eventually import eventually
pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb]
+_test_logger = logging.getLogger("test.dlq.retry_immediate")
+
@pytest.mark.asyncio
async def test_dlq_manager_immediate_retry_updates_doc(db) -> None: # type: ignore[valid-type]
- manager = create_dlq_manager(database=db)
+ schema_registry = create_schema_registry_manager(_test_logger)
+ manager = create_dlq_manager(schema_registry=schema_registry, logger=_test_logger)
prefix = os.environ.get("KAFKA_TOPIC_PREFIX", "")
topic = f"{prefix}{str(KafkaTopic.EXECUTION_EVENTS)}"
manager.set_retry_policy(
@@ -45,13 +51,11 @@ async def test_dlq_manager_immediate_retry_updates_doc(db) -> None: # type: ign
prod.flush(5)
async with manager:
- coll = db.get_collection("dlq_messages")
-
async def _retried() -> None:
- doc = await coll.find_one({"event_id": ev.event_id})
+ doc = await DLQMessageDocument.find_one({"event_id": ev.event_id})
assert doc is not None
- assert doc.get(str(DLQFields.STATUS)) == DLQMessageStatus.RETRIED
- assert doc.get(str(DLQFields.RETRY_COUNT)) == 1
- assert doc.get(str(DLQFields.RETRIED_AT)) is not None
+ assert doc.status == DLQMessageStatus.RETRIED
+ assert doc.retry_count == 1
+ assert doc.retried_at is not None
await eventually(_retried, timeout=10.0, interval=0.2)
diff --git a/backend/tests/integration/events/test_admin_utils.py b/backend/tests/integration/events/test_admin_utils.py
index 5f0492f5..7ab34509 100644
--- a/backend/tests/integration/events/test_admin_utils.py
+++ b/backend/tests/integration/events/test_admin_utils.py
@@ -1,53 +1,22 @@
+import logging
import os
-import uuid
import pytest
-
from app.events.admin_utils import AdminUtils
-
-pytestmark = [pytest.mark.integration, pytest.mark.kafka]
-
-
-def _unique_topic(prefix: str = "test") -> str:
- sid = os.environ.get("PYTEST_SESSION_ID", "sid")
- return f"{prefix}.adminutils.{sid}.{uuid.uuid4().hex[:8]}"
+_test_logger = logging.getLogger("test.events.admin_utils")
+@pytest.mark.kafka
@pytest.mark.asyncio
-async def test_create_topic_and_verify_partitions() -> None:
- admin = AdminUtils()
- topic = _unique_topic()
+async def test_admin_utils_real_topic_checks() -> None:
+ prefix = os.environ.get("KAFKA_TOPIC_PREFIX", "test.")
+ topic = f"{prefix}adminutils.{os.environ.get('PYTEST_SESSION_ID', 'sid')}"
+ au = AdminUtils(logger=_test_logger)
- created = await admin.create_topic(topic, num_partitions=3, replication_factor=1)
- assert created is True
-
- md = admin.admin_client.list_topics(timeout=10)
- t = md.topics.get(topic)
- assert t is not None
- assert len(getattr(t, "partitions", {})) == 3
-
-
-@pytest.mark.asyncio
-async def test_check_topic_exists_after_ensure() -> None:
- admin = AdminUtils()
- topic = _unique_topic()
+ # Ensure topic exists (idempotent)
+ res = await au.ensure_topics_exist([(topic, 1)])
+ assert res.get(topic) in (True, False) # Some clusters may report exists
- res = await admin.ensure_topics_exist([(topic, 1)])
- assert res.get(topic) is True
-
- exists = await admin.check_topic_exists(topic)
+ exists = await au.check_topic_exists(topic)
assert exists is True
-
-
-@pytest.mark.asyncio
-async def test_create_topic_twice_second_call_returns_false() -> None:
- admin = AdminUtils()
- topic = _unique_topic()
-
- first = await admin.create_topic(topic, num_partitions=1)
- second = await admin.create_topic(topic, num_partitions=1)
-
- assert first is True
- assert second is False
-
diff --git a/backend/tests/integration/events/test_consume_roundtrip.py b/backend/tests/integration/events/test_consume_roundtrip.py
index f1bd4d99..185196e5 100644
--- a/backend/tests/integration/events/test_consume_roundtrip.py
+++ b/backend/tests/integration/events/test_consume_roundtrip.py
@@ -1,21 +1,22 @@
import asyncio
+import logging
import uuid
import pytest
-
from app.domain.enums.events import EventType
+from app.domain.enums.kafka import KafkaTopic
from app.events.core import UnifiedConsumer, UnifiedProducer
+from app.events.core.dispatcher import EventDispatcher
from app.events.core.types import ConsumerConfig
from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas
-from app.domain.enums.kafka import KafkaTopic
-from tests.helpers import make_execution_requested_event
-from app.core.metrics.context import get_event_metrics
-from app.events.core.dispatcher import EventDispatcher
from app.settings import get_settings
+from tests.helpers import make_execution_requested_event
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.consume_roundtrip")
+
@pytest.mark.asyncio
async def test_produce_consume_roundtrip(scope) -> None: # type: ignore[valid-type]
@@ -28,7 +29,7 @@ async def test_produce_consume_roundtrip(scope) -> None: # type: ignore[valid-t
# Build a consumer that handles EXECUTION_REQUESTED
settings = get_settings()
- dispatcher = EventDispatcher()
+ dispatcher = EventDispatcher(logger=_test_logger)
received = asyncio.Event()
@dispatcher.register(EventType.EXECUTION_REQUESTED)
@@ -43,7 +44,7 @@ async def _handle(_event) -> None: # noqa: ANN001
auto_offset_reset="earliest",
)
- consumer = UnifiedConsumer(config, dispatcher)
+ consumer = UnifiedConsumer(config, dispatcher, logger=_test_logger)
await consumer.start([str(KafkaTopic.EXECUTION_EVENTS)])
try:
diff --git a/backend/tests/integration/events/test_consumer_group_monitor.py b/backend/tests/integration/events/test_consumer_group_monitor.py
index 9b62cf24..cfab3017 100644
--- a/backend/tests/integration/events/test_consumer_group_monitor.py
+++ b/backend/tests/integration/events/test_consumer_group_monitor.py
@@ -1,13 +1,16 @@
+import logging
+
import pytest
+from app.events.consumer_group_monitor import ConsumerGroupHealth, NativeConsumerGroupMonitor
-from app.events.consumer_group_monitor import NativeConsumerGroupMonitor, ConsumerGroupHealth
+_test_logger = logging.getLogger("test.events.consumer_group_monitor")
@pytest.mark.integration
@pytest.mark.kafka
@pytest.mark.asyncio
async def test_list_groups_and_error_status():
- mon = NativeConsumerGroupMonitor()
+ mon = NativeConsumerGroupMonitor(logger=_test_logger)
groups = await mon.list_consumer_groups()
assert isinstance(groups, list)
diff --git a/backend/tests/integration/events/test_consumer_group_monitor_e2e.py b/backend/tests/integration/events/test_consumer_group_monitor_e2e.py
index 55fac8f0..1be58358 100644
--- a/backend/tests/integration/events/test_consumer_group_monitor_e2e.py
+++ b/backend/tests/integration/events/test_consumer_group_monitor_e2e.py
@@ -1,21 +1,21 @@
-import asyncio
+import logging
from uuid import uuid4
import pytest
-
from app.events.consumer_group_monitor import (
ConsumerGroupHealth,
ConsumerGroupStatus,
NativeConsumerGroupMonitor,
)
-
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.consumer_group_monitor_e2e")
+
@pytest.mark.asyncio
async def test_consumer_group_status_error_path_and_summary():
- monitor = NativeConsumerGroupMonitor(bootstrap_servers="localhost:9092")
+ monitor = NativeConsumerGroupMonitor(bootstrap_servers="localhost:9092", logger=_test_logger)
# Non-existent group triggers error-handling path and returns minimal status
gid = f"does-not-exist-{uuid4().hex[:8]}"
status = await monitor.get_consumer_group_status(gid, timeout=5.0, include_lag=False)
@@ -28,11 +28,19 @@ async def test_consumer_group_status_error_path_and_summary():
def test_assess_group_health_branches():
- m = NativeConsumerGroupMonitor()
+ m = NativeConsumerGroupMonitor(logger=_test_logger)
# Error state
s = ConsumerGroupStatus(
- group_id="g", state="ERROR", protocol="p", protocol_type="ptype", coordinator="c",
- members=[], member_count=0, assigned_partitions=0, partition_distribution={}, total_lag=0
+ group_id="g",
+ state="ERROR",
+ protocol="p",
+ protocol_type="ptype",
+ coordinator="c",
+ members=[],
+ member_count=0,
+ assigned_partitions=0,
+ partition_distribution={},
+ total_lag=0,
)
h, msg = m._assess_group_health(s) # noqa: SLF001
assert h is ConsumerGroupHealth.UNHEALTHY and "error" in msg.lower()
@@ -74,7 +82,7 @@ def test_assess_group_health_branches():
@pytest.mark.asyncio
async def test_multiple_group_status_mixed_errors():
- m = NativeConsumerGroupMonitor(bootstrap_servers="localhost:9092")
+ m = NativeConsumerGroupMonitor(bootstrap_servers="localhost:9092", logger=_test_logger)
gids = [f"none-{uuid4().hex[:6]}", f"none-{uuid4().hex[:6]}"]
res = await m.get_multiple_group_status(gids, timeout=5.0, include_lag=False)
assert set(res.keys()) == set(gids)
diff --git a/backend/tests/integration/events/test_consumer_min_e2e.py b/backend/tests/integration/events/test_consumer_min_e2e.py
index a228cc75..715ac7e4 100644
--- a/backend/tests/integration/events/test_consumer_min_e2e.py
+++ b/backend/tests/integration/events/test_consumer_min_e2e.py
@@ -1,3 +1,4 @@
+import logging
from uuid import uuid4
import pytest
@@ -6,12 +7,14 @@
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.consumer_min_e2e")
+
@pytest.mark.asyncio
async def test_consumer_start_status_seek_and_stop():
cfg = ConsumerConfig(bootstrap_servers="localhost:9092", group_id=f"test-consumer-{uuid4().hex[:6]}")
- disp = EventDispatcher()
- c = UnifiedConsumer(cfg, event_dispatcher=disp)
+ disp = EventDispatcher(logger=_test_logger)
+ c = UnifiedConsumer(cfg, event_dispatcher=disp, logger=_test_logger)
await c.start([KafkaTopic.EXECUTION_EVENTS])
try:
st = c.get_status()
diff --git a/backend/tests/integration/events/test_dlq_handler.py b/backend/tests/integration/events/test_dlq_handler.py
index b30b6c6a..5659529b 100644
--- a/backend/tests/integration/events/test_dlq_handler.py
+++ b/backend/tests/integration/events/test_dlq_handler.py
@@ -1,12 +1,14 @@
-import pytest
+import logging
-from app.events.core import create_dlq_error_handler, create_immediate_dlq_handler
-from app.events.core import UnifiedProducer
+import pytest
+from app.events.core import UnifiedProducer, create_dlq_error_handler, create_immediate_dlq_handler
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
from app.infrastructure.kafka.events.saga import SagaStartedEvent
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.dlq_handler")
+
@pytest.mark.asyncio
async def test_dlq_handler_with_retries(scope, monkeypatch): # type: ignore[valid-type]
@@ -17,9 +19,14 @@ async def _record_send_to_dlq(original_event, original_topic, error, retry_count
calls.append((original_event.event_id, original_topic, str(error), retry_count))
monkeypatch.setattr(p, "send_to_dlq", _record_send_to_dlq)
- h = create_dlq_error_handler(p, original_topic="t", max_retries=2)
- e = SagaStartedEvent(saga_id="s", saga_name="n", execution_id="x", initial_event_id="i",
- metadata=AvroEventMetadata(service_name="a", service_version="1"))
+ h = create_dlq_error_handler(p, original_topic="t", max_retries=2, logger=_test_logger)
+ e = SagaStartedEvent(
+ saga_id="s",
+ saga_name="n",
+ execution_id="x",
+ initial_event_id="i",
+ metadata=AvroEventMetadata(service_name="a", service_version="1"),
+ )
# Call 1 and 2 should not send to DLQ
await h(RuntimeError("boom"), e)
await h(RuntimeError("boom"), e)
@@ -39,8 +46,13 @@ async def _record_send_to_dlq(original_event, original_topic, error, retry_count
calls.append((original_event.event_id, original_topic, str(error), retry_count))
monkeypatch.setattr(p, "send_to_dlq", _record_send_to_dlq)
- h = create_immediate_dlq_handler(p, original_topic="t")
- e = SagaStartedEvent(saga_id="s2", saga_name="n", execution_id="x", initial_event_id="i",
- metadata=AvroEventMetadata(service_name="a", service_version="1"))
+ h = create_immediate_dlq_handler(p, original_topic="t", logger=_test_logger)
+ e = SagaStartedEvent(
+ saga_id="s2",
+ saga_name="n",
+ execution_id="x",
+ initial_event_id="i",
+ metadata=AvroEventMetadata(service_name="a", service_version="1"),
+ )
await h(RuntimeError("x"), e)
assert calls and calls[0][3] == 0
diff --git a/backend/tests/integration/events/test_event_dispatcher.py b/backend/tests/integration/events/test_event_dispatcher.py
index 89b59766..c88e3fa6 100644
--- a/backend/tests/integration/events/test_event_dispatcher.py
+++ b/backend/tests/integration/events/test_event_dispatcher.py
@@ -1,20 +1,22 @@
import asyncio
+import logging
import uuid
import pytest
-
from app.domain.enums.events import EventType
from app.domain.enums.kafka import KafkaTopic
from app.events.core import UnifiedConsumer, UnifiedProducer
from app.events.core.dispatcher import EventDispatcher
from app.events.core.types import ConsumerConfig
from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas
-from tests.helpers import make_execution_requested_event
from app.settings import get_settings
+from tests.helpers import make_execution_requested_event
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.event_dispatcher")
+
@pytest.mark.asyncio
async def test_dispatcher_with_multiple_handlers(scope) -> None: # type: ignore[valid-type]
@@ -23,7 +25,7 @@ async def test_dispatcher_with_multiple_handlers(scope) -> None: # type: ignore
await initialize_event_schemas(registry)
# Build dispatcher with two handlers for the same event
- dispatcher = EventDispatcher()
+ dispatcher = EventDispatcher(logger=_test_logger)
h1_called = asyncio.Event()
h2_called = asyncio.Event()
@@ -43,7 +45,7 @@ async def h2(_e) -> None: # noqa: ANN001
enable_auto_commit=True,
auto_offset_reset="earliest",
)
- consumer = UnifiedConsumer(cfg, dispatcher)
+ consumer = UnifiedConsumer(cfg, dispatcher, logger=_test_logger)
await consumer.start([str(KafkaTopic.EXECUTION_EVENTS)])
# Produce a request event via DI
diff --git a/backend/tests/integration/events/test_event_store.py b/backend/tests/integration/events/test_event_store.py
deleted file mode 100644
index 6d00e0ad..00000000
--- a/backend/tests/integration/events/test_event_store.py
+++ /dev/null
@@ -1,68 +0,0 @@
-from datetime import datetime, timezone, timedelta
-
-import pytest
-
-from app.events.event_store import EventStore
-from app.events.schema.schema_registry import SchemaRegistryManager
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-from app.infrastructure.kafka.events.pod import PodCreatedEvent
-from app.infrastructure.kafka.events.user import UserLoggedInEvent
-from motor.motor_asyncio import AsyncIOMotorDatabase
-
-pytestmark = [pytest.mark.integration, pytest.mark.mongodb]
-
-
-@pytest.fixture()
-async def event_store(scope) -> EventStore: # type: ignore[valid-type]
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
- schema_registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager)
- store = EventStore(db=db, schema_registry=schema_registry)
- await store.initialize()
- return store
-
-
-@pytest.mark.asyncio
-async def test_store_and_query_events(event_store: EventStore) -> None:
- ev1 = PodCreatedEvent(
- execution_id="x1",
- pod_name="pod1",
- namespace="ns",
- metadata=AvroEventMetadata(service_name="svc", service_version="1", user_id="u1", correlation_id="cid"),
- )
- assert await event_store.store_event(ev1) is True
-
- ev2 = PodCreatedEvent(
- execution_id="x2",
- pod_name="pod2",
- namespace="ns",
- metadata=AvroEventMetadata(service_name="svc", service_version="1", user_id="u1"),
- )
- res = await event_store.store_batch([ev1, ev2])
- assert res["total"] == 2 and res["stored"] >= 1
-
- items = await event_store.get_events_by_type(ev1.event_type)
- assert any(getattr(e, "execution_id", None) == "x1" for e in items)
- exec_items = await event_store.get_execution_events("x1")
- assert any(getattr(e, "execution_id", None) == "x1" for e in exec_items)
- user_items = await event_store.get_user_events("u1")
- assert len(user_items) >= 2
- chain = await event_store.get_correlation_chain("cid")
- assert isinstance(chain, list)
- # Security types (may be empty)
- _ = await event_store.get_security_events()
-
-
-@pytest.mark.asyncio
-async def test_replay_events(event_store: EventStore) -> None:
- ev = UserLoggedInEvent(user_id="u1", login_method="password",
- metadata=AvroEventMetadata(service_name="svc", service_version="1"))
- await event_store.store_event(ev)
-
- called = {"n": 0}
-
- async def cb(_): # noqa: ANN001
- called["n"] += 1
-
- start = datetime.now(timezone.utc) - timedelta(days=1)
- cnt = await event_store.replay_events(start_time=start, callback=cb)
- assert cnt >= 1 and called["n"] >= 1
diff --git a/backend/tests/integration/events/test_event_store_consumer.py b/backend/tests/integration/events/test_event_store_consumer.py
index 778f4da1..111d6fe2 100644
--- a/backend/tests/integration/events/test_event_store_consumer.py
+++ b/backend/tests/integration/events/test_event_store_consumer.py
@@ -1,21 +1,20 @@
-import asyncio
+import logging
import uuid
import pytest
-
-from motor.motor_asyncio import AsyncIOMotorDatabase
-
+from app.core.database_context import Database
from app.domain.enums.kafka import KafkaTopic
from app.events.core import UnifiedProducer
-from app.events.event_store_consumer import EventStoreConsumer, create_event_store_consumer
from app.events.event_store import EventStore
+from app.events.event_store_consumer import EventStoreConsumer, create_event_store_consumer
from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas
-from app.infrastructure.kafka.events.user import UserLoggedInEvent
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-
+from app.infrastructure.kafka.events.user import UserLoggedInEvent
pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb]
+_test_logger = logging.getLogger("test.events.event_store_consumer")
+
@pytest.mark.asyncio
async def test_event_store_consumer_stores_events(scope) -> None: # type: ignore[valid-type]
@@ -25,7 +24,7 @@ async def test_event_store_consumer_stores_events(scope) -> None: # type: ignor
# Resolve DI
producer: UnifiedProducer = await scope.get(UnifiedProducer)
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
+ db: Database = await scope.get(Database)
store: EventStore = await scope.get(EventStore)
# Build an event
@@ -40,6 +39,7 @@ async def test_event_store_consumer_stores_events(scope) -> None: # type: ignor
event_store=store,
topics=[KafkaTopic.USER_EVENTS],
schema_registry_manager=registry,
+ logger=_test_logger,
producer=producer,
batch_size=10,
batch_timeout_seconds=0.5,
diff --git a/backend/tests/integration/events/test_event_store_consumer_flush_e2e.py b/backend/tests/integration/events/test_event_store_consumer_flush_e2e.py
deleted file mode 100644
index 3e29ed13..00000000
--- a/backend/tests/integration/events/test_event_store_consumer_flush_e2e.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import asyncio
-from uuid import uuid4
-
-import pytest
-from motor.motor_asyncio import AsyncIOMotorDatabase
-
-from app.domain.enums.events import EventType
-from app.domain.enums.kafka import KafkaTopic
-from app.events.event_store import EventStore
-from app.events.event_store_consumer import create_event_store_consumer
-from app.events.core import UnifiedProducer
-from app.events.schema.schema_registry import SchemaRegistryManager
-from tests.helpers import make_execution_requested_event
-from tests.helpers.eventually import eventually
-
-pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb]
-
-
-@pytest.mark.asyncio
-async def test_event_store_consumer_flush_on_timeout(scope): # type: ignore[valid-type]
- producer: UnifiedProducer = await scope.get(UnifiedProducer)
- schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager)
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
- store = EventStore(db=db, schema_registry=schema)
- await store.initialize()
-
- consumer = create_event_store_consumer(
- event_store=store,
- topics=[KafkaTopic.EXECUTION_EVENTS],
- schema_registry_manager=schema,
- producer=producer,
- batch_size=100,
- batch_timeout_seconds=0.2,
- )
- await consumer.start()
- try:
- # Directly invoke handler to enqueue
- exec_ids = []
- for _ in range(3):
- x = f"exec-{uuid4().hex[:6]}"
- exec_ids.append(x)
- ev = make_execution_requested_event(execution_id=x)
- await consumer._handle_event(ev) # noqa: SLF001
-
- async def _all_present() -> None:
- docs = await db[store.collection_name].find({"event_type": str(EventType.EXECUTION_REQUESTED)}).to_list(50)
- have = {d.get("execution_id") for d in docs}
- assert set(exec_ids).issubset(have)
-
- await eventually(_all_present, timeout=5.0, interval=0.2)
- finally:
- await consumer.stop()
diff --git a/backend/tests/integration/events/test_event_store_e2e.py b/backend/tests/integration/events/test_event_store_e2e.py
deleted file mode 100644
index fa127421..00000000
--- a/backend/tests/integration/events/test_event_store_e2e.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from datetime import datetime, timezone, timedelta
-
-import pytest
-from motor.motor_asyncio import AsyncIOMotorDatabase
-
-from app.domain.enums.events import EventType
-from app.events.event_store import EventStore
-from app.events.schema.schema_registry import SchemaRegistryManager
-from tests.helpers import make_execution_requested_event
-
-
-pytestmark = [pytest.mark.integration, pytest.mark.mongodb]
-
-
-@pytest.mark.asyncio
-async def test_event_store_initialize_and_crud(scope): # type: ignore[valid-type]
- schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager)
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
- store = EventStore(db=db, schema_registry=schema, ttl_days=1)
- await store.initialize()
-
- # Store single event
- ev = make_execution_requested_event(execution_id="e-1")
- assert await store.store_event(ev) is True
-
- # Duplicate insert should be treated as success True (DuplicateKey swallowed)
- assert await store.store_event(ev) is True
-
- # Batch store with duplicates
- ev2 = ev.model_copy(update={"event_id": "new-2", "execution_id": "e-2"})
- res = await store.store_batch([ev, ev2])
- assert res["total"] == 2 and res["stored"] >= 1
-
- # Queries
- by_id = await store.get_event(ev.event_id)
- assert by_id is not None and by_id.event_id == ev.event_id
-
- by_type = await store.get_events_by_type(EventType.EXECUTION_REQUESTED, limit=10)
- assert any(e.event_id == ev.event_id for e in by_type)
-
- by_exec = await store.get_execution_events("e-1")
- assert any(e.event_id == ev.event_id for e in by_exec)
-
- by_user = await store.get_user_events("u-unknown", limit=10)
- assert isinstance(by_user, list)
diff --git a/backend/tests/integration/events/test_producer_e2e.py b/backend/tests/integration/events/test_producer_e2e.py
index 68924e53..eedbfaa0 100644
--- a/backend/tests/integration/events/test_producer_e2e.py
+++ b/backend/tests/integration/events/test_producer_e2e.py
@@ -1,21 +1,22 @@
-import asyncio
import json
+import logging
from uuid import uuid4
import pytest
-
-from app.events.core import UnifiedProducer, ProducerConfig
+from app.events.core import ProducerConfig, UnifiedProducer
from app.events.schema.schema_registry import SchemaRegistryManager
-from tests.helpers import make_execution_requested_event
+from tests.helpers import make_execution_requested_event
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.producer_e2e")
+
@pytest.mark.asyncio
async def test_unified_producer_start_produce_send_to_dlq_stop(scope): # type: ignore[valid-type]
schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager)
- prod = UnifiedProducer(ProducerConfig(bootstrap_servers="localhost:9092"), schema)
+ prod = UnifiedProducer(ProducerConfig(bootstrap_servers="localhost:9092"), schema, logger=_test_logger)
await prod.start()
try:
@@ -33,17 +34,14 @@ async def test_unified_producer_start_produce_send_to_dlq_stop(scope): # type:
def test_producer_handle_stats_path():
# Directly run stats parsing to cover branch logic; avoid relying on timing
- from app.events.core.producer import UnifiedProducer as UP, ProducerMetrics
+ from app.events.core.producer import ProducerMetrics
+ from app.events.core.producer import UnifiedProducer as UP
+
m = ProducerMetrics()
p = object.__new__(UP) # bypass __init__ safely for method call
# Inject required attributes
p._metrics = m # type: ignore[attr-defined]
p._stats_callback = None # type: ignore[attr-defined]
- payload = json.dumps({
- "msg_cnt": 1,
- "topics": {
- "t": {"partitions": {"0": {"msgq_cnt": 2, "rtt": {"avg": 5}}}}
- }
- })
+ payload = json.dumps({"msg_cnt": 1, "topics": {"t": {"partitions": {"0": {"msgq_cnt": 2, "rtt": {"avg": 5}}}}}})
UP._handle_stats(p, payload) # type: ignore[misc]
assert m.queue_size == 1 and m.avg_latency_ms > 0
diff --git a/backend/tests/integration/events/test_schema_registry_e2e.py b/backend/tests/integration/events/test_schema_registry_e2e.py
index 84fa7aae..44c5f827 100644
--- a/backend/tests/integration/events/test_schema_registry_e2e.py
+++ b/backend/tests/integration/events/test_schema_registry_e2e.py
@@ -1,14 +1,14 @@
-import asyncio
-import struct
+import logging
import pytest
+from app.events.schema.schema_registry import MAGIC_BYTE, SchemaRegistryManager
-from app.events.schema.schema_registry import SchemaRegistryManager, MAGIC_BYTE
from tests.helpers import make_execution_requested_event
-
pytestmark = [pytest.mark.integration]
+_test_logger = logging.getLogger("test.events.schema_registry_e2e")
+
@pytest.mark.asyncio
async def test_schema_registry_serialize_deserialize_roundtrip(scope): # type: ignore[valid-type]
@@ -25,6 +25,6 @@ async def test_schema_registry_serialize_deserialize_roundtrip(scope): # type:
def test_schema_registry_deserialize_invalid_header():
- reg = SchemaRegistryManager()
+ reg = SchemaRegistryManager(logger=_test_logger)
with pytest.raises(ValueError):
reg.deserialize_event(b"\x01\x00\x00\x00\x01", topic="t") # wrong magic byte
diff --git a/backend/tests/integration/events/test_schema_registry_real.py b/backend/tests/integration/events/test_schema_registry_real.py
index 962910d2..895f109d 100644
--- a/backend/tests/integration/events/test_schema_registry_real.py
+++ b/backend/tests/integration/events/test_schema_registry_real.py
@@ -1,15 +1,18 @@
-import pytest
+import logging
+import pytest
from app.events.schema.schema_registry import SchemaRegistryManager
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
from app.infrastructure.kafka.events.pod import PodCreatedEvent
pytestmark = [pytest.mark.integration, pytest.mark.kafka]
+_test_logger = logging.getLogger("test.events.schema_registry_real")
+
def test_serialize_and_deserialize_event_real_registry() -> None:
# Uses real Schema Registry configured via env (SCHEMA_REGISTRY_URL)
- m = SchemaRegistryManager()
+ m = SchemaRegistryManager(logger=_test_logger)
ev = PodCreatedEvent(
execution_id="e1",
pod_name="p",
diff --git a/backend/tests/integration/idempotency/test_consumer_idempotent.py b/backend/tests/integration/idempotency/test_consumer_idempotent.py
index 14337e45..e5334149 100644
--- a/backend/tests/integration/idempotency/test_consumer_idempotent.py
+++ b/backend/tests/integration/idempotency/test_consumer_idempotent.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
import uuid
import pytest
@@ -15,6 +16,8 @@
pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.redis]
+_test_logger = logging.getLogger("test.idempotency.consumer_idempotent")
+
@pytest.mark.asyncio
async def test_consumer_idempotent_wrapper_blocks_duplicates(scope) -> None: # type: ignore[valid-type]
@@ -22,7 +25,7 @@ async def test_consumer_idempotent_wrapper_blocks_duplicates(scope) -> None: #
idm: IdempotencyManager = await scope.get(IdempotencyManager)
# Build a dispatcher with a counter
- disp: Disp = EventDispatcher()
+ disp: Disp = EventDispatcher(logger=_test_logger)
seen = {"n": 0}
@disp.register(EventType.EXECUTION_REQUESTED)
@@ -37,13 +40,14 @@ async def handle(_ev): # noqa: ANN001
enable_auto_commit=True,
auto_offset_reset="earliest",
)
- base = UnifiedConsumer(cfg, event_dispatcher=disp)
+ base = UnifiedConsumer(cfg, event_dispatcher=disp, logger=_test_logger)
wrapper = IdempotentConsumerWrapper(
consumer=base,
idempotency_manager=idm,
dispatcher=disp,
default_key_strategy="event_based",
enable_for_all_handlers=True,
+ logger=_test_logger,
)
await wrapper.start([KafkaTopic.EXECUTION_EVENTS])
diff --git a/backend/tests/integration/idempotency/test_decorator_idempotent.py b/backend/tests/integration/idempotency/test_decorator_idempotent.py
index 98d6b3ca..3f4d73ce 100644
--- a/backend/tests/integration/idempotency/test_decorator_idempotent.py
+++ b/backend/tests/integration/idempotency/test_decorator_idempotent.py
@@ -1,9 +1,12 @@
+import logging
import pytest
from tests.helpers import make_execution_requested_event
from app.services.idempotency.idempotency_manager import IdempotencyManager
from app.services.idempotency.middleware import idempotent_handler
+_test_logger = logging.getLogger("test.idempotency.decorator_idempotent")
+
pytestmark = [pytest.mark.integration]
@@ -14,7 +17,7 @@ async def test_decorator_blocks_duplicate_event(scope) -> None: # type: ignore[
calls = {"n": 0}
- @idempotent_handler(idempotency_manager=idm, key_strategy="event_based")
+ @idempotent_handler(idempotency_manager=idm, key_strategy="event_based", logger=_test_logger)
async def h(ev): # noqa: ANN001
calls["n"] += 1
@@ -34,7 +37,7 @@ async def test_decorator_custom_key_blocks(scope) -> None: # type: ignore[valid
def fixed_key(_ev): # noqa: ANN001
return "fixed-key"
- @idempotent_handler(idempotency_manager=idm, key_strategy="custom", custom_key_func=fixed_key)
+ @idempotent_handler(idempotency_manager=idm, key_strategy="custom", custom_key_func=fixed_key, logger=_test_logger)
async def h(ev): # noqa: ANN001
calls["n"] += 1
diff --git a/backend/tests/integration/idempotency/test_idempotency.py b/backend/tests/integration/idempotency/test_idempotency.py
index 3ff9541e..6620ef6f 100644
--- a/backend/tests/integration/idempotency/test_idempotency.py
+++ b/backend/tests/integration/idempotency/test_idempotency.py
@@ -1,5 +1,6 @@
import asyncio
import json
+import logging
import uuid
from datetime import datetime, timedelta, timezone
import pytest
@@ -14,6 +15,9 @@
pytestmark = [pytest.mark.integration, pytest.mark.redis]
+# Test logger for all tests
+_test_logger = logging.getLogger("test.idempotency")
+
class TestIdempotencyManager:
"""IdempotencyManager backed by real Redis repository (DI-provided client)."""
@@ -30,7 +34,7 @@ async def manager(self, redis_client): # type: ignore[valid-type]
enable_metrics=False,
)
repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix)
- m = IdempotencyManager(cfg, repo)
+ m = IdempotencyManager(cfg, repo, _test_logger)
await m.initialize()
try:
yield m
@@ -248,7 +252,7 @@ async def manager(self, redis_client): # type: ignore[valid-type]
prefix = f"handler_test:{uuid.uuid4().hex[:6]}"
config = IdempotencyConfig(key_prefix=prefix, enable_metrics=False)
repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix)
- m = IdempotencyManager(config, repo)
+ m = IdempotencyManager(config, repo, _test_logger)
await m.initialize()
try:
yield m
@@ -267,7 +271,8 @@ async def actual_handler(event: BaseEvent):
handler = IdempotentEventHandler(
handler=actual_handler,
idempotency_manager=manager,
- key_strategy="event_based"
+ key_strategy="event_based",
+ logger=_test_logger,
)
# Process event
@@ -290,7 +295,8 @@ async def actual_handler(event: BaseEvent):
handler = IdempotentEventHandler(
handler=actual_handler,
idempotency_manager=manager,
- key_strategy="event_based"
+ key_strategy="event_based",
+ logger=_test_logger,
)
# Process event twice
@@ -311,7 +317,8 @@ async def failing_handler(event: BaseEvent):
handler = IdempotentEventHandler(
handler=failing_handler,
idempotency_manager=manager,
- key_strategy="event_based"
+ key_strategy="event_based",
+ logger=_test_logger,
)
# Process event (should raise)
@@ -340,7 +347,8 @@ async def on_duplicate(event: BaseEvent, result):
handler=actual_handler,
idempotency_manager=manager,
key_strategy="event_based",
- on_duplicate=on_duplicate
+ on_duplicate=on_duplicate,
+ logger=_test_logger,
)
# Process twice
@@ -361,7 +369,8 @@ async def test_decorator_integration(self, manager):
@idempotent_handler(
idempotency_manager=manager,
key_strategy="content_hash",
- ttl_seconds=300
+ ttl_seconds=300,
+ logger=_test_logger,
)
async def my_handler(event: BaseEvent):
processed_events.append(event)
@@ -402,7 +411,8 @@ def extract_script_key(event: BaseEvent) -> str:
handler=process_script,
idempotency_manager=manager,
key_strategy="custom",
- custom_key_func=extract_script_key
+ custom_key_func=extract_script_key,
+ logger=_test_logger,
)
# Events with same script
@@ -495,7 +505,7 @@ async def test_metrics_enabled(self, redis_client): # type: ignore[valid-type]
"""Test manager with metrics enabled"""
config = IdempotencyConfig(key_prefix=f"metrics:{uuid.uuid4().hex[:6]}", enable_metrics=True)
repository = RedisIdempotencyRepository(redis_client, key_prefix=config.key_prefix)
- manager = IdempotencyManager(config, repository)
+ manager = IdempotencyManager(config, repository, _test_logger)
# Initialize with metrics
await manager.initialize()
diff --git a/backend/tests/integration/idempotency/test_idempotent_handler.py b/backend/tests/integration/idempotency/test_idempotent_handler.py
index 1a6f1c07..76ea369a 100644
--- a/backend/tests/integration/idempotency/test_idempotent_handler.py
+++ b/backend/tests/integration/idempotency/test_idempotent_handler.py
@@ -1,3 +1,5 @@
+import logging
+
import pytest
from app.events.schema.schema_registry import SchemaRegistryManager
@@ -8,6 +10,8 @@
pytestmark = [pytest.mark.integration]
+_test_logger = logging.getLogger("test.idempotency.idempotent_handler")
+
@pytest.mark.asyncio
async def test_idempotent_handler_blocks_duplicates(scope) -> None: # type: ignore[valid-type]
@@ -22,6 +26,7 @@ async def _handler(ev) -> None: # noqa: ANN001
handler=_handler,
idempotency_manager=manager,
key_strategy="event_based",
+ logger=_test_logger,
)
ev = make_execution_requested_event(execution_id="exec-dup-1")
@@ -45,6 +50,7 @@ async def _handler(ev) -> None: # noqa: ANN001
handler=_handler,
idempotency_manager=manager,
key_strategy="content_hash",
+ logger=_test_logger,
)
e1 = make_execution_requested_event(execution_id="exec-dup-2")
diff --git a/backend/tests/integration/test_execution_routes.py b/backend/tests/integration/k8s/test_execution_routes.py
similarity index 84%
rename from backend/tests/integration/test_execution_routes.py
rename to backend/tests/integration/k8s/test_execution_routes.py
index ac3e4053..fb85978e 100644
--- a/backend/tests/integration/test_execution_routes.py
+++ b/backend/tests/integration/k8s/test_execution_routes.py
@@ -14,23 +14,7 @@
)
-def has_k8s_workers() -> bool:
- """Check if K8s workers are available for execution."""
- # Check if K8s worker container is running
- import subprocess
- try:
- result = subprocess.run(
- ["docker", "ps", "--filter", "name=k8s-worker", "--format", "{{.Names}}"],
- capture_output=True,
- text=True,
- timeout=2
- )
- return "k8s-worker" in result.stdout
- except Exception:
- return False
-
-
-@pytest.mark.integration
+@pytest.mark.k8s
class TestExecution:
"""Test execution endpoints against real backend."""
@@ -52,13 +36,12 @@ async def test_execute_requires_authentication(self, client: AsyncClient) -> Non
for word in ["not authenticated", "unauthorized", "login"])
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_execute_simple_python_script(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_execute_simple_python_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test executing a simple Python script."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -96,13 +79,12 @@ async def test_execute_simple_python_script(self, client: AsyncClient, shared_us
]
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_get_execution_result(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_execution_result(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting execution result after completion using SSE (event-driven)."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -137,13 +119,12 @@ async def test_get_execution_result(self, client: AsyncClient, shared_user: Dict
assert "Line 2" in execution_result.stdout
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_execute_with_error(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_execute_with_error(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test executing a script that produces an error."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -163,13 +144,12 @@ async def test_execute_with_error(self, client: AsyncClient, shared_user: Dict[s
# No waiting - execution was accepted, error will be processed asynchronously
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_execute_with_resource_tracking(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_execute_with_resource_tracking(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that execution tracks resource usage."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -205,14 +185,13 @@ async def test_execute_with_resource_tracking(self, client: AsyncClient, shared_
assert resource_usage.peak_memory_kb >= 0
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
async def test_execute_with_different_language_versions(self, client: AsyncClient,
- shared_user: Dict[str, str]) -> None:
+ test_user: Dict[str, str]) -> None:
"""Test execution with different Python versions."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -240,13 +219,12 @@ async def test_execute_with_different_language_versions(self, client: AsyncClien
assert "execution_id" in data
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_execute_with_large_output(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_execute_with_large_output(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test execution with large output."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -279,13 +257,12 @@ async def test_execute_with_large_output(self, client: AsyncClient, shared_user:
assert "End of output" in result_data["stdout"] or len(result_data["stdout"]) > 10000
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_cancel_running_execution(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_cancel_running_execution(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test cancelling a running execution."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -326,8 +303,7 @@ async def test_cancel_running_execution(self, client: AsyncClient, shared_user:
# Cancel response of 200 means cancellation was accepted
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_execution_with_timeout(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_execution_with_timeout(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Bounded check: long-running executions don't finish immediately.
The backend's default timeout is 300s. To keep integration fast,
@@ -336,8 +312,8 @@ async def test_execution_with_timeout(self, client: AsyncClient, shared_user: Di
"""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -364,13 +340,12 @@ async def test_execution_with_timeout(self, client: AsyncClient, shared_user: Di
# No need to wait or observe states
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_sandbox_restrictions(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_sandbox_restrictions(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that dangerous operations are blocked by sandbox."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -421,13 +396,12 @@ async def test_sandbox_restrictions(self, client: AsyncClient, shared_user: Dict
assert exec_response.status_code in [400, 422]
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_concurrent_executions_by_same_user(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_concurrent_executions_by_same_user(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test running multiple executions concurrently."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -489,11 +463,10 @@ async def test_get_k8s_resource_limits(self, client: AsyncClient) -> None:
assert key in limits
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
- async def test_get_user_executions_list(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_user_executions_list(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""User executions list returns paginated executions for current user."""
# Login first
- login_data = {"username": shared_user["username"], "password": shared_user["password"]}
+ login_data = {"username": test_user["username"], "password": test_user["password"]}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -504,12 +477,11 @@ async def test_get_user_executions_list(self, client: AsyncClient, shared_user:
assert set(["executions", "total", "limit", "skip", "has_more"]).issubset(payload.keys())
@pytest.mark.asyncio
- @pytest.mark.skipif(not has_k8s_workers(), reason="K8s workers not available")
async def test_execution_idempotency_same_key_returns_same_execution(self, client: AsyncClient,
- shared_user: Dict[str, str]) -> None:
+ test_user: Dict[str, str]) -> None:
"""Submitting the same request with the same Idempotency-Key yields the same execution_id."""
# Login first
- login_data = {"username": shared_user["username"], "password": shared_user["password"]}
+ login_data = {"username": test_user["username"], "password": test_user["password"]}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/k8s/test_k8s_worker_create_pod.py b/backend/tests/integration/k8s/test_k8s_worker_create_pod.py
index 1722ea2d..732ce094 100644
--- a/backend/tests/integration/k8s/test_k8s_worker_create_pod.py
+++ b/backend/tests/integration/k8s/test_k8s_worker_create_pod.py
@@ -1,22 +1,22 @@
+import logging
import os
import uuid
import pytest
-from kubernetes.client.rest import ApiException
-
+from app.events.core import UnifiedProducer
+from app.events.event_store import EventStore
+from app.events.schema.schema_registry import SchemaRegistryManager
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
from app.infrastructure.kafka.events.saga import CreatePodCommandEvent
+from app.services.idempotency import IdempotencyManager
from app.services.k8s_worker.config import K8sWorkerConfig
from app.services.k8s_worker.worker import KubernetesWorker
-
-from motor.motor_asyncio import AsyncIOMotorDatabase
-from app.events.event_store import EventStore
-from app.events.schema.schema_registry import SchemaRegistryManager
-from app.events.core import UnifiedProducer
-from app.services.idempotency import IdempotencyManager
+from kubernetes.client.rest import ApiException
pytestmark = [pytest.mark.integration, pytest.mark.k8s]
+_test_logger = logging.getLogger("test.k8s.worker_create_pod")
+
@pytest.mark.asyncio
async def test_worker_creates_configmap_and_pod(scope, monkeypatch): # type: ignore[valid-type]
@@ -26,7 +26,6 @@ async def test_worker_creates_configmap_and_pod(scope, monkeypatch): # type: ig
ns = "integr8scode"
monkeypatch.setenv("K8S_NAMESPACE", ns)
- database: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager)
store: EventStore = await scope.get(EventStore)
producer: UnifiedProducer = await scope.get(UnifiedProducer)
@@ -35,11 +34,11 @@ async def test_worker_creates_configmap_and_pod(scope, monkeypatch): # type: ig
cfg = K8sWorkerConfig(namespace=ns, max_concurrent_pods=1)
worker = KubernetesWorker(
config=cfg,
- database=database,
producer=producer,
schema_registry_manager=schema,
event_store=store,
idempotency_manager=idem,
+ logger=_test_logger,
)
# Initialize k8s clients using worker's own method
diff --git a/backend/tests/integration/k8s/test_resource_cleaner_integration.py b/backend/tests/integration/k8s/test_resource_cleaner_integration.py
index 7d6ec26a..483937e0 100644
--- a/backend/tests/integration/k8s/test_resource_cleaner_integration.py
+++ b/backend/tests/integration/k8s/test_resource_cleaner_integration.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from datetime import datetime, timedelta, timezone
import pytest
@@ -9,6 +10,8 @@
pytestmark = [pytest.mark.integration, pytest.mark.k8s]
+_test_logger = logging.getLogger("test.k8s.resource_cleaner_integration")
+
def _ensure_kubeconfig():
try:
@@ -33,7 +36,7 @@ async def test_cleanup_orphaned_configmaps_dry_run():
v1.create_namespaced_config_map(namespace=ns, body=body)
try:
- cleaner = ResourceCleaner()
+ cleaner = ResourceCleaner(logger=_test_logger)
# Force as orphaned by using a large cutoff
cleaned = await cleaner.cleanup_orphaned_resources(namespace=ns, max_age_hours=0, dry_run=True)
diff --git a/backend/tests/integration/k8s/test_resource_cleaner_k8s.py b/backend/tests/integration/k8s/test_resource_cleaner_k8s.py
index 87ccc97c..2a36af62 100644
--- a/backend/tests/integration/k8s/test_resource_cleaner_k8s.py
+++ b/backend/tests/integration/k8s/test_resource_cleaner_k8s.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
import os
import pytest
@@ -8,10 +9,12 @@
pytestmark = [pytest.mark.integration, pytest.mark.k8s]
+_test_logger = logging.getLogger("test.k8s.resource_cleaner_k8s")
+
@pytest.mark.asyncio
async def test_initialize_and_get_usage() -> None:
- rc = ResourceCleaner()
+ rc = ResourceCleaner(logger=_test_logger)
await rc.initialize()
usage = await rc.get_resource_usage(namespace=os.environ.get("K8S_NAMESPACE", "default"))
assert set(usage.keys()) >= {"pods", "configmaps", "network_policies"}
@@ -19,7 +22,7 @@ async def test_initialize_and_get_usage() -> None:
@pytest.mark.asyncio
async def test_cleanup_orphaned_resources_dry_run() -> None:
- rc = ResourceCleaner()
+ rc = ResourceCleaner(logger=_test_logger)
await rc.initialize()
cleaned = await rc.cleanup_orphaned_resources(
namespace=os.environ.get("K8S_NAMESPACE", "default"),
@@ -31,7 +34,7 @@ async def test_cleanup_orphaned_resources_dry_run() -> None:
@pytest.mark.asyncio
async def test_cleanup_nonexistent_pod() -> None:
- rc = ResourceCleaner()
+ rc = ResourceCleaner(logger=_test_logger)
await rc.initialize()
# Attempt to delete a pod that doesn't exist - should complete without errors
diff --git a/backend/tests/integration/result_processor/test_result_processor.py b/backend/tests/integration/result_processor/test_result_processor.py
index af16b600..e400e8dc 100644
--- a/backend/tests/integration/result_processor/test_result_processor.py
+++ b/backend/tests/integration/result_processor/test_result_processor.py
@@ -1,15 +1,17 @@
import asyncio
+import logging
import uuid
from tests.helpers.eventually import eventually
import pytest
-from motor.motor_asyncio import AsyncIOMotorDatabase
+from app.core.database_context import Database
from app.db.repositories.execution_repository import ExecutionRepository
from app.domain.enums.events import EventType
from app.domain.enums.execution import ExecutionStatus
+from app.domain.execution import DomainExecutionCreate
from app.domain.enums.kafka import KafkaTopic
-from app.domain.execution.models import DomainExecution, ResourceUsageDomain
+from app.domain.execution.models import ResourceUsageDomain
from app.events.core import UnifiedConsumer, UnifiedProducer
from app.events.core.dispatcher import EventDispatcher
from app.events.core.types import ConsumerConfig
@@ -22,6 +24,8 @@
pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb]
+_test_logger = logging.getLogger("test.result_processor.processor")
+
@pytest.mark.asyncio
async def test_result_processor_persists_and_emits(scope) -> None: # type: ignore[valid-type]
@@ -30,33 +34,32 @@ async def test_result_processor_persists_and_emits(scope) -> None: # type: igno
await initialize_event_schemas(registry)
# Dependencies
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
+ db: Database = await scope.get(Database)
repo: ExecutionRepository = await scope.get(ExecutionRepository)
producer: UnifiedProducer = await scope.get(UnifiedProducer)
idem: IdempotencyManager = await scope.get(IdempotencyManager)
# Create a base execution to satisfy ResultProcessor lookup
- execution_id = f"exec-{uuid.uuid4().hex[:8]}"
- base = DomainExecution(
- execution_id=execution_id,
+ created = await repo.create_execution(DomainExecutionCreate(
script="print('x')",
- status=ExecutionStatus.RUNNING,
+ user_id="u1",
lang="python",
lang_version="3.11",
- user_id="u1",
- )
- await repo.create_execution(base)
+ status=ExecutionStatus.RUNNING,
+ ))
+ execution_id = created.execution_id
# Build and start the processor
processor = ResultProcessor(
execution_repo=repo,
producer=producer,
idempotency_manager=idem,
+ logger=_test_logger,
)
# Setup a small consumer to capture ResultStoredEvent
settings = get_settings()
- dispatcher = EventDispatcher()
+ dispatcher = EventDispatcher(logger=_test_logger)
stored_received = asyncio.Event()
@dispatcher.register(EventType.RESULT_STORED)
@@ -70,7 +73,7 @@ async def _stored(_event) -> None: # noqa: ANN001
enable_auto_commit=True,
auto_offset_reset="earliest",
)
- stored_consumer = UnifiedConsumer(cconf, dispatcher)
+ stored_consumer = UnifiedConsumer(cconf, dispatcher, logger=_test_logger)
await stored_consumer.start([str(KafkaTopic.EXECUTION_RESULTS)])
try:
@@ -94,8 +97,9 @@ async def _stored(_event) -> None: # noqa: ANN001
# Wait for DB persistence (event-driven polling)
async def _persisted() -> None:
- doc = await db.get_collection("execution_results").find_one({"_id": execution_id})
+ doc = await db.get_collection("executions").find_one({"execution_id": execution_id})
assert doc is not None
+ assert doc.get("status") == ExecutionStatus.COMPLETED.value
await eventually(_persisted, timeout=12.0, interval=0.2)
diff --git a/backend/tests/integration/services/admin/test_admin_user_service.py b/backend/tests/integration/services/admin/test_admin_user_service.py
index c576bfc1..a392a908 100644
--- a/backend/tests/integration/services/admin/test_admin_user_service.py
+++ b/backend/tests/integration/services/admin/test_admin_user_service.py
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
import pytest
-from motor.motor_asyncio import AsyncIOMotorDatabase
+from app.core.database_context import Database
from app.domain.enums.user import UserRole
from app.services.admin import AdminUserService
@@ -12,7 +12,7 @@
@pytest.mark.asyncio
async def test_get_user_overview_basic(scope) -> None: # type: ignore[valid-type]
svc: AdminUserService = await scope.get(AdminUserService)
- db: AsyncIOMotorDatabase = await scope.get(AsyncIOMotorDatabase)
+ db: Database = await scope.get(Database)
await db.get_collection("users").insert_one({
"user_id": "u1",
"username": "bob",
diff --git a/backend/tests/integration/services/events/test_event_service_integration.py b/backend/tests/integration/services/events/test_event_service_integration.py
deleted file mode 100644
index 0f7257dc..00000000
--- a/backend/tests/integration/services/events/test_event_service_integration.py
+++ /dev/null
@@ -1,63 +0,0 @@
-from datetime import datetime, timezone, timedelta
-
-import pytest
-
-from app.db.repositories import EventRepository
-from app.domain.events.event_models import EventFields, Event, EventFilter
-from app.domain.enums.common import SortOrder
-from app.domain.enums.user import UserRole
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-from app.domain.enums.events import EventType
-from app.services.event_service import EventService
-
-pytestmark = [pytest.mark.integration, pytest.mark.mongodb]
-
-
-@pytest.mark.asyncio
-async def test_event_service_access_and_queries(scope) -> None: # type: ignore[valid-type]
- repo: EventRepository = await scope.get(EventRepository)
- svc: EventService = await scope.get(EventService)
-
- now = datetime.now(timezone.utc)
- # Seed some events (domain Event, not infra BaseEvent)
- md1 = AvroEventMetadata(service_name="svc", service_version="1", user_id="u1", correlation_id="c1")
- md2 = AvroEventMetadata(service_name="svc", service_version="1", user_id="u2", correlation_id="c1")
- e1 = Event(event_id="e1", event_type=str(EventType.USER_LOGGED_IN), event_version="1.0", timestamp=now,
- metadata=md1, payload={"user_id": "u1", "login_method": "password"}, aggregate_id="agg1")
- e2 = Event(event_id="e2", event_type=str(EventType.USER_LOGGED_IN), event_version="1.0", timestamp=now,
- metadata=md2, payload={"user_id": "u2", "login_method": "password"}, aggregate_id="agg2")
- await repo.store_event(e1)
- await repo.store_event(e2)
-
- # get_execution_events returns None when non-admin for different user; then admin sees
- events_user = await svc.get_execution_events("agg1", "u2", UserRole.USER)
- assert events_user is None
- events_admin = await svc.get_execution_events("agg1", "admin", UserRole.ADMIN)
- assert any(ev.aggregate_id == "agg1" for ev in events_admin.events)
-
- # query_events_advanced: basic run (empty filters) should return a result structure
- res = await svc.query_events_advanced("u1", UserRole.USER, filters=EventFilter(), sort_by="correlation_id", sort_order=SortOrder.ASC)
- assert res is not None
-
- # get_events_by_correlation filters non-admin to their own user_id
- by_corr_user = await svc.get_events_by_correlation("c1", user_id="u1", user_role=UserRole.USER, include_all_users=False)
- assert all(ev.metadata.user_id == "u1" for ev in by_corr_user.events)
- by_corr_admin = await svc.get_events_by_correlation("c1", user_id="admin", user_role=UserRole.ADMIN, include_all_users=True)
- assert len(by_corr_admin.events) >= 2
-
- # get_event_statistics (time window)
- _ = await svc.get_event_statistics("u1", UserRole.USER, start_time=now - timedelta(days=1), end_time=now + timedelta(days=1))
-
- # get_event enforces access control
- one_allowed = await svc.get_event(e1.event_id, user_id="u1", user_role=UserRole.USER)
- assert one_allowed is not None
- one_denied = await svc.get_event(e1.event_id, user_id="u2", user_role=UserRole.USER)
- assert one_denied is None
-
- # aggregate_events injects user filter for non-admin
- pipe = [{"$match": {EventFields.EVENT_TYPE: str(e1.event_type)}}]
- _ = await svc.aggregate_events("u1", UserRole.USER, pipe)
-
- # list_event_types returns at least one type
- types = await svc.list_event_types("u1", UserRole.USER)
- assert isinstance(types, list) and len(types) >= 1
diff --git a/backend/tests/integration/services/saved_script/test_saved_script_service.py b/backend/tests/integration/services/saved_script/test_saved_script_service.py
index 9532b06f..16d980c8 100644
--- a/backend/tests/integration/services/saved_script/test_saved_script_service.py
+++ b/backend/tests/integration/services/saved_script/test_saved_script_service.py
@@ -1,13 +1,10 @@
import pytest
-from app.core.exceptions import ServiceError
-from app.domain.saved_script import DomainSavedScriptCreate, DomainSavedScriptUpdate
+from app.domain.saved_script import DomainSavedScriptCreate, DomainSavedScriptUpdate, SavedScriptNotFoundError
from app.services.saved_script_service import SavedScriptService
pytestmark = [pytest.mark.integration, pytest.mark.mongodb]
-pytestmark = pytest.mark.unit
-
def _create_payload() -> DomainSavedScriptCreate:
return DomainSavedScriptCreate(name="n", description=None, script="print(1)")
@@ -29,5 +26,5 @@ async def test_crud_saved_script(scope) -> None: # type: ignore[valid-type]
assert any(s.script_id == created.script_id for s in lst)
await service.delete_saved_script(str(created.script_id), "u1")
- with pytest.raises(ServiceError):
+ with pytest.raises(SavedScriptNotFoundError):
await service.get_saved_script(str(created.script_id), "u1")
diff --git a/backend/tests/integration/services/sse/__init__.py b/backend/tests/integration/services/sse/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/backend/tests/integration/services/sse/__init__.py
@@ -0,0 +1 @@
+
diff --git a/backend/tests/integration/services/sse/test_partitioned_event_router_integration.py b/backend/tests/integration/services/sse/test_partitioned_event_router_integration.py
index f4b604ac..40720ee4 100644
--- a/backend/tests/integration/services/sse/test_partitioned_event_router_integration.py
+++ b/backend/tests/integration/services/sse/test_partitioned_event_router_integration.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from uuid import uuid4
from tests.helpers.eventually import eventually
import pytest
@@ -15,6 +16,8 @@
pytestmark = [pytest.mark.integration, pytest.mark.redis]
+_test_logger = logging.getLogger("test.services.sse.partitioned_event_router_integration")
+
@pytest.mark.asyncio
async def test_router_bridges_to_redis(redis_client) -> None: # type: ignore[valid-type]
@@ -24,14 +27,16 @@ async def test_router_bridges_to_redis(redis_client) -> None: # type: ignore[va
redis_client,
exec_prefix=f"sse:exec:{suffix}:",
notif_prefix=f"sse:notif:{suffix}:",
+ logger=_test_logger,
)
router = SSEKafkaRedisBridge(
- schema_registry=SchemaRegistryManager(),
+ schema_registry=SchemaRegistryManager(logger=_test_logger),
settings=settings,
event_metrics=EventMetrics(),
sse_bus=bus,
+ logger=_test_logger,
)
- disp = EventDispatcher()
+ disp = EventDispatcher(logger=_test_logger)
router._register_routing_handlers(disp)
# Open Redis subscription for our execution id
@@ -57,14 +62,16 @@ async def test_router_start_and_stop(redis_client) -> None: # type: ignore[vali
settings.SSE_CONSUMER_POOL_SIZE = 1
suffix = uuid4().hex[:6]
router = SSEKafkaRedisBridge(
- schema_registry=SchemaRegistryManager(),
+ schema_registry=SchemaRegistryManager(logger=_test_logger),
settings=settings,
event_metrics=EventMetrics(),
sse_bus=SSERedisBus(
redis_client,
exec_prefix=f"sse:exec:{suffix}:",
notif_prefix=f"sse:notif:{suffix}:",
+ logger=_test_logger,
),
+ logger=_test_logger,
)
await router.start()
diff --git a/backend/tests/unit/services/sse/test_redis_bus.py b/backend/tests/integration/services/sse/test_redis_bus.py
similarity index 90%
rename from backend/tests/unit/services/sse/test_redis_bus.py
rename to backend/tests/integration/services/sse/test_redis_bus.py
index c24598e3..ae54a6e4 100644
--- a/backend/tests/unit/services/sse/test_redis_bus.py
+++ b/backend/tests/integration/services/sse/test_redis_bus.py
@@ -1,15 +1,18 @@
import asyncio
import json
+import logging
from typing import Any
import pytest
-pytestmark = pytest.mark.unit
+pytestmark = pytest.mark.integration
from app.domain.enums.events import EventType
from app.schemas_pydantic.sse import RedisNotificationMessage, RedisSSEMessage
from app.services.sse.redis_bus import SSERedisBus
+_test_logger = logging.getLogger("test.services.sse.redis_bus")
+
class _DummyEvent:
def __init__(self, execution_id: str, event_type: EventType, extra: dict[str, Any] | None = None) -> None:
@@ -62,7 +65,7 @@ def pubsub(self) -> _FakePubSub:
@pytest.mark.asyncio
async def test_publish_and_subscribe_round_trip() -> None:
r = _FakeRedis()
- bus = SSERedisBus(r)
+ bus = SSERedisBus(r, logger=_test_logger)
# Subscribe
sub = await bus.open_subscription("exec-1")
@@ -77,13 +80,13 @@ async def test_publish_and_subscribe_round_trip() -> None:
assert ch.endswith("exec-1")
# Push to pubsub and read from subscription
await r._pubsub.push(ch, payload)
- msg = await sub.get(RedisSSEMessage, timeout=0.02)
+ msg = await sub.get(RedisSSEMessage)
assert msg and msg.event_type == EventType.EXECUTION_COMPLETED
assert msg.execution_id == "exec-1"
# Non-message / invalid JSON paths
await r._pubsub.push(ch, b"not-json")
- assert await sub.get(RedisSSEMessage, timeout=0.02) is None
+ assert await sub.get(RedisSSEMessage) is None
# Close
await sub.close()
@@ -93,7 +96,7 @@ async def test_publish_and_subscribe_round_trip() -> None:
@pytest.mark.asyncio
async def test_notifications_channels() -> None:
r = _FakeRedis()
- bus = SSERedisBus(r)
+ bus = SSERedisBus(r, logger=_test_logger)
nsub = await bus.open_notification_subscription("user-1")
assert "sse:notif:user-1" in r._pubsub.subscribed
@@ -111,7 +114,7 @@ async def test_notifications_channels() -> None:
ch, payload = r.published[-1]
assert ch.endswith("user-1")
await r._pubsub.push(ch, payload)
- got = await nsub.get(RedisNotificationMessage, timeout=0.02)
+ got = await nsub.get(RedisNotificationMessage)
assert got is not None
assert got.notification_id == "n1"
diff --git a/backend/tests/integration/test_admin_routes.py b/backend/tests/integration/test_admin_routes.py
index 54babedd..03206678 100644
--- a/backend/tests/integration/test_admin_routes.py
+++ b/backend/tests/integration/test_admin_routes.py
@@ -27,12 +27,12 @@ async def test_get_settings_requires_auth(self, client: AsyncClient) -> None:
assert "not authenticated" in error["detail"].lower() or "unauthorized" in error["detail"].lower()
@pytest.mark.asyncio
- async def test_get_settings_with_admin_auth(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_get_settings_with_admin_auth(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test getting system settings with admin authentication."""
# Login and get cookies
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -68,12 +68,12 @@ async def test_get_settings_with_admin_auth(self, client: AsyncClient, shared_ad
assert settings.monitoring_settings.sampling_rate == 0.1
@pytest.mark.asyncio
- async def test_update_and_reset_settings(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_update_and_reset_settings(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test updating and resetting system settings."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -125,12 +125,12 @@ async def test_update_and_reset_settings(self, client: AsyncClient, shared_admin
assert reset_settings.monitoring_settings.log_level == "INFO"
@pytest.mark.asyncio
- async def test_regular_user_cannot_access_settings(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_regular_user_cannot_access_settings(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that regular users cannot access admin settings."""
# Login as regular user
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -149,12 +149,12 @@ class TestAdminUsers:
"""Test admin user management endpoints against real backend."""
@pytest.mark.asyncio
- async def test_list_users_with_pagination(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_list_users_with_pagination(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test listing users with pagination."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -188,12 +188,12 @@ async def test_list_users_with_pagination(self, client: AsyncClient, shared_admi
assert "updated_at" in user
@pytest.mark.asyncio
- async def test_create_and_manage_user(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_create_and_manage_user(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test full user CRUD operations."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -257,12 +257,12 @@ class TestAdminEvents:
"""Test admin event management endpoints against real backend."""
@pytest.mark.asyncio
- async def test_browse_events(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_browse_events(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test browsing events with filters."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -291,12 +291,12 @@ async def test_browse_events(self, client: AsyncClient, shared_admin: Dict[str,
assert data["total"] >= 0
@pytest.mark.asyncio
- async def test_event_statistics(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_event_statistics(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test getting event statistics."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -324,10 +324,10 @@ async def test_event_statistics(self, client: AsyncClient, shared_admin: Dict[st
assert data["error_rate"] >= 0.0
@pytest.mark.asyncio
- async def test_admin_events_export_csv_and_json(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_admin_events_export_csv_and_json(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Export admin events as CSV and JSON and validate basic structure."""
# Login as admin
- login_data = {"username": shared_admin["username"], "password": shared_admin["password"]}
+ login_data = {"username": test_admin["username"], "password": test_admin["password"]}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -352,10 +352,10 @@ async def test_admin_events_export_csv_and_json(self, client: AsyncClient, share
@pytest.mark.asyncio
async def test_admin_user_rate_limits_and_password_reset(self, client: AsyncClient,
- shared_admin: Dict[str, str]) -> None:
+ test_admin: Dict[str, str]) -> None:
"""Create a user, manage rate limits, and reset password via admin endpoints."""
# Login as admin
- login_data = {"username": shared_admin["username"], "password": shared_admin["password"]}
+ login_data = {"username": test_admin["username"], "password": test_admin["password"]}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_dlq_routes.py b/backend/tests/integration/test_dlq_routes.py
index d21d928f..5cc114a0 100644
--- a/backend/tests/integration/test_dlq_routes.py
+++ b/backend/tests/integration/test_dlq_routes.py
@@ -1,3 +1,4 @@
+from datetime import datetime
from typing import Dict
import pytest
@@ -32,12 +33,12 @@ async def test_dlq_requires_authentication(self, client: AsyncClient) -> None:
for word in ["not authenticated", "unauthorized", "login"])
@pytest.mark.asyncio
- async def test_get_dlq_statistics(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_dlq_statistics(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting DLQ statistics."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -85,12 +86,12 @@ async def test_get_dlq_statistics(self, client: AsyncClient, shared_user: Dict[s
assert stats.age_stats[key] >= 0
@pytest.mark.asyncio
- async def test_list_dlq_messages(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_list_dlq_messages(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test listing DLQ messages with filters."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -129,12 +130,12 @@ async def test_list_dlq_messages(self, client: AsyncClient, shared_user: Dict[st
assert isinstance(message.details, dict)
@pytest.mark.asyncio
- async def test_filter_dlq_messages_by_status(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_filter_dlq_messages_by_status(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test filtering DLQ messages by status."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -152,12 +153,12 @@ async def test_filter_dlq_messages_by_status(self, client: AsyncClient, shared_u
assert message.status == status
@pytest.mark.asyncio
- async def test_filter_dlq_messages_by_topic(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_filter_dlq_messages_by_topic(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test filtering DLQ messages by topic."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -175,12 +176,12 @@ async def test_filter_dlq_messages_by_topic(self, client: AsyncClient, shared_us
assert message.original_topic == test_topic
@pytest.mark.asyncio
- async def test_get_single_dlq_message_detail(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_single_dlq_message_detail(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting detailed information for a single DLQ message."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -223,12 +224,12 @@ async def test_get_single_dlq_message_detail(self, client: AsyncClient, shared_u
assert message_detail.dlq_partition >= 0
@pytest.mark.asyncio
- async def test_get_nonexistent_dlq_message(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_nonexistent_dlq_message(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting a non-existent DLQ message."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -243,12 +244,12 @@ async def test_get_nonexistent_dlq_message(self, client: AsyncClient, shared_use
assert "not found" in error_data["detail"].lower()
@pytest.mark.asyncio
- async def test_set_retry_policy(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_set_retry_policy(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test setting a retry policy for a topic."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -273,12 +274,12 @@ async def test_set_retry_policy(self, client: AsyncClient, shared_user: Dict[str
assert policy_data["topic"] in result.message
@pytest.mark.asyncio
- async def test_retry_dlq_messages_batch(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_retry_dlq_messages_batch(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test retrying a batch of DLQ messages."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -318,12 +319,12 @@ async def test_retry_dlq_messages_batch(self, client: AsyncClient, shared_user:
assert "success" in detail
@pytest.mark.asyncio
- async def test_discard_dlq_message(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_discard_dlq_message(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test discarding a DLQ message."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -357,12 +358,12 @@ async def test_discard_dlq_message(self, client: AsyncClient, shared_user: Dict[
assert detail_data["status"] == "discarded"
@pytest.mark.asyncio
- async def test_get_dlq_topics_summary(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_dlq_topics_summary(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting DLQ topics summary."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -386,15 +387,15 @@ async def test_get_dlq_topics_summary(self, client: AsyncClient, shared_user: Di
# Check status breakdown
for status, count in topic_summary.status_breakdown.items():
- assert status in ["pending", "retrying", "failed", "discarded"]
+ assert status in ["pending", "scheduled", "retried", "discarded"]
assert isinstance(count, int)
assert count >= 0
- # Check dates if present
+ # Check dates if present (may be str or datetime)
if topic_summary.oldest_message:
- assert isinstance(topic_summary.oldest_message, str)
+ assert isinstance(topic_summary.oldest_message, (str, datetime))
if topic_summary.newest_message:
- assert isinstance(topic_summary.newest_message, str)
+ assert isinstance(topic_summary.newest_message, (str, datetime))
# Check retry stats
if topic_summary.avg_retry_count is not None:
@@ -403,12 +404,12 @@ async def test_get_dlq_topics_summary(self, client: AsyncClient, shared_user: Di
assert topic_summary.max_retry_count >= 0
@pytest.mark.asyncio
- async def test_dlq_message_pagination(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_dlq_message_pagination(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test DLQ message pagination."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -441,12 +442,12 @@ async def test_dlq_message_pagination(self, client: AsyncClient, shared_user: Di
assert len(page1_ids.intersection(page2_ids)) == 0
@pytest.mark.asyncio
- async def test_dlq_error_handling(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_dlq_error_handling(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test DLQ error handling for invalid requests."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_events_routes.py b/backend/tests/integration/test_events_routes.py
index 86c1cc85..342bd8ad 100644
--- a/backend/tests/integration/test_events_routes.py
+++ b/backend/tests/integration/test_events_routes.py
@@ -32,9 +32,9 @@ async def test_events_require_authentication(self, client: AsyncClient) -> None:
for word in ["not authenticated", "unauthorized", "login"])
@pytest.mark.asyncio
- async def test_get_user_events(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_user_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting user's events."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Get user events
response = await client.get("/api/v1/events/user?limit=10&skip=0")
@@ -73,9 +73,9 @@ async def test_get_user_events(self, client: AsyncClient, shared_user: Dict[str,
assert isinstance(event.correlation_id, str)
@pytest.mark.asyncio
- async def test_get_user_events_with_filters(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_user_events_with_filters(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test filtering user events."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Create an execution to generate events
execution_request = {
@@ -107,12 +107,12 @@ async def test_get_user_events_with_filters(self, client: AsyncClient, shared_us
events_response.events) == 0
@pytest.mark.asyncio
- async def test_get_execution_events(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_execution_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting events for a specific execution."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -147,12 +147,12 @@ async def test_get_execution_events(self, client: AsyncClient, shared_user: Dict
assert execution_id in event.aggregate_id or event.aggregate_id == execution_id
@pytest.mark.asyncio
- async def test_query_events_advanced(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_query_events_advanced(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test advanced event querying with filters."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -191,12 +191,12 @@ async def test_query_events_advanced(self, client: AsyncClient, shared_user: Dic
assert t1 >= t2 # Descending order
@pytest.mark.asyncio
- async def test_get_events_by_correlation_id(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_events_by_correlation_id(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting events by correlation ID."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -231,12 +231,12 @@ async def test_get_events_by_correlation_id(self, client: AsyncClient, shared_us
assert event.correlation_id == correlation_id
@pytest.mark.asyncio
- async def test_get_current_request_events(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_current_request_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting events for the current request."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -253,12 +253,12 @@ async def test_get_current_request_events(self, client: AsyncClient, shared_user
assert events_response.total >= 0
@pytest.mark.asyncio
- async def test_get_event_statistics(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_event_statistics(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting event statistics."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -289,12 +289,12 @@ async def test_get_event_statistics(self, client: AsyncClient, shared_user: Dict
assert hourly_stat["count"] >= 0
@pytest.mark.asyncio
- async def test_get_single_event(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_single_event(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting a single event by ID."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -320,12 +320,12 @@ async def test_get_single_event(self, client: AsyncClient, shared_user: Dict[str
assert event.timestamp is not None
@pytest.mark.asyncio
- async def test_get_nonexistent_event(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_nonexistent_event(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting a non-existent event."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -340,12 +340,12 @@ async def test_get_nonexistent_event(self, client: AsyncClient, shared_user: Dic
assert "not found" in error_data["detail"].lower()
@pytest.mark.asyncio
- async def test_list_event_types(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_list_event_types(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test listing available event types."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -371,12 +371,12 @@ async def test_list_event_types(self, client: AsyncClient, shared_user: Dict[str
assert len(event_type) > 0
@pytest.mark.asyncio
- async def test_publish_custom_event_requires_admin(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_publish_custom_event_requires_admin(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that publishing custom events requires admin privileges."""
# Login as regular user
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -397,12 +397,12 @@ async def test_publish_custom_event_requires_admin(self, client: AsyncClient, sh
@pytest.mark.asyncio
@pytest.mark.kafka
- async def test_publish_custom_event_as_admin(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_publish_custom_event_as_admin(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test publishing custom events as admin."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -434,12 +434,12 @@ async def test_publish_custom_event_as_admin(self, client: AsyncClient, shared_a
assert publish_response.timestamp is not None
@pytest.mark.asyncio
- async def test_aggregate_events(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_aggregate_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test event aggregation."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -469,12 +469,12 @@ async def test_aggregate_events(self, client: AsyncClient, shared_user: Dict[str
assert result["count"] >= 0
@pytest.mark.asyncio
- async def test_delete_event_requires_admin(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_delete_event_requires_admin(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that deleting events requires admin privileges."""
# Login as regular user
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -486,12 +486,12 @@ async def test_delete_event_requires_admin(self, client: AsyncClient, shared_use
@pytest.mark.asyncio
async def test_replay_aggregate_events_requires_admin(self, client: AsyncClient,
- shared_user: Dict[str, str]) -> None:
+ test_user: Dict[str, str]) -> None:
"""Test that replaying events requires admin privileges."""
# Login as regular user
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -502,12 +502,12 @@ async def test_replay_aggregate_events_requires_admin(self, client: AsyncClient,
assert response.status_code == 403 # Forbidden for non-admin
@pytest.mark.asyncio
- async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test replaying events in dry-run mode."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -543,12 +543,12 @@ async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, shared
assert "detail" in error_data
@pytest.mark.asyncio
- async def test_event_pagination(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_event_pagination(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test event pagination."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -582,13 +582,13 @@ async def test_event_pagination(self, client: AsyncClient, shared_user: Dict[str
@pytest.mark.asyncio
async def test_events_isolation_between_users(self, client: AsyncClient,
- shared_user: Dict[str, str],
- shared_admin: Dict[str, str]) -> None:
+ test_user: Dict[str, str],
+ test_admin: Dict[str, str]) -> None:
"""Test that events are properly isolated between users."""
# Get events as regular user
user_login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
user_login_response = await client.post("/api/v1/auth/login", data=user_login_data)
assert user_login_response.status_code == 200
@@ -601,8 +601,8 @@ async def test_events_isolation_between_users(self, client: AsyncClient,
# Get events as admin (without include_all_users flag)
admin_login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
admin_login_response = await client.post("/api/v1/auth/login", data=admin_login_data)
assert admin_login_response.status_code == 200
@@ -618,9 +618,9 @@ async def test_events_isolation_between_users(self, client: AsyncClient,
for event in user_events["events"]:
meta = event.get("metadata") or {}
if meta.get("user_id"):
- assert meta["user_id"] == shared_user.get("user_id", meta["user_id"])
+ assert meta["user_id"] == test_user.get("user_id", meta["user_id"])
for event in admin_events["events"]:
meta = event.get("metadata") or {}
if meta.get("user_id"):
- assert meta["user_id"] == shared_admin.get("user_id", meta["user_id"])
+ assert meta["user_id"] == test_admin.get("user_id", meta["user_id"])
diff --git a/backend/tests/integration/test_health_routes.py b/backend/tests/integration/test_health_routes.py
index e2204a74..40105561 100644
--- a/backend/tests/integration/test_health_routes.py
+++ b/backend/tests/integration/test_health_routes.py
@@ -48,11 +48,11 @@ async def test_concurrent_liveness_fetch(self, client: AsyncClient) -> None:
assert all(r.status_code == 200 for r in responses)
@pytest.mark.asyncio
- async def test_app_responds_during_load(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_app_responds_during_load(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
# Login first for creating load
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_notifications_routes.py b/backend/tests/integration/test_notifications_routes.py
index 3ea848f0..5e60164f 100644
--- a/backend/tests/integration/test_notifications_routes.py
+++ b/backend/tests/integration/test_notifications_routes.py
@@ -31,12 +31,12 @@ async def test_notifications_require_authentication(self, client: AsyncClient) -
for word in ["not authenticated", "unauthorized", "login"])
@pytest.mark.asyncio
- async def test_list_user_notifications(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_list_user_notifications(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test listing user's notifications."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -66,12 +66,12 @@ async def test_list_user_notifications(self, client: AsyncClient, shared_user: D
assert n.created_at is not None
@pytest.mark.asyncio
- async def test_filter_notifications_by_status(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_filter_notifications_by_status(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test filtering notifications by status."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -89,12 +89,12 @@ async def test_filter_notifications_by_status(self, client: AsyncClient, shared_
assert notification.status == status
@pytest.mark.asyncio
- async def test_get_unread_count(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_unread_count(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting count of unread notifications."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -113,12 +113,12 @@ async def test_get_unread_count(self, client: AsyncClient, shared_user: Dict[str
# Note: listing cannot filter 'unread' directly; count is authoritative
@pytest.mark.asyncio
- async def test_mark_notification_as_read(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_mark_notification_as_read(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test marking a notification as read."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -149,12 +149,12 @@ async def test_mark_notification_as_read(self, client: AsyncClient, shared_user:
@pytest.mark.asyncio
async def test_mark_nonexistent_notification_as_read(self, client: AsyncClient,
- shared_user: Dict[str, str]) -> None:
+ test_user: Dict[str, str]) -> None:
"""Test marking a non-existent notification as read."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -172,12 +172,12 @@ async def test_mark_nonexistent_notification_as_read(self, client: AsyncClient,
assert "not found" in error_data["detail"].lower()
@pytest.mark.asyncio
- async def test_mark_all_notifications_as_read(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_mark_all_notifications_as_read(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test marking all notifications as read."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -198,12 +198,12 @@ async def test_mark_all_notifications_as_read(self, client: AsyncClient, shared_
assert count_data["unread_count"] == 0
@pytest.mark.asyncio
- async def test_get_notification_subscriptions(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_get_notification_subscriptions(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test getting user's notification subscriptions."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -239,12 +239,12 @@ async def test_get_notification_subscriptions(self, client: AsyncClient, shared_
assert subscription.slack_webhook.startswith("http")
@pytest.mark.asyncio
- async def test_update_notification_subscription(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_update_notification_subscription(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test updating a notification subscription."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -284,12 +284,12 @@ async def test_update_notification_subscription(self, client: AsyncClient, share
break
@pytest.mark.asyncio
- async def test_update_webhook_subscription(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_update_webhook_subscription(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test updating webhook subscription with URL."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -316,12 +316,12 @@ async def test_update_webhook_subscription(self, client: AsyncClient, shared_use
assert updated_subscription.severities == update_data["severities"]
@pytest.mark.asyncio
- async def test_update_slack_subscription(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_update_slack_subscription(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test updating Slack subscription with webhook."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -351,12 +351,12 @@ async def test_update_slack_subscription(self, client: AsyncClient, shared_user:
assert updated_subscription.severities == update_data["severities"]
@pytest.mark.asyncio
- async def test_delete_notification(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_delete_notification(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test deleting a notification."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -388,12 +388,12 @@ async def test_delete_notification(self, client: AsyncClient, shared_user: Dict[
assert notification_id not in notification_ids
@pytest.mark.asyncio
- async def test_delete_nonexistent_notification(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_delete_nonexistent_notification(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test deleting a non-existent notification."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -408,12 +408,12 @@ async def test_delete_nonexistent_notification(self, client: AsyncClient, shared
assert "not found" in error_data["detail"].lower()
@pytest.mark.asyncio
- async def test_notification_pagination(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_notification_pagination(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test notification pagination."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -445,13 +445,13 @@ async def test_notification_pagination(self, client: AsyncClient, shared_user: D
@pytest.mark.asyncio
async def test_notifications_isolation_between_users(self, client: AsyncClient,
- shared_user: Dict[str, str],
- shared_admin: Dict[str, str]) -> None:
+ test_user: Dict[str, str],
+ test_admin: Dict[str, str]) -> None:
"""Test that notifications are isolated between users."""
# Login as regular user
user_login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
user_login_response = await client.post("/api/v1/auth/login", data=user_login_data)
assert user_login_response.status_code == 200
@@ -465,8 +465,8 @@ async def test_notifications_isolation_between_users(self, client: AsyncClient,
# Login as admin
admin_login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
admin_login_response = await client.post("/api/v1/auth/login", data=admin_login_data)
assert admin_login_response.status_code == 200
@@ -483,12 +483,12 @@ async def test_notifications_isolation_between_users(self, client: AsyncClient,
assert len(set(user_notification_ids).intersection(set(admin_notification_ids))) == 0
@pytest.mark.asyncio
- async def test_invalid_notification_channel(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_invalid_notification_channel(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test updating subscription with invalid channel."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_replay_routes.py b/backend/tests/integration/test_replay_routes.py
index 5c19b3c3..1cdf73ec 100644
--- a/backend/tests/integration/test_replay_routes.py
+++ b/backend/tests/integration/test_replay_routes.py
@@ -23,9 +23,9 @@ class TestReplayRoutes:
"""Test replay endpoints against real backend."""
@pytest.mark.asyncio
- async def test_replay_requires_admin_authentication(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_replay_requires_admin_authentication(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that replay endpoints require admin authentication."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Try to access replay endpoints as non-admin
response = await client.get("/api/v1/replay/sessions")
@@ -37,9 +37,9 @@ async def test_replay_requires_admin_authentication(self, client: AsyncClient, s
for word in ["admin", "forbidden", "denied"])
@pytest.mark.asyncio
- async def test_create_replay_session(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_create_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test creating a replay session."""
- # Already authenticated via shared_admin fixture
+ # Already authenticated via test_admin fixture
# Create replay session
replay_request = ReplayRequest(
@@ -67,9 +67,9 @@ async def test_create_replay_session(self, client: AsyncClient, shared_admin: Di
assert replay_response.message is not None
@pytest.mark.asyncio
- async def test_list_replay_sessions(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_list_replay_sessions(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test listing replay sessions."""
- # Already authenticated via shared_admin fixture
+ # Already authenticated via test_admin fixture
# List replay sessions
response = await client.get("/api/v1/replay/sessions?limit=10")
@@ -88,9 +88,9 @@ async def test_list_replay_sessions(self, client: AsyncClient, shared_admin: Dic
assert session_summary.created_at is not None
@pytest.mark.asyncio
- async def test_get_replay_session_details(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_get_replay_session_details(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test getting detailed information about a replay session."""
- # Already authenticated via shared_admin fixture
+ # Already authenticated via test_admin fixture
# Create a session first
replay_request = ReplayRequest(
@@ -121,12 +121,12 @@ async def test_get_replay_session_details(self, client: AsyncClient, shared_admi
assert session.created_at is not None
@pytest.mark.asyncio
- async def test_start_replay_session(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_start_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test starting a replay session."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -160,12 +160,12 @@ async def test_start_replay_session(self, client: AsyncClient, shared_admin: Dic
assert start_result.message is not None
@pytest.mark.asyncio
- async def test_pause_and_resume_replay_session(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_pause_and_resume_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test pausing and resuming a replay session."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -215,12 +215,12 @@ async def test_pause_and_resume_replay_session(self, client: AsyncClient, shared
assert resume_result.status in [ReplayStatus.RUNNING, ReplayStatus.COMPLETED]
@pytest.mark.asyncio
- async def test_cancel_replay_session(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_cancel_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test cancelling a replay session."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -254,12 +254,12 @@ async def test_cancel_replay_session(self, client: AsyncClient, shared_admin: Di
assert cancel_result.message is not None
@pytest.mark.asyncio
- async def test_filter_sessions_by_status(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_filter_sessions_by_status(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test filtering replay sessions by status."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -286,12 +286,12 @@ async def test_filter_sessions_by_status(self, client: AsyncClient, shared_admin
assert session.status == status
@pytest.mark.asyncio
- async def test_cleanup_old_sessions(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_cleanup_old_sessions(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test cleanup of old replay sessions."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -308,12 +308,12 @@ async def test_cleanup_old_sessions(self, client: AsyncClient, shared_admin: Dic
assert cleanup_result.message is not None
@pytest.mark.asyncio
- async def test_get_nonexistent_session(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_get_nonexistent_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test getting a non-existent replay session."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -329,12 +329,12 @@ async def test_get_nonexistent_session(self, client: AsyncClient, shared_admin:
assert "detail" in error_data
@pytest.mark.asyncio
- async def test_start_nonexistent_session(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_start_nonexistent_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test starting a non-existent replay session."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -346,12 +346,12 @@ async def test_start_nonexistent_session(self, client: AsyncClient, shared_admin
assert response.status_code in [400, 404]
@pytest.mark.asyncio
- async def test_replay_session_state_transitions(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_replay_session_state_transitions(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test valid state transitions for replay sessions."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -391,12 +391,12 @@ async def test_replay_session_state_transitions(self, client: AsyncClient, share
assert start_again_response.status_code in [200, 400, 409] # Might be idempotent or error
@pytest.mark.asyncio
- async def test_replay_with_complex_filters(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_replay_with_complex_filters(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test creating replay session with complex filters."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -416,7 +416,7 @@ async def test_replay_with_complex_filters(self, client: AsyncClient, shared_adm
"end_time": datetime.now(timezone.utc).isoformat(),
"aggregate_id": str(uuid4()),
"correlation_id": str(uuid4()),
- "user_id": shared_admin.get("user_id"),
+ "user_id": test_admin.get("user_id"),
"service_name": "execution-service"
},
"target_topic": "complex-filter-topic",
@@ -437,12 +437,12 @@ async def test_replay_with_complex_filters(self, client: AsyncClient, shared_adm
assert replay_response.status in ["created", "pending"]
@pytest.mark.asyncio
- async def test_replay_session_progress_tracking(self, client: AsyncClient, shared_admin: Dict[str, str]) -> None:
+ async def test_replay_session_progress_tracking(self, client: AsyncClient, test_admin: Dict[str, str]) -> None:
"""Test tracking progress of replay sessions."""
# Login as admin
login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_saga_routes.py b/backend/tests/integration/test_saga_routes.py
index b4a2449d..b26d7b90 100644
--- a/backend/tests/integration/test_saga_routes.py
+++ b/backend/tests/integration/test_saga_routes.py
@@ -1,15 +1,14 @@
-import uuid
import asyncio
+import uuid
from typing import Dict
import pytest
-from httpx import AsyncClient
-
from app.domain.enums.saga import SagaState
from app.schemas_pydantic.saga import (
SagaListResponse,
SagaStatusResponse,
)
+from httpx import AsyncClient
class TestSagaRoutes:
@@ -25,16 +24,16 @@ async def test_get_saga_requires_auth(self, client: AsyncClient) -> None:
@pytest.mark.asyncio
async def test_get_saga_not_found(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test getting non-existent saga returns 404."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Try to get non-existent saga
saga_id = str(uuid.uuid4())
response = await client.get(f"/api/v1/sagas/{saga_id}")
assert response.status_code == 404
- assert "Saga not found" in response.json()["detail"]
+ assert "not found" in response.json()["detail"]
@pytest.mark.asyncio
async def test_get_execution_sagas_requires_auth(
@@ -47,10 +46,10 @@ async def test_get_execution_sagas_requires_auth(
@pytest.mark.asyncio
async def test_get_execution_sagas_empty(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test getting sagas for execution with no sagas."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Get sagas for non-existent execution
execution_id = str(uuid.uuid4())
@@ -60,10 +59,10 @@ async def test_get_execution_sagas_empty(
@pytest.mark.asyncio
async def test_get_execution_sagas_with_state_filter(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test getting execution sagas filtered by state."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Get sagas filtered by running state
execution_id = str(uuid.uuid4())
@@ -86,10 +85,10 @@ async def test_list_sagas_requires_auth(self, client: AsyncClient) -> None:
@pytest.mark.asyncio
async def test_list_sagas_paginated(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test listing sagas with pagination."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# List sagas with pagination
response = await client.get(
@@ -105,13 +104,13 @@ async def test_list_sagas_paginated(
@pytest.mark.asyncio
async def test_list_sagas_with_state_filter(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test listing sagas filtered by state."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -131,13 +130,13 @@ async def test_list_sagas_with_state_filter(
@pytest.mark.asyncio
async def test_list_sagas_large_limit(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test listing sagas with maximum limit."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -154,13 +153,13 @@ async def test_list_sagas_large_limit(
@pytest.mark.asyncio
async def test_list_sagas_invalid_limit(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test listing sagas with invalid limit."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -181,13 +180,13 @@ async def test_cancel_saga_requires_auth(self, client: AsyncClient) -> None:
@pytest.mark.asyncio
async def test_cancel_saga_not_found(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test cancelling non-existent saga returns 404."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -196,20 +195,20 @@ async def test_cancel_saga_not_found(
saga_id = str(uuid.uuid4())
response = await client.post(f"/api/v1/sagas/{saga_id}/cancel")
assert response.status_code == 404
- assert "Saga not found" in response.json()["detail"]
+ assert "not found" in response.json()["detail"]
@pytest.mark.asyncio
async def test_saga_access_control(
self,
client: AsyncClient,
- shared_user: Dict[str, str],
+ test_user: Dict[str, str],
another_user: Dict[str, str]
) -> None:
"""Test that users can only access their own sagas."""
# User 1 lists their sagas
login_data1 = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response1 = await client.post("/api/v1/auth/login", data=login_data1)
assert login_response1.status_code == 200
@@ -241,13 +240,13 @@ async def test_saga_access_control(
@pytest.mark.asyncio
async def test_get_saga_with_details(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test getting saga with all details when it exists."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -272,13 +271,13 @@ async def test_get_saga_with_details(
@pytest.mark.asyncio
async def test_list_sagas_with_offset(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test listing sagas with offset for pagination."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -308,13 +307,13 @@ async def test_list_sagas_with_offset(
@pytest.mark.asyncio
async def test_cancel_saga_invalid_state(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test cancelling a saga in invalid state (if one exists)."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -336,13 +335,13 @@ async def test_cancel_saga_invalid_state(
@pytest.mark.asyncio
async def test_get_execution_sagas_multiple_states(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test getting execution sagas across different states."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -368,13 +367,13 @@ async def test_get_execution_sagas_multiple_states(
@pytest.mark.asyncio
async def test_saga_response_structure(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test that saga responses have correct structure."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -399,13 +398,13 @@ async def test_saga_response_structure(
@pytest.mark.asyncio
async def test_concurrent_saga_access(
- self, client: AsyncClient, shared_user: Dict[str, str]
+ self, client: AsyncClient, test_user: Dict[str, str]
) -> None:
"""Test concurrent access to saga endpoints."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_saved_scripts_routes.py b/backend/tests/integration/test_saved_scripts_routes.py
index b2dd8248..cc42b39c 100644
--- a/backend/tests/integration/test_saved_scripts_routes.py
+++ b/backend/tests/integration/test_saved_scripts_routes.py
@@ -33,9 +33,9 @@ async def test_create_script_requires_authentication(self, client: AsyncClient)
for word in ["not authenticated", "unauthorized", "login"])
@pytest.mark.asyncio
- async def test_create_and_retrieve_saved_script(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_create_and_retrieve_saved_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test creating and retrieving a saved script."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Create a unique script
unique_id = str(uuid4())[:8]
@@ -89,9 +89,9 @@ async def test_create_and_retrieve_saved_script(self, client: AsyncClient, share
assert retrieved_script.script == script_data["script"]
@pytest.mark.asyncio
- async def test_list_user_scripts(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_list_user_scripts(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test listing user's saved scripts."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Create a few scripts
unique_id = str(uuid4())[:8]
@@ -149,9 +149,9 @@ async def test_list_user_scripts(self, client: AsyncClient, shared_user: Dict[st
assert created_id in returned_ids
@pytest.mark.asyncio
- async def test_update_saved_script(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_update_saved_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test updating a saved script."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Create a script
unique_id = str(uuid4())[:8]
@@ -202,9 +202,9 @@ async def test_update_saved_script(self, client: AsyncClient, shared_user: Dict[
assert updated_script.updated_at > updated_script.created_at
@pytest.mark.asyncio
- async def test_delete_saved_script(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_delete_saved_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test deleting a saved script."""
- # Already authenticated via shared_user fixture
+ # Already authenticated via test_user fixture
# Create a script to delete
unique_id = str(uuid4())[:8]
@@ -234,13 +234,13 @@ async def test_delete_saved_script(self, client: AsyncClient, shared_user: Dict[
assert "detail" in error_data
@pytest.mark.asyncio
- async def test_cannot_access_other_users_scripts(self, client: AsyncClient, shared_user: Dict[str, str],
- shared_admin: Dict[str, str]) -> None:
+ async def test_cannot_access_other_users_scripts(self, client: AsyncClient, test_user: Dict[str, str],
+ test_admin: Dict[str, str]) -> None:
"""Test that users cannot access scripts created by other users."""
# Create a script as regular user
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -261,8 +261,8 @@ async def test_cannot_access_other_users_scripts(self, client: AsyncClient, shar
# Now login as admin
admin_login_data = {
- "username": shared_admin["username"],
- "password": shared_admin["password"]
+ "username": test_admin["username"],
+ "password": test_admin["password"]
}
admin_login_response = await client.post("/api/v1/auth/login", data=admin_login_data)
assert admin_login_response.status_code == 200
@@ -283,12 +283,12 @@ async def test_cannot_access_other_users_scripts(self, client: AsyncClient, shar
assert user_script_id not in admin_script_ids
@pytest.mark.asyncio
- async def test_script_with_invalid_language(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_script_with_invalid_language(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that invalid language/version combinations are handled."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -320,12 +320,12 @@ async def test_script_with_invalid_language(self, client: AsyncClient, shared_us
assert response.status_code in [200, 201, 400, 422]
@pytest.mark.asyncio
- async def test_script_name_constraints(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_script_name_constraints(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test script name validation and constraints."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -356,12 +356,12 @@ async def test_script_name_constraints(self, client: AsyncClient, shared_user: D
assert "detail" in error_data
@pytest.mark.asyncio
- async def test_script_content_size_limits(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_script_content_size_limits(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test script content size limits."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -396,12 +396,12 @@ async def test_script_content_size_limits(self, client: AsyncClient, shared_user
assert response.status_code in [200, 201, 400, 413, 422]
@pytest.mark.asyncio
- async def test_update_nonexistent_script(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_update_nonexistent_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test updating a non-existent script."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -423,12 +423,12 @@ async def test_update_nonexistent_script(self, client: AsyncClient, shared_user:
assert "detail" in error_data
@pytest.mark.asyncio
- async def test_delete_nonexistent_script(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_delete_nonexistent_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test deleting a non-existent script."""
# Login first
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
@@ -440,12 +440,12 @@ async def test_delete_nonexistent_script(self, client: AsyncClient, shared_user:
assert response.status_code in [404, 403, 204]
@pytest.mark.asyncio
- async def test_scripts_persist_across_sessions(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_scripts_persist_across_sessions(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
"""Test that scripts persist across login sessions."""
# First session - create script
login_data = {
- "username": shared_user["username"],
- "password": shared_user["password"]
+ "username": test_user["username"],
+ "password": test_user["password"]
}
login_response = await client.post("/api/v1/auth/login", data=login_data)
assert login_response.status_code == 200
diff --git a/backend/tests/integration/test_sse_routes.py b/backend/tests/integration/test_sse_routes.py
index 1078259c..ace4bc48 100644
--- a/backend/tests/integration/test_sse_routes.py
+++ b/backend/tests/integration/test_sse_routes.py
@@ -38,7 +38,7 @@ async def test_sse_requires_authentication(self, client: AsyncClient) -> None:
assert r.status_code == 401
@pytest.mark.asyncio
- async def test_sse_health_status(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_sse_health_status(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
r = await client.get("/api/v1/events/health")
assert r.status_code == 200
health = SSEHealthResponse(**r.json())
@@ -46,7 +46,7 @@ async def test_sse_health_status(self, client: AsyncClient, shared_user: Dict[st
assert isinstance(health.active_connections, int) and health.active_connections >= 0
@pytest.mark.asyncio
- async def test_notification_stream_service(self, scope, shared_user: Dict[str, str]) -> None: # type: ignore[valid-type]
+ async def test_notification_stream_service(self, scope, test_user: Dict[str, str]) -> None: # type: ignore[valid-type]
"""Test SSE notification stream directly through service (httpx doesn't support SSE streaming)."""
sse_service: SSEService = await scope.get(SSEService)
bus: SSERedisBus = await scope.get(SSERedisBus)
@@ -97,7 +97,7 @@ async def _connected() -> None:
assert len(notif_events) > 0
@pytest.mark.asyncio
- async def test_execution_event_stream_service(self, scope, shared_user: Dict[str, str]) -> None: # type: ignore[valid-type]
+ async def test_execution_event_stream_service(self, scope, test_user: Dict[str, str]) -> None: # type: ignore[valid-type]
"""Test SSE execution stream directly through service (httpx doesn't support SSE streaming)."""
exec_id = f"e-{uuid4().hex[:8]}"
user_id = "test-user-id"
@@ -158,7 +158,7 @@ async def test_sse_route_requires_auth(self, client: AsyncClient) -> None:
assert r.status_code == 401
@pytest.mark.asyncio
- async def test_sse_endpoint_returns_correct_headers(self, client: AsyncClient, shared_user: Dict[str, str]) -> None:
+ async def test_sse_endpoint_returns_correct_headers(self, client: AsyncClient, test_user: Dict[str, str]) -> None:
task = asyncio.create_task(client.get("/api/v1/events/notifications/stream"))
async def _tick() -> None:
@@ -176,7 +176,7 @@ async def _tick() -> None:
assert isinstance(r.json(), dict)
@pytest.mark.asyncio
- async def test_multiple_concurrent_sse_service_connections(self, scope, shared_user: Dict[str, str]) -> None: # type: ignore[valid-type]
+ async def test_multiple_concurrent_sse_service_connections(self, scope, test_user: Dict[str, str]) -> None: # type: ignore[valid-type]
"""Test multiple concurrent SSE connections through the service."""
sse_service: SSEService = await scope.get(SSEService)
diff --git a/backend/tests/unit/.env.unit b/backend/tests/unit/.env.unit
new file mode 100644
index 00000000..f3205c30
--- /dev/null
+++ b/backend/tests/unit/.env.unit
@@ -0,0 +1,4 @@
+TESTING=true
+SECRET_KEY=test-secret-key-for-testing-only-32chars!!
+ENABLE_TRACING=false
+OTEL_SDK_DISABLED=true
diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py
new file mode 100644
index 00000000..617233cf
--- /dev/null
+++ b/backend/tests/unit/conftest.py
@@ -0,0 +1,35 @@
+import os
+from pathlib import Path
+
+import pytest
+from dotenv import load_dotenv
+
+# Load unit test env
+unit_env = Path(__file__).parent / ".env.unit"
+load_dotenv(unit_env, override=True)
+
+
+@pytest.fixture(scope="function", autouse=False)
+def _cleanup():
+ """No-op - unit tests don't need DB/Redis cleanup."""
+ yield
+
+
+@pytest.fixture
+def db():
+ raise RuntimeError("Unit tests should not access DB - use mocks or move to integration/")
+
+
+@pytest.fixture
+def redis_client():
+ raise RuntimeError("Unit tests should not access Redis - use mocks or move to integration/")
+
+
+@pytest.fixture
+def client():
+ raise RuntimeError("Unit tests should not use HTTP client - use mocks or move to integration/")
+
+
+@pytest.fixture
+def app():
+ raise RuntimeError("Unit tests should not use full app - use mocks or move to integration/")
diff --git a/backend/tests/unit/core/metrics/test_metrics_context.py b/backend/tests/unit/core/metrics/test_metrics_context.py
index bc1240fd..c73001a9 100644
--- a/backend/tests/unit/core/metrics/test_metrics_context.py
+++ b/backend/tests/unit/core/metrics/test_metrics_context.py
@@ -1,9 +1,13 @@
+import logging
+
from app.core.metrics.context import (
MetricsContext,
get_connection_metrics,
get_coordinator_metrics,
)
+_test_logger = logging.getLogger("test.core.metrics.context")
+
def test_metrics_context_lazy_and_reset() -> None:
"""Test metrics context lazy loading and reset with no-op metrics."""
@@ -13,7 +17,7 @@ def test_metrics_context_lazy_and_reset() -> None:
assert c1 is c2 # same instance per context
d1 = get_coordinator_metrics()
- MetricsContext.reset_all()
+ MetricsContext.reset_all(_test_logger)
# after reset, new instances are created lazily
c3 = get_connection_metrics()
assert c3 is not c1
diff --git a/backend/tests/unit/core/test_exceptions.py b/backend/tests/unit/core/test_exceptions.py
deleted file mode 100644
index dcd6c1ba..00000000
--- a/backend/tests/unit/core/test_exceptions.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from fastapi import FastAPI
-from fastapi.testclient import TestClient
-
-from app.core.exceptions.base import AuthenticationError, IntegrationException, ServiceError
-from app.core.exceptions.handlers import configure_exception_handlers
-
-
-def make_app(raise_exc):
- app = FastAPI()
- configure_exception_handlers(app)
-
- @app.get("/boom")
- def boom(): # type: ignore[no-redef]
- raise raise_exc
-
- return app
-
-
-def test_integration_exception_handler():
- app = make_app(IntegrationException(418, "teapot"))
- with TestClient(app) as c:
- r = c.get("/boom")
- assert r.status_code == 418
- assert r.json()["detail"] == "teapot"
-
-
-def test_authentication_error_handler():
- app = make_app(AuthenticationError("nope"))
- with TestClient(app) as c:
- r = c.get("/boom")
- assert r.status_code == 401
- assert r.json()["detail"] == "nope"
-
-
-def test_service_error_handler():
- app = make_app(ServiceError("oops", status_code=503))
- with TestClient(app) as c:
- r = c.get("/boom")
- assert r.status_code == 503
- assert r.json()["detail"] == "oops"
-
diff --git a/backend/tests/unit/core/test_logging_and_correlation.py b/backend/tests/unit/core/test_logging_and_correlation.py
index 04cea6d9..bad1385f 100644
--- a/backend/tests/unit/core/test_logging_and_correlation.py
+++ b/backend/tests/unit/core/test_logging_and_correlation.py
@@ -84,5 +84,5 @@ async def ping(request: Request) -> JSONResponse:
def test_setup_logger_returns_logger():
- lg = setup_logger()
+ lg = setup_logger(log_level="INFO")
assert hasattr(lg, "info")
diff --git a/backend/tests/unit/db/repositories/__init__.py b/backend/tests/unit/db/repositories/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/backend/tests/unit/db/repositories/test_admin_events_repository.py b/backend/tests/unit/db/repositories/test_admin_events_repository.py
deleted file mode 100644
index e574e439..00000000
--- a/backend/tests/unit/db/repositories/test_admin_events_repository.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from datetime import datetime, timezone, timedelta
-
-import pytest
-
-from app.db.repositories.admin.admin_events_repository import AdminEventsRepository
-from app.domain.admin import ReplaySession, ReplayQuery
-from app.domain.admin.replay_updates import ReplaySessionUpdate
-from app.domain.enums.replay import ReplayStatus
-from app.domain.events.event_models import EventFields, EventFilter, EventStatistics, Event
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> AdminEventsRepository: # type: ignore[valid-type]
- return AdminEventsRepository(db)
-
-
-@pytest.mark.asyncio
-async def test_browse_detail_delete_and_export(repo: AdminEventsRepository, db) -> None: # type: ignore[valid-type]
- now = datetime.now(timezone.utc)
- await db.get_collection("events").insert_many([
- {EventFields.EVENT_ID: "e1", EventFields.EVENT_TYPE: "X", EventFields.TIMESTAMP: now, EventFields.METADATA: AvroEventMetadata(service_name="svc", service_version="1", correlation_id="c1").to_dict()},
- {EventFields.EVENT_ID: "e2", EventFields.EVENT_TYPE: "X", EventFields.TIMESTAMP: now, EventFields.METADATA: AvroEventMetadata(service_name="svc", service_version="1", correlation_id="c1").to_dict()},
- ])
- res = await repo.browse_events(EventFilter())
- assert res.total >= 2
- detail = await repo.get_event_detail("e1")
- assert detail and detail.event.event_id == "e1"
- assert await repo.delete_event("e2") is True
- rows = await repo.export_events_csv(EventFilter())
- assert isinstance(rows, list) and len(rows) >= 1
-
-
-@pytest.mark.asyncio
-async def test_event_stats_and_archive(repo: AdminEventsRepository, db) -> None: # type: ignore[valid-type]
- now = datetime.now(timezone.utc)
- await db.get_collection("events").insert_many([
- {EventFields.EVENT_ID: "e10", EventFields.EVENT_TYPE: "step.completed", EventFields.TIMESTAMP: now, EventFields.METADATA: AvroEventMetadata(service_name="svc", service_version="1", user_id="u1").to_dict()},
- ])
- await db.get_collection("executions").insert_one({"created_at": now, "status": "completed", "resource_usage": {"execution_time_wall_seconds": 1.25}})
- stats = await repo.get_event_stats(hours=1)
- assert isinstance(stats, EventStatistics)
- ev = Event(event_id="a1", event_type="X", event_version="1.0", timestamp=now, metadata=AvroEventMetadata(service_name="s", service_version="1"), payload={})
- assert await repo.archive_event(ev, deleted_by="admin") is True
-
-
-@pytest.mark.asyncio
-async def test_replay_session_flow_and_helpers(repo: AdminEventsRepository, db) -> None: # type: ignore[valid-type]
- # create/get/update
- session = ReplaySession(session_id="s1", status=ReplayStatus.SCHEDULED, total_events=1, correlation_id="corr", created_at=datetime.now(timezone.utc) - timedelta(seconds=5), dry_run=False)
- sid = await repo.create_replay_session(session)
- assert sid == "s1"
- got = await repo.get_replay_session("s1")
- assert got and got.session_id == "s1"
- session_update = ReplaySessionUpdate(status=ReplayStatus.RUNNING)
- assert await repo.update_replay_session("s1", session_update) is True
- detail = await repo.get_replay_status_with_progress("s1")
- assert detail and detail.session.session_id == "s1"
- assert await repo.count_events_for_replay({}) >= 0
- prev = await repo.get_replay_events_preview(event_ids=["e10"]) # from earlier insert
- assert isinstance(prev, dict)
-
diff --git a/backend/tests/unit/db/repositories/test_admin_user_repository.py b/backend/tests/unit/db/repositories/test_admin_user_repository.py
deleted file mode 100644
index c913029b..00000000
--- a/backend/tests/unit/db/repositories/test_admin_user_repository.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import pytest
-from datetime import datetime, timezone
-
-from app.db.repositories.admin.admin_user_repository import AdminUserRepository
-from app.domain.user import UserFields, UserUpdate, PasswordReset
-from app.core.security import SecurityService
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> AdminUserRepository: # type: ignore[valid-type]
- return AdminUserRepository(db)
-
-
-@pytest.mark.asyncio
-async def test_list_and_get_user(repo: AdminUserRepository, db) -> None: # type: ignore[valid-type]
- # Insert a user
- await db.get_collection("users").insert_one({
- UserFields.USER_ID: "u1",
- UserFields.USERNAME: "alice",
- UserFields.EMAIL: "alice@example.com",
- UserFields.ROLE: "user",
- UserFields.IS_ACTIVE: True,
- UserFields.IS_SUPERUSER: False,
- UserFields.HASHED_PASSWORD: "h",
- UserFields.CREATED_AT: datetime.now(timezone.utc),
- UserFields.UPDATED_AT: datetime.now(timezone.utc),
- })
- res = await repo.list_users(limit=10)
- assert res.total >= 1 and any(u.username == "alice" for u in res.users)
- user = await repo.get_user_by_id("u1")
- assert user and user.user_id == "u1"
-
-
-@pytest.mark.asyncio
-async def test_update_delete_and_reset_password(repo: AdminUserRepository, db, monkeypatch: pytest.MonkeyPatch) -> None: # type: ignore[valid-type]
- # Insert base user
- await db.get_collection("users").insert_one({
- UserFields.USER_ID: "u1",
- UserFields.USERNAME: "bob",
- UserFields.EMAIL: "bob@example.com",
- UserFields.ROLE: "user",
- UserFields.IS_ACTIVE: True,
- UserFields.IS_SUPERUSER: False,
- UserFields.HASHED_PASSWORD: "h",
- UserFields.CREATED_AT: datetime.now(timezone.utc),
- UserFields.UPDATED_AT: datetime.now(timezone.utc),
- })
- # No updates → returns current
- updated = await repo.update_user("u1", UserUpdate())
- assert updated and updated.user_id == "u1"
- # Delete cascade (collections empty → zeros)
- deleted = await repo.delete_user("u1", cascade=True)
- assert deleted["user"] in (0, 1)
- # Re-insert and reset password
- await db.get_collection("users").insert_one({
- UserFields.USER_ID: "u1",
- UserFields.USERNAME: "bob",
- UserFields.EMAIL: "bob@example.com",
- UserFields.ROLE: "user",
- UserFields.IS_ACTIVE: True,
- UserFields.IS_SUPERUSER: False,
- UserFields.HASHED_PASSWORD: "h",
- UserFields.CREATED_AT: datetime.now(timezone.utc),
- UserFields.UPDATED_AT: datetime.now(timezone.utc),
- })
- monkeypatch.setattr(SecurityService, "get_password_hash", staticmethod(lambda p: "HASHED"))
- pr = PasswordReset(user_id="u1", new_password="secret123")
- assert await repo.reset_user_password(pr) is True
-
-
-@pytest.mark.asyncio
-async def test_list_with_filters_and_reset_invalid(repo: AdminUserRepository, db) -> None: # type: ignore[valid-type]
- # Insert a couple of users
- await db.get_collection("users").insert_many([
- {UserFields.USER_ID: "u1", UserFields.USERNAME: "Alice", UserFields.EMAIL: "a@e.com", UserFields.ROLE: "user", UserFields.IS_ACTIVE: True, UserFields.IS_SUPERUSER: False, UserFields.HASHED_PASSWORD: "h", UserFields.CREATED_AT: datetime.now(timezone.utc), UserFields.UPDATED_AT: datetime.now(timezone.utc)},
- {UserFields.USER_ID: "u2", UserFields.USERNAME: "Bob", UserFields.EMAIL: "b@e.com", UserFields.ROLE: "admin", UserFields.IS_ACTIVE: True, UserFields.IS_SUPERUSER: True, UserFields.HASHED_PASSWORD: "h", UserFields.CREATED_AT: datetime.now(timezone.utc), UserFields.UPDATED_AT: datetime.now(timezone.utc)},
- ])
- res = await repo.list_users(limit=5, offset=0, search="Al", role=None)
- assert any(u.username.lower().startswith("al") for u in res.users) or res.total >= 0
- # invalid password reset (empty)
- with pytest.raises(ValueError):
- await repo.reset_user_password(PasswordReset(user_id="u1", new_password=""))
diff --git a/backend/tests/unit/db/repositories/test_dlq_repository.py b/backend/tests/unit/db/repositories/test_dlq_repository.py
deleted file mode 100644
index 4bf77eeb..00000000
--- a/backend/tests/unit/db/repositories/test_dlq_repository.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-
-from app.db.repositories.dlq_repository import DLQRepository
-from app.domain.enums.events import EventType
-from app.dlq import DLQFields, DLQMessageStatus
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> DLQRepository: # type: ignore[valid-type]
- return DLQRepository(db)
-
-
-def make_dlq_doc(eid: str, topic: str, etype: str, status: str = DLQMessageStatus.PENDING) -> dict:
- now = datetime.now(timezone.utc)
- # Build event dict compatible with event schema (top-level fields)
- event: dict[str, object] = {
- "event_type": etype,
- "metadata": {"service_name": "svc", "service_version": "1"},
- }
- if etype == str(EventType.USER_LOGGED_IN):
- event.update({"user_id": "u1", "login_method": "password"})
- elif etype == str(EventType.EXECUTION_STARTED):
- event.update({"execution_id": "x1", "pod_name": "p1"})
- return {
- DLQFields.EVENT: event,
- DLQFields.ORIGINAL_TOPIC: topic,
- DLQFields.ERROR: "err",
- DLQFields.RETRY_COUNT: 0,
- DLQFields.FAILED_AT: now,
- DLQFields.STATUS: status,
- DLQFields.PRODUCER_ID: "p1",
- DLQFields.EVENT_ID: eid,
- }
-
-
-@pytest.mark.asyncio
-async def test_stats_list_get_and_updates(repo: DLQRepository, db) -> None: # type: ignore[valid-type]
- await db.get_collection("dlq_messages").insert_many([
- make_dlq_doc("id1", "t1", str(EventType.USER_LOGGED_IN), DLQMessageStatus.PENDING),
- make_dlq_doc("id2", "t1", str(EventType.USER_LOGGED_IN), DLQMessageStatus.RETRIED),
- make_dlq_doc("id3", "t2", str(EventType.EXECUTION_STARTED), DLQMessageStatus.PENDING),
- ])
- stats = await repo.get_dlq_stats()
- assert isinstance(stats.by_status, dict) and len(stats.by_topic) >= 1
-
- res = await repo.get_messages(limit=2)
- assert res.total >= 3 and len(res.messages) <= 2
- msg = await repo.get_message_by_id("id1")
- assert msg and msg.event_id == "id1"
- assert await repo.mark_message_retried("id1") in (True, False)
- assert await repo.mark_message_discarded("id1", "r") in (True, False)
-
- topics = await repo.get_topics_summary()
- assert any(t.topic == "t1" for t in topics)
-
-
-@pytest.mark.asyncio
-async def test_retry_batch(repo: DLQRepository) -> None:
- class Manager:
- async def retry_message_manually(self, eid: str) -> bool: # noqa: ARG002
- return True
-
- result = await repo.retry_messages_batch(["missing"], Manager())
- # Missing messages cause failures
- assert result.total == 1 and result.failed >= 1
diff --git a/backend/tests/unit/db/repositories/test_event_repository.py b/backend/tests/unit/db/repositories/test_event_repository.py
deleted file mode 100644
index 5fe402f4..00000000
--- a/backend/tests/unit/db/repositories/test_event_repository.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from datetime import datetime, timezone, timedelta
-
-import pytest
-
-from app.db.repositories.event_repository import EventRepository
-from app.domain.events.event_models import Event, EventFields, EventFilter
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> EventRepository: # type: ignore[valid-type]
- return EventRepository(db)
-
-
-def make_event(event_id: str, etype: str = "UserLoggedIn", user: str | None = "u1", agg: str | None = "agg1") -> Event:
- return Event(
- event_id=event_id,
- event_type=etype,
- event_version="1.0",
- timestamp=datetime.now(timezone.utc),
- metadata=AvroEventMetadata(service_name="svc", service_version="1", user_id=user, correlation_id="c1"),
- payload={"k": 1, "execution_id": agg} if agg else {"k": 1},
- aggregate_id=agg,
- )
-
-
-@pytest.mark.asyncio
-async def test_store_get_and_queries(repo: EventRepository, db) -> None: # type: ignore[valid-type]
- e1 = make_event("e1", etype="A", agg="x1")
- e2 = make_event("e2", etype="B", agg="x2")
- await repo.store_event(e1)
- await repo.store_events_batch([e2])
- got = await repo.get_event("e1")
- assert got and got.event_id == "e1"
-
- now = datetime.now(timezone.utc)
- by_type = await repo.get_events_by_type("A", start_time=now - timedelta(days=1), end_time=now + timedelta(days=1))
- assert any(ev.event_id == "e1" for ev in by_type)
- by_agg = await repo.get_events_by_aggregate("x2")
- assert any(ev.event_id == "e2" for ev in by_agg)
- by_corr = await repo.get_events_by_correlation("c1")
- assert len(by_corr.events) >= 2
- by_user = await repo.get_events_by_user("u1", limit=10)
- assert len(by_user) >= 2
- exec_events = await repo.get_execution_events("x1")
- assert any(ev.event_id == "e1" for ev in exec_events.events)
-
-
-@pytest.mark.asyncio
-async def test_statistics_and_search_and_delete(repo: EventRepository, db) -> None: # type: ignore[valid-type]
- now = datetime.now(timezone.utc)
- await db.get_collection("events").insert_many([
- {EventFields.EVENT_ID: "e3", EventFields.EVENT_TYPE: "C", EventFields.EVENT_VERSION: "1.0", EventFields.TIMESTAMP: now, EventFields.METADATA: AvroEventMetadata(service_name="svc", service_version="1").to_dict(), EventFields.PAYLOAD: {}},
- ])
- stats = await repo.get_event_statistics(start_time=now - timedelta(days=1), end_time=now + timedelta(days=1))
- assert stats.total_events >= 1
-
- # search requires text index; guard if index not present
- try:
- res = await repo.search_events("test", filters=None, limit=10, skip=0)
- assert isinstance(res, list)
- except Exception:
- # Accept environments without text index
- pass
diff --git a/backend/tests/unit/db/repositories/test_execution_repository.py b/backend/tests/unit/db/repositories/test_execution_repository.py
deleted file mode 100644
index e0150af0..00000000
--- a/backend/tests/unit/db/repositories/test_execution_repository.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import pytest
-from uuid import uuid4
-from datetime import datetime, timezone
-
-from app.db.repositories.execution_repository import ExecutionRepository
-from app.domain.enums.execution import ExecutionStatus
-from app.domain.execution.models import DomainExecution, ResourceUsageDomain
-
-
-@pytest.mark.asyncio
-async def test_execution_crud_and_query(db) -> None: # type: ignore[valid-type]
- repo = ExecutionRepository(db)
-
- # Create
- e = DomainExecution(
- script="print('hello')",
- lang="python",
- lang_version="3.11",
- user_id=str(uuid4()),
- resource_usage=ResourceUsageDomain(0.0, 0, 0, 0),
- )
- created = await repo.create_execution(e)
- assert created.execution_id
-
- # Get
- got = await repo.get_execution(e.execution_id)
- assert got and got.script.startswith("print") and got.status == ExecutionStatus.QUEUED
-
- # Update
- ok = await repo.update_execution(e.execution_id, {"status": ExecutionStatus.RUNNING.value, "stdout": "ok"})
- assert ok is True
- got2 = await repo.get_execution(e.execution_id)
- assert got2 and got2.status == ExecutionStatus.RUNNING
-
- # List
- items = await repo.get_executions({"user_id": e.user_id}, limit=10, skip=0, sort=[("created_at", 1)])
- assert any(x.execution_id == e.execution_id for x in items)
-
- # Delete
- assert await repo.delete_execution(e.execution_id) is True
- assert await repo.get_execution(e.execution_id) is None
diff --git a/backend/tests/unit/db/repositories/test_notification_repository.py b/backend/tests/unit/db/repositories/test_notification_repository.py
deleted file mode 100644
index 2fd5d89b..00000000
--- a/backend/tests/unit/db/repositories/test_notification_repository.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from datetime import datetime, UTC, timedelta
-
-import pytest
-
-from app.db.repositories.notification_repository import NotificationRepository
-from app.domain.enums.notification import NotificationChannel, NotificationStatus, NotificationSeverity
-from app.domain.enums.notification import NotificationChannel as NC
-from app.domain.enums.user import UserRole
-from app.domain.notification import (
- DomainNotification,
- DomainNotificationSubscription,
-)
-from app.domain.user import UserFields
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> NotificationRepository: # type: ignore[valid-type]
- return NotificationRepository(db)
-
-
-@pytest.mark.asyncio
-async def test_create_indexes_and_crud(repo: NotificationRepository) -> None:
- await repo.create_indexes() # should not raise
-
- n = DomainNotification(
- user_id="u1",
- severity=NotificationSeverity.MEDIUM,
- tags=["execution", "completed"],
- channel=NotificationChannel.IN_APP,
- subject="sub",
- body="body",
- )
- _id = await repo.create_notification(n)
- assert _id
- # Modify and update
- n.subject = "updated"
- n.body = "new body"
- assert await repo.update_notification(n) is True
- got = await repo.get_notification(n.notification_id, n.user_id)
- assert got and got.notification_id == n.notification_id
- assert await repo.mark_as_read(n.notification_id, n.user_id) is True
- assert await repo.mark_all_as_read(n.user_id) >= 0
- assert await repo.delete_notification(n.notification_id, n.user_id) is True
-
-
-@pytest.mark.asyncio
-async def test_list_count_unread_and_pending(repo: NotificationRepository, db) -> None: # type: ignore[valid-type]
- now = datetime.now(UTC)
- # Seed notifications
- await db.get_collection("notifications").insert_many([
- {"notification_id": "n1", "user_id": "u1", "severity": NotificationSeverity.MEDIUM, "tags": ["execution"], "channel": NotificationChannel.IN_APP, "subject": "s", "body": "b", "status": NotificationStatus.PENDING, "created_at": now},
- {"notification_id": "n2", "user_id": "u1", "severity": NotificationSeverity.LOW, "tags": ["completed"], "channel": NotificationChannel.IN_APP, "subject": "s", "body": "b", "status": NotificationStatus.DELIVERED, "created_at": now},
- ])
- lst = await repo.list_notifications("u1")
- assert len(lst) >= 2
- assert await repo.count_notifications("u1") >= 2
- assert await repo.get_unread_count("u1") >= 0
-
- # Pending and scheduled
- pending = await repo.find_pending_notifications()
- assert any(n.status == NotificationStatus.PENDING for n in pending)
- await db.get_collection("notifications").insert_one({
- "notification_id": "n3", "user_id": "u1", "severity": NotificationSeverity.MEDIUM, "tags": ["execution"],
- "channel": NotificationChannel.IN_APP, "subject": "s", "body": "b", "status": NotificationStatus.PENDING,
- "created_at": now, "scheduled_for": now + timedelta(seconds=1)
- })
- scheduled = await repo.find_scheduled_notifications()
- assert isinstance(scheduled, list)
- assert await repo.cleanup_old_notifications(days=0) >= 0
-
-
-@pytest.mark.asyncio
-async def test_subscriptions_and_user_queries(repo: NotificationRepository, db) -> None: # type: ignore[valid-type]
- sub = DomainNotificationSubscription(user_id="u1", channel=NotificationChannel.IN_APP, severities=[])
- await repo.upsert_subscription("u1", NotificationChannel.IN_APP, sub)
- got = await repo.get_subscription("u1", NotificationChannel.IN_APP)
- assert got and got.user_id == "u1"
- subs = await repo.get_all_subscriptions("u1")
- assert len(subs) == len(list(NC))
-
- # Users by role and active users
- await db.get_collection("users").insert_many([
- {UserFields.USER_ID: "u1", UserFields.USERNAME: "A", UserFields.EMAIL: "a@e.com", UserFields.ROLE: "user", UserFields.IS_ACTIVE: True},
- {UserFields.USER_ID: "u2", UserFields.USERNAME: "B", UserFields.EMAIL: "b@e.com", UserFields.ROLE: "admin", UserFields.IS_ACTIVE: True},
- ])
- ids = await repo.get_users_by_roles([UserRole.USER])
- assert "u1" in ids or isinstance(ids, list)
- await db.get_collection("executions").insert_one({"execution_id": "e1", "user_id": "u2", "created_at": datetime.now(UTC)})
- active = await repo.get_active_users(days=1)
- assert set(active) >= {"u2"} or isinstance(active, list)
diff --git a/backend/tests/unit/db/repositories/test_replay_repository.py b/backend/tests/unit/db/repositories/test_replay_repository.py
deleted file mode 100644
index 8f8aaf2e..00000000
--- a/backend/tests/unit/db/repositories/test_replay_repository.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-
-from app.db.repositories.replay_repository import ReplayRepository
-from app.domain.admin.replay_updates import ReplaySessionUpdate
-from app.domain.enums.replay import ReplayStatus, ReplayType
-from app.domain.replay import ReplayConfig, ReplayFilter
-from app.schemas_pydantic.replay_models import ReplaySession
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> ReplayRepository: # type: ignore[valid-type]
- return ReplayRepository(db)
-
-
-@pytest.mark.asyncio
-async def test_indexes_and_session_crud(repo: ReplayRepository) -> None:
- await repo.create_indexes()
- config = ReplayConfig(replay_type=ReplayType.EXECUTION, filter=ReplayFilter())
- session = ReplaySession(session_id="s1", status=ReplayStatus.CREATED, created_at=datetime.now(timezone.utc), config=config)
- await repo.save_session(session)
- got = await repo.get_session("s1")
- assert got and got.session_id == "s1"
- lst = await repo.list_sessions(limit=5)
- assert any(s.session_id == "s1" for s in lst)
- assert await repo.update_session_status("s1", ReplayStatus.RUNNING) is True
- session_update = ReplaySessionUpdate(status=ReplayStatus.COMPLETED)
- assert await repo.update_replay_session("s1", session_update) is True
-
-
-@pytest.mark.asyncio
-async def test_count_fetch_events_and_delete(repo: ReplayRepository, db) -> None: # type: ignore[valid-type]
- now = datetime.now(timezone.utc)
- # Insert events
- await db.get_collection("events").insert_many([
- {"event_id": "e1", "timestamp": now, "execution_id": "x1", "event_type": "T", "metadata": {"user_id": "u1"}},
- {"event_id": "e2", "timestamp": now, "execution_id": "x2", "event_type": "T", "metadata": {"user_id": "u1"}},
- {"event_id": "e3", "timestamp": now, "execution_id": "x3", "event_type": "U", "metadata": {"user_id": "u2"}},
- ])
- cnt = await repo.count_events(ReplayFilter())
- assert cnt >= 3
- batches = []
- async for b in repo.fetch_events(ReplayFilter(), batch_size=2):
- batches.append(b)
- assert sum(len(b) for b in batches) >= 3
- # Delete old sessions (none match date predicate likely)
- assert await repo.delete_old_sessions("2000-01-01T00:00:00Z") >= 0
diff --git a/backend/tests/unit/db/repositories/test_saga_repository.py b/backend/tests/unit/db/repositories/test_saga_repository.py
deleted file mode 100644
index 0e3c8f0b..00000000
--- a/backend/tests/unit/db/repositories/test_saga_repository.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-
-from app.db.repositories.saga_repository import SagaRepository
-from app.domain.enums.saga import SagaState
-from app.domain.saga.models import Saga, SagaFilter, SagaListResult
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.fixture()
-def repo(db) -> SagaRepository: # type: ignore[valid-type]
- return SagaRepository(db)
-
-
-@pytest.mark.asyncio
-async def test_saga_crud_and_queries(repo: SagaRepository, db) -> None: # type: ignore[valid-type]
- now = datetime.now(timezone.utc)
- # Insert saga docs
- await db.get_collection("sagas").insert_many([
- {"saga_id": "s1", "saga_name": "test", "execution_id": "e1", "state": "running", "created_at": now, "updated_at": now},
- {"saga_id": "s2", "saga_name": "test2", "execution_id": "e2", "state": "completed", "created_at": now, "updated_at": now, "completed_at": now},
- ])
- saga = await repo.get_saga("s1")
- assert saga and saga.saga_id == "s1"
- lst = await repo.get_sagas_by_execution("e1")
- assert len(lst) >= 1
-
- f = SagaFilter(execution_ids=["e1"])
- result = await repo.list_sagas(f, limit=2)
- assert isinstance(result, SagaListResult)
-
- assert await repo.update_saga_state("s1", SagaState.COMPLETED) in (True, False)
-
- # user execution ids
- await db.get_collection("executions").insert_many([
- {"execution_id": "e1", "user_id": "u1"},
- {"execution_id": "e2", "user_id": "u1"},
- ])
- ids = await repo.get_user_execution_ids("u1")
- assert set(ids) == {"e1", "e2"}
-
- counts = await repo.count_sagas_by_state()
- assert isinstance(counts, dict) and ("running" in counts or "completed" in counts)
-
- stats = await repo.get_saga_statistics()
- assert isinstance(stats, dict) and "total" in stats
diff --git a/backend/tests/unit/db/repositories/test_sse_repository.py b/backend/tests/unit/db/repositories/test_sse_repository.py
deleted file mode 100644
index fde10f35..00000000
--- a/backend/tests/unit/db/repositories/test_sse_repository.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import pytest
-
-from app.db.repositories.sse_repository import SSERepository
-from app.domain.enums.execution import ExecutionStatus
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.mark.asyncio
-async def test_get_execution_status(db) -> None: # type: ignore[valid-type]
- repo = SSERepository(db)
- await db.get_collection("executions").insert_one({"execution_id": "e1", "status": "running"})
- status = await repo.get_execution_status("e1")
- assert status is not None
- assert status.status == ExecutionStatus.RUNNING
- assert status.execution_id == "e1"
-
-
-@pytest.mark.asyncio
-async def test_get_execution_status_none(db) -> None: # type: ignore[valid-type]
- repo = SSERepository(db)
- assert await repo.get_execution_status("missing") is None
-
-
-@pytest.mark.asyncio
-async def test_get_execution(db) -> None: # type: ignore[valid-type]
- repo = SSERepository(db)
- await db.get_collection("executions").insert_one({
- "execution_id": "e1",
- "status": "queued",
- "resource_usage": {}
- })
- doc = await repo.get_execution("e1")
- assert doc is not None
- assert doc.execution_id == "e1"
-
-
-@pytest.mark.asyncio
-async def test_get_execution_not_found(db) -> None: # type: ignore[valid-type]
- repo = SSERepository(db)
- assert await repo.get_execution("missing") is None
diff --git a/backend/tests/unit/db/repositories/test_user_repository.py b/backend/tests/unit/db/repositories/test_user_repository.py
deleted file mode 100644
index 3c43ca12..00000000
--- a/backend/tests/unit/db/repositories/test_user_repository.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import pytest
-from datetime import datetime, timezone
-
-from app.db.repositories.user_repository import UserRepository
-from app.domain.user.user_models import User as DomainUser, UserUpdate
-from app.domain.enums.user import UserRole
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.mark.asyncio
-async def test_create_get_update_delete_user(db) -> None: # type: ignore[valid-type]
- repo = UserRepository(db)
-
- # Create user
- user = DomainUser(
- user_id="", # let repo assign
- username="alice",
- email="alice@example.com",
- role=UserRole.USER,
- is_active=True,
- is_superuser=False,
- hashed_password="h",
- created_at=datetime.now(timezone.utc),
- updated_at=datetime.now(timezone.utc),
- )
- created = await repo.create_user(user)
- assert created.user_id
-
- # Get by username
- fetched = await repo.get_user("alice")
- assert fetched and fetched.username == "alice"
-
- # Get by id
- by_id = await repo.get_user_by_id(created.user_id)
- assert by_id and by_id.user_id == created.user_id
-
- # List with search + role
- users = await repo.list_users(limit=10, offset=0, search="ali", role=UserRole.USER)
- assert any(u.username == "alice" for u in users)
-
- # Update
- upd = UserUpdate(email="alice2@example.com")
- updated = await repo.update_user(created.user_id, upd)
- assert updated and updated.email == "alice2@example.com"
-
- # Delete
- assert await repo.delete_user(created.user_id) is True
- assert await repo.get_user("alice") is None
diff --git a/backend/tests/unit/db/repositories/test_user_settings_repository.py b/backend/tests/unit/db/repositories/test_user_settings_repository.py
deleted file mode 100644
index d9fdf48d..00000000
--- a/backend/tests/unit/db/repositories/test_user_settings_repository.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from datetime import datetime, timezone, timedelta
-
-import pytest
-
-from app.db.repositories.user_settings_repository import UserSettingsRepository
-from app.domain.enums.events import EventType
-from app.domain.user.settings_models import DomainUserSettings
-
-pytestmark = pytest.mark.unit
-
-
-@pytest.mark.asyncio
-async def test_user_settings_snapshot_and_events(db) -> None: # type: ignore[valid-type]
- repo = UserSettingsRepository(db)
-
- # Create indexes (should not raise)
- await repo.create_indexes()
-
- # Snapshot CRUD
- us = DomainUserSettings(user_id="u1")
- await repo.create_snapshot(us)
- got = await repo.get_snapshot("u1")
- assert got and got.user_id == "u1"
-
- # Insert events and query
- now = datetime.now(timezone.utc)
- await db.get_collection("events").insert_many([
- {
- "aggregate_id": "user_settings_u1",
- "event_type": str(EventType.USER_SETTINGS_UPDATED),
- "timestamp": now,
- "payload": {}
- },
- {
- "aggregate_id": "user_settings_u1",
- "event_type": str(EventType.USER_THEME_CHANGED),
- "timestamp": now,
- "payload": {}
- },
- ])
- evs = await repo.get_settings_events("u1", [EventType.USER_SETTINGS_UPDATED], since=now - timedelta(days=1))
- assert any(e.event_type == EventType.USER_SETTINGS_UPDATED for e in evs)
-
- # Counting helpers
- assert await repo.count_events_for_user("u1") >= 2
- assert await repo.count_events_since_snapshot("u1") >= 0
diff --git a/backend/tests/unit/dlq/test_dlq_models.py b/backend/tests/unit/dlq/test_dlq_models.py
deleted file mode 100644
index b1104c79..00000000
--- a/backend/tests/unit/dlq/test_dlq_models.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import json
-from datetime import datetime, timezone
-
-from app.dlq import (
- AgeStatistics,
- DLQFields,
- DLQMessageFilter,
- DLQMessageStatus,
- DLQStatistics,
- EventTypeStatistic,
- RetryPolicy,
- RetryStrategy,
- TopicStatistic,
-)
-from app.domain.enums.events import EventType
-from app.events.schema.schema_registry import SchemaRegistryManager
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-from app.infrastructure.kafka.events.user import UserLoggedInEvent
-from app.infrastructure.mappers.dlq_mapper import DLQMapper
-
-
-def _make_event() -> UserLoggedInEvent:
- from app.domain.enums.auth import LoginMethod
-
- return UserLoggedInEvent(
- user_id="u1",
- login_method=LoginMethod.PASSWORD,
- metadata=AvroEventMetadata(service_name="svc", service_version="1"),
- )
-
-
-def test_dlqmessage_mapper_roundtrip_minimal() -> None:
- ev = _make_event()
- data = {
- DLQFields.EVENT: ev.to_dict(),
- DLQFields.ORIGINAL_TOPIC: "t",
- DLQFields.ERROR: "err",
- DLQFields.RETRY_COUNT: 2,
- DLQFields.FAILED_AT: datetime.now(timezone.utc).isoformat(),
- DLQFields.STATUS: DLQMessageStatus.PENDING,
- DLQFields.PRODUCER_ID: "p1",
- }
- parsed = DLQMapper.from_mongo_document(data)
- assert parsed.original_topic == "t"
- assert parsed.event_type == str(ev.event_type)
-
-
-def test_from_kafka_message_and_headers() -> None:
- ev = _make_event()
- payload = {
- "event": ev.to_dict(),
- "original_topic": "t",
- "error": "E",
- "retry_count": 1,
- "failed_at": datetime.now(timezone.utc).isoformat(),
- "producer_id": "p",
- }
-
- class Msg:
- def value(self):
- return json.dumps(payload).encode()
-
- def headers(self):
- return [("k", b"v")]
-
- def offset(self):
- return 10
-
- def partition(self):
- return 0
-
- m = DLQMapper.from_kafka_message(Msg(), SchemaRegistryManager())
- assert m.original_topic == "t"
- assert m.headers.get("k") == "v"
- assert m.dlq_offset == 10
-
-
-def test_retry_policy_bounds() -> None:
- msg = DLQMapper.from_failed_event(_make_event(), "t", "e", "p", retry_count=0)
- # Immediate
- p1 = RetryPolicy(topic="t", strategy=RetryStrategy.IMMEDIATE)
- assert p1.should_retry(msg) is True
- assert isinstance(p1.get_next_retry_time(msg), datetime)
- # Fixed interval
- p2 = RetryPolicy(topic="t", strategy=RetryStrategy.FIXED_INTERVAL, base_delay_seconds=1)
- t2 = p2.get_next_retry_time(msg)
- assert (t2 - datetime.now(timezone.utc)).total_seconds() <= 2
- # Exponential backoff adds jitter but stays below max
- p3 = RetryPolicy(topic="t", strategy=RetryStrategy.EXPONENTIAL_BACKOFF, base_delay_seconds=1, max_delay_seconds=10)
- t3 = p3.get_next_retry_time(msg)
- assert (t3 - datetime.now(timezone.utc)).total_seconds() <= 11
- # Manual never retries
- p4 = RetryPolicy(topic="t", strategy=RetryStrategy.MANUAL)
- assert p4.should_retry(msg) is False
-
-
-def test_filter_and_stats_models() -> None:
- f = DLQMessageFilter(status=DLQMessageStatus.PENDING, topic="t", event_type=EventType.EXECUTION_REQUESTED)
- q = DLQMapper.filter_to_query(f)
- assert q[DLQFields.STATUS] == DLQMessageStatus.PENDING
- assert q[DLQFields.ORIGINAL_TOPIC] == "t"
- assert q[DLQFields.EVENT_TYPE] == EventType.EXECUTION_REQUESTED
-
- ts = TopicStatistic(topic="t", count=2, avg_retry_count=1.5)
- es = EventTypeStatistic(event_type="X", count=3)
- ages = AgeStatistics(min_age_seconds=1, max_age_seconds=10, avg_age_seconds=5)
- stats = DLQStatistics(by_status={"pending": 1}, by_topic=[ts], by_event_type=[es], age_stats=ages)
- assert stats.by_status["pending"] == 1
- assert isinstance(stats.timestamp, datetime)
diff --git a/backend/tests/unit/events/test_admin_utils.py b/backend/tests/unit/events/test_admin_utils.py
deleted file mode 100644
index 62751384..00000000
--- a/backend/tests/unit/events/test_admin_utils.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import os
-
-import pytest
-
-from app.events.admin_utils import AdminUtils
-
-
-@pytest.mark.kafka
-@pytest.mark.asyncio
-async def test_admin_utils_real_topic_checks() -> None:
- prefix = os.environ.get("KAFKA_TOPIC_PREFIX", "test.")
- topic = f"{prefix}adminutils.{os.environ.get('PYTEST_SESSION_ID','sid')}"
- au = AdminUtils()
-
- # Ensure topic exists (idempotent)
- res = await au.ensure_topics_exist([(topic, 1)])
- assert res.get(topic) in (True, False) # Some clusters may report exists
-
- exists = await au.check_topic_exists(topic)
- assert exists is True
diff --git a/backend/tests/unit/events/test_event_dispatcher.py b/backend/tests/unit/events/test_event_dispatcher.py
index a38b6224..28f7c92d 100644
--- a/backend/tests/unit/events/test_event_dispatcher.py
+++ b/backend/tests/unit/events/test_event_dispatcher.py
@@ -1,8 +1,12 @@
+import logging
+
from app.domain.enums.events import EventType
from app.events.core import EventDispatcher
from app.infrastructure.kafka.events.base import BaseEvent
from tests.helpers import make_execution_requested_event
+_test_logger = logging.getLogger("test.events.event_dispatcher")
+
def make_event():
return make_execution_requested_event(execution_id="e1")
@@ -13,7 +17,7 @@ async def _async_noop(_: BaseEvent) -> None:
def test_register_and_remove_handler() -> None:
- disp = EventDispatcher()
+ disp = EventDispatcher(logger=_test_logger)
# Register via direct method
disp.register_handler(EventType.EXECUTION_REQUESTED, _async_noop)
@@ -26,7 +30,7 @@ def test_register_and_remove_handler() -> None:
def test_decorator_registration() -> None:
- disp = EventDispatcher()
+ disp = EventDispatcher(logger=_test_logger)
@disp.register(EventType.EXECUTION_REQUESTED)
async def handler(ev: BaseEvent) -> None: # noqa: ARG001
@@ -36,7 +40,7 @@ async def handler(ev: BaseEvent) -> None: # noqa: ARG001
async def test_dispatch_metrics_processed_and_skipped() -> None:
- disp = EventDispatcher()
+ disp = EventDispatcher(logger=_test_logger)
called = {"n": 0}
@disp.register(EventType.EXECUTION_REQUESTED)
diff --git a/backend/tests/unit/events/test_metadata_model.py b/backend/tests/unit/events/test_metadata_model.py
index 94afa349..71440ce7 100644
--- a/backend/tests/unit/events/test_metadata_model.py
+++ b/backend/tests/unit/events/test_metadata_model.py
@@ -1,20 +1,6 @@
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-def test_to_dict() -> None:
- m = AvroEventMetadata(service_name="svc", service_version="1.0")
- d = m.to_dict()
- assert d["service_name"] == "svc"
- assert d["service_version"] == "1.0"
-
-
-def test_from_dict() -> None:
- m = AvroEventMetadata.from_dict({"service_name": "a", "service_version": "2", "user_id": "u"})
- assert m.service_name == "a"
- assert m.service_version == "2"
- assert m.user_id == "u"
-
-
def test_with_correlation() -> None:
m = AvroEventMetadata(service_name="svc", service_version="1")
m2 = m.with_correlation("cid")
diff --git a/backend/tests/unit/events/test_schema_registry_manager.py b/backend/tests/unit/events/test_schema_registry_manager.py
index 458b2323..9d155b5d 100644
--- a/backend/tests/unit/events/test_schema_registry_manager.py
+++ b/backend/tests/unit/events/test_schema_registry_manager.py
@@ -1,13 +1,17 @@
+import logging
+
import pytest
from app.events.schema.schema_registry import SchemaRegistryManager
+
+_test_logger = logging.getLogger("test.events.schema_registry_manager")
from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent
from app.infrastructure.kafka.events.metadata import AvroEventMetadata
from app.infrastructure.kafka.events.pod import PodCreatedEvent
def test_deserialize_json_execution_requested() -> None:
- m = SchemaRegistryManager()
+ m = SchemaRegistryManager(logger=_test_logger)
data = {
"event_type": "execution_requested",
"execution_id": "e1",
@@ -32,7 +36,7 @@ def test_deserialize_json_execution_requested() -> None:
def test_deserialize_json_missing_type_raises() -> None:
- m = SchemaRegistryManager()
+ m = SchemaRegistryManager(logger=_test_logger)
with pytest.raises(ValueError):
m.deserialize_json({})
diff --git a/backend/tests/unit/infrastructure/mappers/__init__.py b/backend/tests/unit/infrastructure/mappers/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/backend/tests/unit/infrastructure/mappers/test_admin_mapper.py b/backend/tests/unit/infrastructure/mappers/test_admin_mapper.py
deleted file mode 100644
index a63a6c2c..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_admin_mapper.py
+++ /dev/null
@@ -1,113 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.admin import (
- AuditAction,
- AuditLogEntry,
- ExecutionLimits,
- MonitoringSettings,
- SecuritySettings,
- SystemSettings,
-)
-from app.domain.user import User as DomainAdminUser
-from app.domain.user import UserCreation, UserRole, UserUpdate
-from app.infrastructure.mappers import AuditLogMapper, SettingsMapper, UserMapper
-from app.schemas_pydantic.user import User as ServiceUser
-
-pytestmark = pytest.mark.unit
-
-
-def _now() -> datetime:
- return datetime.now(timezone.utc)
-
-
-def test_user_mapper_roundtrip_and_validation() -> None:
- user = DomainAdminUser(
- user_id="u1",
- username="alice",
- email="alice@example.com",
- role=UserRole.USER,
- is_active=True,
- is_superuser=False,
- hashed_password="h",
- created_at=_now(),
- updated_at=_now(),
- )
-
- doc = UserMapper.to_mongo_document(user)
- back = UserMapper.from_mongo_document(doc)
- assert back.user_id == user.user_id and back.email == user.email
-
- # invalid email
- doc_bad = {**doc, "email": "bad"}
- with pytest.raises(ValueError):
- UserMapper.from_mongo_document(doc_bad)
-
- # missing required field
- bad2 = doc.copy(); bad2.pop("username")
- with pytest.raises(ValueError):
- UserMapper.from_mongo_document(bad2)
-
-
-def test_user_mapper_from_service_and_update_dict() -> None:
- service_user = ServiceUser(
- user_id="u2",
- username="bob",
- email="bob@example.com",
- role=UserRole.ADMIN,
- is_active=True,
- is_superuser=True,
- )
- domain_user = UserMapper.from_pydantic_service_user(service_user)
- assert domain_user.user_id == "u2" and domain_user.is_superuser is True
-
- upd = UserUpdate(username="new", email="new@example.com", role=UserRole.USER, is_active=False)
- upd_dict = UserMapper.to_update_dict(upd)
- assert upd_dict["username"] == "new" and upd_dict["role"] == UserRole.USER.value
-
- with pytest.raises(ValueError):
- UserMapper.to_update_dict(UserUpdate(email="bad"))
-
- creation = UserCreation(username="c", email="c@example.com", password="12345678")
- cdict = UserMapper.user_creation_to_dict(creation)
- assert "created_at" in cdict and "updated_at" in cdict
-
-
-def test_settings_mapper_roundtrip_defaults_and_custom() -> None:
- # defaults
- exec_limits = SettingsMapper.execution_limits_from_dict(None)
- assert exec_limits.max_timeout_seconds == 300
- sec = SettingsMapper.security_settings_from_dict(None)
- assert sec.password_min_length == 8
- mon = SettingsMapper.monnitoring_settings_from_dict if False else SettingsMapper.monitoring_settings_from_dict # type: ignore[attr-defined]
- mon_settings = mon(None)
- assert mon_settings.log_level is not None
-
- # to_dict/from_dict
- limits = ExecutionLimits(max_timeout_seconds=10, max_memory_mb=256, max_cpu_cores=1, max_concurrent_executions=2)
- sec = SecuritySettings(password_min_length=12, session_timeout_minutes=30, max_login_attempts=3, lockout_duration_minutes=5)
- mon = MonitoringSettings(metrics_retention_days=7, enable_tracing=False)
- sys = SystemSettings(execution_limits=limits, security_settings=sec, monitoring_settings=mon)
-
- sys_dict = SettingsMapper.system_settings_to_dict(sys)
- sys_back = SettingsMapper.system_settings_from_dict(sys_dict)
- assert sys_back.execution_limits.max_memory_mb == 256 and sys_back.monitoring_settings.enable_tracing is False
-
- # pydantic dict
- pyd = SettingsMapper.system_settings_to_pydantic_dict(sys)
- sys_back2 = SettingsMapper.system_settings_from_pydantic(pyd)
- assert sys_back2.security_settings.password_min_length == 12
-
-
-def test_audit_log_mapper_roundtrip() -> None:
- entry = AuditLogEntry(
- action=AuditAction.SYSTEM_SETTINGS_UPDATED,
- user_id="u1",
- username="alice",
- timestamp=_now(),
- changes={"k": "v"},
- reason="init",
- )
- d = AuditLogMapper.to_dict(entry)
- e2 = AuditLogMapper.from_dict(d)
- assert e2.action == entry.action and e2.reason == "init"
diff --git a/backend/tests/unit/infrastructure/mappers/test_dlq_mapper.py b/backend/tests/unit/infrastructure/mappers/test_dlq_mapper.py
deleted file mode 100644
index 185c320c..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_dlq_mapper.py
+++ /dev/null
@@ -1,361 +0,0 @@
-"""Tests for DLQ mapper."""
-
-import json
-from datetime import datetime, timezone
-from unittest.mock import MagicMock, patch
-
-import pytest
-from app.dlq.models import (
- DLQFields,
- DLQMessage,
- DLQMessageFilter,
- DLQMessageStatus,
- DLQMessageUpdate,
-)
-from app.domain.enums.events import EventType
-from app.infrastructure.mappers.dlq_mapper import DLQMapper
-from confluent_kafka import Message
-
-from tests.helpers import make_execution_requested_event
-
-
-@pytest.fixture
-def sample_event():
- """Create a sample event for testing."""
- return make_execution_requested_event(
- execution_id="exec-123",
- script="print('test')",
- )
-
-
-@pytest.fixture
-def sample_dlq_message(sample_event):
- """Create a sample DLQ message."""
- return DLQMessage(
- event=sample_event,
- original_topic="execution-events",
- error="Test error",
- retry_count=2,
- failed_at=datetime.now(timezone.utc),
- status=DLQMessageStatus.PENDING,
- producer_id="test-producer",
- event_id="event-123",
- created_at=datetime.now(timezone.utc),
- last_updated=datetime.now(timezone.utc),
- next_retry_at=datetime.now(timezone.utc),
- retried_at=datetime.now(timezone.utc),
- discarded_at=datetime.now(timezone.utc),
- discard_reason="Max retries exceeded",
- dlq_offset=100,
- dlq_partition=1,
- last_error="Connection timeout",
- )
-
-
-class TestDLQMapper:
- """Test DLQ mapper."""
-
- def test_to_mongo_document_full(self, sample_dlq_message):
- """Test converting DLQ message to MongoDB document with all fields."""
- doc = DLQMapper.to_mongo_document(sample_dlq_message)
-
- assert doc[DLQFields.EVENT] == sample_dlq_message.event.to_dict()
- assert doc[DLQFields.ORIGINAL_TOPIC] == "execution-events"
- assert doc[DLQFields.ERROR] == "Test error"
- assert doc[DLQFields.RETRY_COUNT] == 2
- assert doc[DLQFields.STATUS] == DLQMessageStatus.PENDING
- assert doc[DLQFields.PRODUCER_ID] == "test-producer"
- assert doc[DLQFields.EVENT_ID] == "event-123"
- assert DLQFields.CREATED_AT in doc
- assert DLQFields.LAST_UPDATED in doc
- assert DLQFields.NEXT_RETRY_AT in doc
- assert DLQFields.RETRIED_AT in doc
- assert DLQFields.DISCARDED_AT in doc
- assert doc[DLQFields.DISCARD_REASON] == "Max retries exceeded"
- assert doc[DLQFields.DLQ_OFFSET] == 100
- assert doc[DLQFields.DLQ_PARTITION] == 1
- assert doc[DLQFields.LAST_ERROR] == "Connection timeout"
-
- def test_to_mongo_document_minimal(self, sample_event):
- """Test converting minimal DLQ message to MongoDB document."""
- msg = DLQMessage(
- event=sample_event,
- original_topic="test-topic",
- error="Error",
- retry_count=0,
- failed_at=datetime.now(timezone.utc),
- status=DLQMessageStatus.PENDING,
- producer_id="producer",
- )
-
- doc = DLQMapper.to_mongo_document(msg)
-
- assert doc[DLQFields.EVENT] == sample_event.to_dict()
- assert doc[DLQFields.ORIGINAL_TOPIC] == "test-topic"
- assert doc[DLQFields.ERROR] == "Error"
- assert doc[DLQFields.RETRY_COUNT] == 0
- # event_id is extracted from event in __post_init__ if not provided
- assert doc[DLQFields.EVENT_ID] == sample_event.event_id
- # created_at is set in __post_init__ if not provided
- assert DLQFields.CREATED_AT in doc
- assert DLQFields.LAST_UPDATED not in doc
- assert DLQFields.NEXT_RETRY_AT not in doc
- assert DLQFields.RETRIED_AT not in doc
- assert DLQFields.DISCARDED_AT not in doc
- assert DLQFields.DISCARD_REASON not in doc
- assert DLQFields.DLQ_OFFSET not in doc
- assert DLQFields.DLQ_PARTITION not in doc
- assert DLQFields.LAST_ERROR not in doc
-
- def test_from_mongo_document_full(self, sample_dlq_message):
- """Test creating DLQ message from MongoDB document with all fields."""
- doc = DLQMapper.to_mongo_document(sample_dlq_message)
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager") as mock_registry:
- mock_registry.return_value.deserialize_json.return_value = sample_dlq_message.event
-
- msg = DLQMapper.from_mongo_document(doc)
-
- assert msg.event == sample_dlq_message.event
- assert msg.original_topic == "execution-events"
- assert msg.error == "Test error"
- assert msg.retry_count == 2
- assert msg.status == DLQMessageStatus.PENDING
- assert msg.producer_id == "test-producer"
- assert msg.event_id == "event-123"
- assert msg.discard_reason == "Max retries exceeded"
- assert msg.dlq_offset == 100
- assert msg.dlq_partition == 1
- assert msg.last_error == "Connection timeout"
-
- def test_from_mongo_document_minimal(self, sample_event):
- """Test creating DLQ message from minimal MongoDB document."""
- doc = {
- DLQFields.EVENT: sample_event.to_dict(),
- DLQFields.FAILED_AT: datetime.now(timezone.utc),
- }
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager") as mock_registry:
- mock_registry.return_value.deserialize_json.return_value = sample_event
-
- msg = DLQMapper.from_mongo_document(doc)
-
- assert msg.event == sample_event
- assert msg.original_topic == ""
- assert msg.error == ""
- assert msg.retry_count == 0
- assert msg.status == DLQMessageStatus.PENDING
- assert msg.producer_id == "unknown"
-
- def test_from_mongo_document_with_string_datetime(self, sample_event):
- """Test creating DLQ message from document with string datetime."""
- now = datetime.now(timezone.utc)
- doc = {
- DLQFields.EVENT: sample_event.to_dict(),
- DLQFields.FAILED_AT: now.isoformat(),
- DLQFields.CREATED_AT: now.isoformat(),
- }
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager") as mock_registry:
- mock_registry.return_value.deserialize_json.return_value = sample_event
-
- msg = DLQMapper.from_mongo_document(doc)
-
- assert msg.failed_at.replace(microsecond=0) == now.replace(microsecond=0)
- assert msg.created_at.replace(microsecond=0) == now.replace(microsecond=0)
-
- def test_from_mongo_document_missing_failed_at(self, sample_event):
- """Test creating DLQ message from document without failed_at raises error."""
- doc = {
- DLQFields.EVENT: sample_event.to_dict(),
- }
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager") as mock_registry:
- mock_registry.return_value.deserialize_json.return_value = sample_event
-
- with pytest.raises(ValueError, match="Missing failed_at"):
- DLQMapper.from_mongo_document(doc)
-
- def test_from_mongo_document_invalid_failed_at(self, sample_event):
- """Test creating DLQ message with invalid failed_at raises error."""
- doc = {
- DLQFields.EVENT: sample_event.to_dict(),
- DLQFields.FAILED_AT: None,
- }
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager") as mock_registry:
- mock_registry.return_value.deserialize_json.return_value = sample_event
-
- with pytest.raises(ValueError, match="Missing failed_at"):
- DLQMapper.from_mongo_document(doc)
-
- def test_from_mongo_document_invalid_event(self):
- """Test creating DLQ message with invalid event raises error."""
- doc = {
- DLQFields.FAILED_AT: datetime.now(timezone.utc),
- DLQFields.EVENT: "not a dict",
- }
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager"):
- with pytest.raises(ValueError, match="Missing or invalid event data"):
- DLQMapper.from_mongo_document(doc)
-
- def test_from_mongo_document_invalid_datetime_type(self, sample_event):
- """Test creating DLQ message with invalid datetime type raises error."""
- doc = {
- DLQFields.EVENT: sample_event.to_dict(),
- DLQFields.FAILED_AT: 12345, # Invalid type
- }
-
- with patch("app.infrastructure.mappers.dlq_mapper.SchemaRegistryManager") as mock_registry:
- mock_registry.return_value.deserialize_json.return_value = sample_event
-
- with pytest.raises(ValueError, match="Invalid datetime type"):
- DLQMapper.from_mongo_document(doc)
-
- def test_from_kafka_message(self, sample_event):
- """Test creating DLQ message from Kafka message."""
- mock_msg = MagicMock(spec=Message)
-
- event_data = {
- "event": sample_event.to_dict(),
- "original_topic": "test-topic",
- "error": "Test error",
- "retry_count": 3,
- "failed_at": datetime.now(timezone.utc).isoformat(),
- "producer_id": "test-producer",
- }
- mock_msg.value.return_value = json.dumps(event_data).encode("utf-8")
- mock_msg.headers.return_value = [
- ("trace-id", b"123"),
- ("correlation-id", b"456"),
- ]
- mock_msg.offset.return_value = 200
- mock_msg.partition.return_value = 2
-
- mock_registry = MagicMock()
- mock_registry.deserialize_json.return_value = sample_event
-
- msg = DLQMapper.from_kafka_message(mock_msg, mock_registry)
-
- assert msg.event == sample_event
- assert msg.original_topic == "test-topic"
- assert msg.error == "Test error"
- assert msg.retry_count == 3
- assert msg.producer_id == "test-producer"
- assert msg.dlq_offset == 200
- assert msg.dlq_partition == 2
- assert msg.headers["trace-id"] == "123"
- assert msg.headers["correlation-id"] == "456"
-
- def test_from_kafka_message_no_value(self):
- """Test creating DLQ message from Kafka message without value raises error."""
- mock_msg = MagicMock(spec=Message)
- mock_msg.value.return_value = None
-
- mock_registry = MagicMock()
-
- with pytest.raises(ValueError, match="Message has no value"):
- DLQMapper.from_kafka_message(mock_msg, mock_registry)
-
- def test_from_kafka_message_minimal(self, sample_event):
- """Test creating DLQ message from minimal Kafka message."""
- mock_msg = MagicMock(spec=Message)
-
- event_data = {
- "event": sample_event.to_dict(),
- }
- mock_msg.value.return_value = json.dumps(event_data).encode("utf-8")
- mock_msg.headers.return_value = None
- mock_msg.offset.return_value = -1 # Invalid offset
- mock_msg.partition.return_value = -1 # Invalid partition
-
- mock_registry = MagicMock()
- mock_registry.deserialize_json.return_value = sample_event
-
- msg = DLQMapper.from_kafka_message(mock_msg, mock_registry)
-
- assert msg.event == sample_event
- assert msg.original_topic == "unknown"
- assert msg.error == "Unknown error"
- assert msg.retry_count == 0
- assert msg.producer_id == "unknown"
- assert msg.dlq_offset is None
- assert msg.dlq_partition is None
- assert msg.headers == {}
-
- def test_from_failed_event(self, sample_event):
- """Test creating DLQ message from failed event."""
- msg = DLQMapper.from_failed_event(
- event=sample_event,
- original_topic="test-topic",
- error="Processing failed",
- producer_id="producer-123",
- retry_count=5,
- )
-
- assert msg.event == sample_event
- assert msg.original_topic == "test-topic"
- assert msg.error == "Processing failed"
- assert msg.producer_id == "producer-123"
- assert msg.retry_count == 5
- assert msg.status == DLQMessageStatus.PENDING
- assert msg.failed_at is not None
-
- def test_update_to_mongo_full(self):
- """Test converting DLQ message update to MongoDB update document."""
- update = DLQMessageUpdate(
- status=DLQMessageStatus.RETRIED,
- retry_count=3,
- next_retry_at=datetime.now(timezone.utc),
- retried_at=datetime.now(timezone.utc),
- discarded_at=datetime.now(timezone.utc),
- discard_reason="Too many retries",
- last_error="Connection timeout",
- extra={"custom_field": "value"},
- )
-
- doc = DLQMapper.update_to_mongo(update)
-
- assert doc[str(DLQFields.STATUS)] == DLQMessageStatus.RETRIED
- assert doc[str(DLQFields.RETRY_COUNT)] == 3
- assert str(DLQFields.NEXT_RETRY_AT) in doc
- assert str(DLQFields.RETRIED_AT) in doc
- assert str(DLQFields.DISCARDED_AT) in doc
- assert doc[str(DLQFields.DISCARD_REASON)] == "Too many retries"
- assert doc[str(DLQFields.LAST_ERROR)] == "Connection timeout"
- assert doc["custom_field"] == "value"
- assert str(DLQFields.LAST_UPDATED) in doc
-
- def test_update_to_mongo_minimal(self):
- """Test converting minimal DLQ message update to MongoDB update document."""
- update = DLQMessageUpdate(status=DLQMessageStatus.DISCARDED)
-
- doc = DLQMapper.update_to_mongo(update)
-
- assert doc[str(DLQFields.STATUS)] == DLQMessageStatus.DISCARDED
- assert str(DLQFields.LAST_UPDATED) in doc
- assert str(DLQFields.RETRY_COUNT) not in doc
- assert str(DLQFields.NEXT_RETRY_AT) not in doc
-
- def test_filter_to_query_full(self):
- """Test converting DLQ message filter to MongoDB query."""
- f = DLQMessageFilter(
- status=DLQMessageStatus.PENDING,
- topic="test-topic",
- event_type=EventType.EXECUTION_REQUESTED,
- )
-
- query = DLQMapper.filter_to_query(f)
-
- assert query[DLQFields.STATUS] == DLQMessageStatus.PENDING
- assert query[DLQFields.ORIGINAL_TOPIC] == "test-topic"
- assert query[DLQFields.EVENT_TYPE] == EventType.EXECUTION_REQUESTED
-
- def test_filter_to_query_empty(self):
- """Test converting empty DLQ message filter to MongoDB query."""
- f = DLQMessageFilter()
-
- query = DLQMapper.filter_to_query(f)
-
- assert query == {}
diff --git a/backend/tests/unit/infrastructure/mappers/test_event_mapper_extended.py b/backend/tests/unit/infrastructure/mappers/test_event_mapper_extended.py
deleted file mode 100644
index 04189d28..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_event_mapper_extended.py
+++ /dev/null
@@ -1,272 +0,0 @@
-"""Extended tests for event mapper to achieve 95%+ coverage."""
-
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.events.event_models import (
- ArchivedEvent,
- Event,
- EventFields,
- EventFilter,
-)
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-from app.infrastructure.mappers.event_mapper import (
- ArchivedEventMapper,
- EventExportRowMapper,
- EventFilterMapper,
- EventMapper,
-)
-
-
-@pytest.fixture
-def sample_metadata():
- """Create sample event metadata."""
- return AvroEventMetadata(
- service_name="test-service",
- service_version="1.0.0",
- correlation_id="corr-123",
- user_id="user-456",
- request_id="req-789",
- )
-
-
-@pytest.fixture
-def sample_event(sample_metadata):
- """Create a sample event with all optional fields."""
- return Event(
- event_id="event-123",
- event_type="test.event",
- event_version="2.0",
- timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc),
- metadata=sample_metadata,
- payload={"key": "value", "nested": {"data": 123}},
- aggregate_id="agg-456",
- stored_at=datetime(2024, 1, 15, 10, 30, 1, tzinfo=timezone.utc),
- ttl_expires_at=datetime(2024, 2, 15, 10, 30, 0, tzinfo=timezone.utc),
- status="processed",
- error="Some error occurred",
- )
-
-
-@pytest.fixture
-def minimal_event():
- """Create a minimal event without optional fields."""
- return Event(
- event_id="event-minimal",
- event_type="minimal.event",
- event_version="1.0",
- timestamp=datetime.now(timezone.utc),
- metadata=AvroEventMetadata(service_name="minimal-service", service_version="1.0.0"),
- payload={},
- )
-
-
-class TestEventMapper:
- """Test EventMapper with all branches."""
-
- def test_to_mongo_document_with_all_fields(self, sample_event):
- """Test converting event to MongoDB document with all optional fields."""
- doc = EventMapper.to_mongo_document(sample_event)
-
- assert doc[EventFields.EVENT_ID] == "event-123"
- assert doc[EventFields.EVENT_TYPE] == "test.event"
- assert doc[EventFields.EVENT_VERSION] == "2.0"
- assert doc[EventFields.TIMESTAMP] == sample_event.timestamp
- assert doc[EventFields.PAYLOAD] == {"key": "value", "nested": {"data": 123}}
- assert doc[EventFields.AGGREGATE_ID] == "agg-456"
- assert doc[EventFields.STORED_AT] == sample_event.stored_at
- assert doc[EventFields.TTL_EXPIRES_AT] == sample_event.ttl_expires_at
- assert doc[EventFields.STATUS] == "processed"
- assert doc[EventFields.ERROR] == "Some error occurred"
-
- def test_to_mongo_document_minimal(self, minimal_event):
- """Test converting minimal event to MongoDB document."""
- doc = EventMapper.to_mongo_document(minimal_event)
-
- assert doc[EventFields.EVENT_ID] == "event-minimal"
- assert doc[EventFields.EVENT_TYPE] == "minimal.event"
- assert EventFields.AGGREGATE_ID not in doc
- assert EventFields.STORED_AT not in doc
- assert EventFields.TTL_EXPIRES_AT not in doc
- assert EventFields.STATUS not in doc
- assert EventFields.ERROR not in doc
-
-
-class TestArchivedEventMapper:
- """Test ArchivedEventMapper with all branches."""
-
- def test_to_mongo_document_with_all_fields(self, sample_event):
- """Test converting archived event with all deletion fields."""
- archived = ArchivedEvent(
- event_id=sample_event.event_id,
- event_type=sample_event.event_type,
- event_version=sample_event.event_version,
- timestamp=sample_event.timestamp,
- metadata=sample_event.metadata,
- payload=sample_event.payload,
- aggregate_id=sample_event.aggregate_id,
- stored_at=sample_event.stored_at,
- ttl_expires_at=sample_event.ttl_expires_at,
- status=sample_event.status,
- error=sample_event.error,
- deleted_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc),
- deleted_by="admin-user",
- deletion_reason="Data cleanup",
- )
-
- doc = ArchivedEventMapper.to_mongo_document(archived)
-
- assert doc[EventFields.EVENT_ID] == sample_event.event_id
- assert doc[EventFields.DELETED_AT] == archived.deleted_at
- assert doc[EventFields.DELETED_BY] == "admin-user"
- assert doc[EventFields.DELETION_REASON] == "Data cleanup"
-
- def test_to_mongo_document_minimal_deletion_info(self, minimal_event):
- """Test converting archived event with minimal deletion info."""
- archived = ArchivedEvent(
- event_id=minimal_event.event_id,
- event_type=minimal_event.event_type,
- event_version=minimal_event.event_version,
- timestamp=minimal_event.timestamp,
- metadata=minimal_event.metadata,
- payload=minimal_event.payload,
- deleted_at=None,
- deleted_by=None,
- deletion_reason=None,
- )
-
- doc = ArchivedEventMapper.to_mongo_document(archived)
-
- assert doc[EventFields.EVENT_ID] == minimal_event.event_id
- assert EventFields.DELETED_AT not in doc
- assert EventFields.DELETED_BY not in doc
- assert EventFields.DELETION_REASON not in doc
-
-
-class TestEventFilterMapper:
- """Test EventFilterMapper with all branches."""
-
- def test_to_mongo_query_full(self):
- """Test converting filter with all fields to MongoDB query."""
- filt = EventFilter(
- event_types=["type1", "type2"],
- aggregate_id="agg-123",
- correlation_id="corr-456",
- user_id="user-789",
- service_name="test-service",
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- text_search="search term",
- )
- # Add status attribute dynamically
- filt.status = "completed"
-
- query = EventFilterMapper.to_mongo_query(filt)
-
- assert query[EventFields.EVENT_TYPE] == {"$in": ["type1", "type2"]}
- assert query[EventFields.AGGREGATE_ID] == "agg-123"
- assert query[EventFields.METADATA_CORRELATION_ID] == "corr-456"
- assert query[EventFields.METADATA_USER_ID] == "user-789"
- assert query[EventFields.METADATA_SERVICE_NAME] == "test-service"
- assert query[EventFields.STATUS] == "completed"
- assert query[EventFields.TIMESTAMP]["$gte"] == filt.start_time
- assert query[EventFields.TIMESTAMP]["$lte"] == filt.end_time
- assert query["$text"] == {"$search": "search term"}
-
- def test_to_mongo_query_with_search_text(self):
- """Test converting filter with search_text field."""
- filt = EventFilter(search_text="another search")
-
- query = EventFilterMapper.to_mongo_query(filt)
-
- assert query["$text"] == {"$search": "another search"}
-
- def test_to_mongo_query_only_start_time(self):
- """Test converting filter with only start_time."""
- filt = EventFilter(
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=None,
- )
-
- query = EventFilterMapper.to_mongo_query(filt)
-
- assert query[EventFields.TIMESTAMP] == {"$gte": filt.start_time}
-
- def test_to_mongo_query_only_end_time(self):
- """Test converting filter with only end_time."""
- filt = EventFilter(
- start_time=None,
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- )
-
- query = EventFilterMapper.to_mongo_query(filt)
-
- assert query[EventFields.TIMESTAMP] == {"$lte": filt.end_time}
-
- def test_to_mongo_query_minimal(self):
- """Test converting minimal filter."""
- filt = EventFilter()
-
- query = EventFilterMapper.to_mongo_query(filt)
-
- assert query == {}
-
- def test_to_mongo_query_with_individual_fields(self):
- """Test converting filter with individual fields set."""
- # Test each field individually to ensure all branches are covered
-
- # Test with event_types
- filt = EventFilter(event_types=["test"])
- query = EventFilterMapper.to_mongo_query(filt)
- assert EventFields.EVENT_TYPE in query
-
- # Test with aggregate_id
- filt = EventFilter(aggregate_id="agg-1")
- query = EventFilterMapper.to_mongo_query(filt)
- assert EventFields.AGGREGATE_ID in query
-
- # Test with correlation_id
- filt = EventFilter(correlation_id="corr-1")
- query = EventFilterMapper.to_mongo_query(filt)
- assert EventFields.METADATA_CORRELATION_ID in query
-
- # Test with user_id
- filt = EventFilter(user_id="user-1")
- query = EventFilterMapper.to_mongo_query(filt)
- assert EventFields.METADATA_USER_ID in query
-
- # Test with service_name
- filt = EventFilter(service_name="service-1")
- query = EventFilterMapper.to_mongo_query(filt)
- assert EventFields.METADATA_SERVICE_NAME in query
-
-
-class TestEventExportRowMapper:
- """Test EventExportRowMapper."""
-
- def test_from_event_with_all_fields(self, sample_event):
- """Test creating export row from event with all fields."""
- row = EventExportRowMapper.from_event(sample_event)
-
- assert row.event_id == "event-123"
- assert row.event_type == "test.event"
- assert row.correlation_id == "corr-123"
- assert row.aggregate_id == "agg-456"
- assert row.user_id == "user-456"
- assert row.service == "test-service"
- assert row.status == "processed"
- assert row.error == "Some error occurred"
-
- def test_from_event_minimal(self, minimal_event):
- """Test creating export row from minimal event."""
- row = EventExportRowMapper.from_event(minimal_event)
-
- assert row.event_id == "event-minimal"
- assert row.event_type == "minimal.event"
- # correlation_id is auto-generated, so it won't be empty
- assert row.correlation_id != ""
- assert row.aggregate_id == ""
- assert row.user_id == ""
- assert row.service == "minimal-service"
- assert row.status == ""
- assert row.error == ""
diff --git a/backend/tests/unit/infrastructure/mappers/test_infra_event_mapper.py b/backend/tests/unit/infrastructure/mappers/test_infra_event_mapper.py
deleted file mode 100644
index 77211da6..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_infra_event_mapper.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.events.event_models import (
- Event,
- EventSummary,
-)
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-from app.infrastructure.mappers import (
- ArchivedEventMapper,
- EventExportRowMapper,
- EventMapper,
- EventSummaryMapper,
-)
-
-pytestmark = pytest.mark.unit
-
-
-def _event(eid: str = "e1") -> Event:
- return Event(
- event_id=eid,
- event_type="X",
- event_version="1.0",
- timestamp=datetime.now(timezone.utc),
- metadata=AvroEventMetadata(service_name="svc", service_version="1", user_id="u1"),
- payload={"k": 1},
- aggregate_id="agg",
- status="ok",
- error=None,
- )
-
-
-def test_event_mapper_to_from_mongo() -> None:
- ev = _event()
- doc = EventMapper.to_mongo_document(ev)
- assert doc["event_id"] == ev.event_id and doc["payload"]["k"] == 1
-
- # from_mongo_document should move extra fields into payload
- mongo_doc = doc | {"custom": 123}
- back = EventMapper.from_mongo_document(mongo_doc)
- assert back.payload.get("custom") == 123
-
-
-def test_summary_mapper() -> None:
- e = _event()
- summary = EventSummary(
- event_id=e.event_id, event_type=e.event_type, timestamp=e.timestamp, aggregate_id=e.aggregate_id
- )
- s2 = EventSummaryMapper.from_mongo_document(
- {"event_id": summary.event_id, "event_type": summary.event_type, "timestamp": summary.timestamp}
- )
- assert s2.event_id == summary.event_id
-
-
-def test_archived_export_mapper() -> None:
- e = _event()
- arch = ArchivedEventMapper.from_event(e, deleted_by="admin", deletion_reason="r")
- assert arch.deleted_by == "admin"
- arch_doc = ArchivedEventMapper.to_mongo_document(arch)
- assert "_deleted_at" in arch_doc
- assert "_deleted_by" in arch_doc
- assert "_deletion_reason" in arch_doc
-
- row = type("Row", (), {})()
- row.event_id = e.event_id
- row.event_type = e.event_type
- row.timestamp = e.timestamp.isoformat()
- row.correlation_id = e.metadata.correlation_id or ""
- row.aggregate_id = e.aggregate_id or ""
- row.user_id = e.metadata.user_id or ""
- row.service = e.metadata.service_name
- row.status = e.status or ""
- row.error = e.error or ""
- ed = EventExportRowMapper.to_dict(row)
- assert ed["Event ID"] == e.event_id
diff --git a/backend/tests/unit/infrastructure/mappers/test_rate_limit_mapper.py b/backend/tests/unit/infrastructure/mappers/test_rate_limit_mapper.py
deleted file mode 100644
index 3a1ed0ce..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_rate_limit_mapper.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from datetime import datetime, timedelta, timezone
-
-import pytest
-from app.domain.rate_limit.rate_limit_models import (
- EndpointGroup,
- RateLimitAlgorithm,
- RateLimitConfig,
- RateLimitRule,
- UserRateLimit,
-)
-from app.infrastructure.mappers import (
- RateLimitConfigMapper,
- RateLimitRuleMapper,
- UserRateLimitMapper,
-)
-
-pytestmark = pytest.mark.unit
-
-
-def test_rule_mapper_roundtrip_defaults() -> None:
- rule = RateLimitRule(endpoint_pattern=r"^/api", group=EndpointGroup.API, requests=10, window_seconds=60)
- d = RateLimitRuleMapper.to_dict(rule)
- r2 = RateLimitRuleMapper.from_dict(d)
- assert r2.endpoint_pattern == rule.endpoint_pattern and r2.algorithm == RateLimitAlgorithm.SLIDING_WINDOW
-
-
-def test_user_rate_limit_mapper_roundtrip_and_dates() -> None:
- now = datetime.now(timezone.utc)
- u = UserRateLimit(user_id="u1", rules=[
- RateLimitRule(endpoint_pattern="/x", group=EndpointGroup.API, requests=1, window_seconds=1)], notes="n")
- d = UserRateLimitMapper.to_dict(u)
- u2 = UserRateLimitMapper.from_dict(d)
- assert u2.user_id == "u1" and len(u2.rules) == 1 and isinstance(u2.created_at, datetime)
-
- # from string timestamps
- d["created_at"] = now.isoformat()
- d["updated_at"] = (now + timedelta(seconds=1)).isoformat()
- u3 = UserRateLimitMapper.from_dict(d)
- assert u3.created_at <= u3.updated_at
-
-
-def test_config_mapper_roundtrip_and_json() -> None:
- cfg = RateLimitConfig(
- default_rules=[RateLimitRule(endpoint_pattern="/a", group=EndpointGroup.API, requests=1, window_seconds=1)],
- user_overrides={"u": UserRateLimit(user_id="u")}, global_enabled=False, redis_ttl=10)
- d = RateLimitConfigMapper.to_dict(cfg)
- c2 = RateLimitConfigMapper.from_dict(d)
- assert c2.redis_ttl == 10 and len(c2.default_rules) == 1 and "u" in c2.user_overrides
-
- js = RateLimitConfigMapper.model_dump_json(cfg)
- c3 = RateLimitConfigMapper.model_validate_json(js)
- assert isinstance(c3, RateLimitConfig) and c3.global_enabled is False
diff --git a/backend/tests/unit/infrastructure/mappers/test_rate_limit_mapper_extended.py b/backend/tests/unit/infrastructure/mappers/test_rate_limit_mapper_extended.py
deleted file mode 100644
index 72363534..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_rate_limit_mapper_extended.py
+++ /dev/null
@@ -1,321 +0,0 @@
-"""Extended tests for rate limit mapper to achieve 95%+ coverage."""
-
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.rate_limit import (
- EndpointGroup,
- RateLimitAlgorithm,
- RateLimitConfig,
- RateLimitRule,
- UserRateLimit,
-)
-from app.infrastructure.mappers.rate_limit_mapper import (
- RateLimitConfigMapper,
- RateLimitRuleMapper,
- UserRateLimitMapper,
-)
-
-
-@pytest.fixture
-def sample_rule():
- """Create a sample rate limit rule."""
- return RateLimitRule(
- endpoint_pattern="/api/*",
- group=EndpointGroup.PUBLIC,
- requests=100,
- window_seconds=60,
- burst_multiplier=2.0,
- algorithm=RateLimitAlgorithm.TOKEN_BUCKET,
- priority=10,
- enabled=True,
- )
-
-
-@pytest.fixture
-def sample_user_limit(sample_rule):
- """Create a sample user rate limit."""
- return UserRateLimit(
- user_id="user-123",
- bypass_rate_limit=False,
- global_multiplier=1.5,
- rules=[sample_rule],
- created_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
- updated_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
- notes="Test user with custom limits",
- )
-
-
-class TestUserRateLimitMapper:
- """Test UserRateLimitMapper with focus on uncovered branches."""
-
- def test_from_dict_without_created_at(self):
- """Test creating user rate limit without created_at field (lines 65-66)."""
- data = {
- "user_id": "user-test",
- "bypass_rate_limit": True,
- "global_multiplier": 2.0,
- "rules": [],
- "created_at": None, # Explicitly None
- "updated_at": "2024-01-15T10:00:00",
- "notes": "Test without created_at",
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-test"
- # Should default to current time
- assert user_limit.created_at is not None
- assert isinstance(user_limit.created_at, datetime)
-
- def test_from_dict_without_updated_at(self):
- """Test creating user rate limit without updated_at field (lines 71-72)."""
- data = {
- "user_id": "user-test2",
- "bypass_rate_limit": False,
- "global_multiplier": 1.0,
- "rules": [],
- "created_at": "2024-01-15T09:00:00",
- "updated_at": None, # Explicitly None
- "notes": "Test without updated_at",
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-test2"
- # Should default to current time
- assert user_limit.updated_at is not None
- assert isinstance(user_limit.updated_at, datetime)
-
- def test_from_dict_missing_timestamps(self):
- """Test creating user rate limit with missing timestamp fields."""
- data = {
- "user_id": "user-test3",
- "bypass_rate_limit": False,
- "global_multiplier": 1.0,
- "rules": [],
- # No created_at or updated_at fields at all
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-test3"
- # Both should default to current time
- assert user_limit.created_at is not None
- assert user_limit.updated_at is not None
- assert isinstance(user_limit.created_at, datetime)
- assert isinstance(user_limit.updated_at, datetime)
-
- def test_from_dict_with_empty_string_timestamps(self):
- """Test creating user rate limit with empty string timestamps."""
- data = {
- "user_id": "user-test4",
- "bypass_rate_limit": False,
- "global_multiplier": 1.0,
- "rules": [],
- "created_at": "", # Empty string (falsy)
- "updated_at": "", # Empty string (falsy)
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-test4"
- # Both should default to current time when falsy
- assert user_limit.created_at is not None
- assert user_limit.updated_at is not None
-
- def test_from_dict_with_zero_timestamps(self):
- """Test creating user rate limit with zero/falsy timestamps."""
- data = {
- "user_id": "user-test5",
- "bypass_rate_limit": False,
- "global_multiplier": 1.0,
- "rules": [],
- "created_at": 0, # Falsy number
- "updated_at": 0, # Falsy number
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-test5"
- # Both should default to current time when falsy
- assert user_limit.created_at is not None
- assert user_limit.updated_at is not None
-
- def test_model_dump(self, sample_user_limit):
- """Test model_dump method (line 87)."""
- result = UserRateLimitMapper.model_dump(sample_user_limit)
-
- assert result["user_id"] == "user-123"
- assert result["bypass_rate_limit"] is False
- assert result["global_multiplier"] == 1.5
- assert len(result["rules"]) == 1
- assert result["notes"] == "Test user with custom limits"
- # Check it's the same as to_dict
- assert result == UserRateLimitMapper.to_dict(sample_user_limit)
-
- def test_model_dump_with_minimal_data(self):
- """Test model_dump with minimal user rate limit."""
- minimal_limit = UserRateLimit(
- user_id="minimal-user",
- bypass_rate_limit=False,
- global_multiplier=1.0,
- rules=[],
- created_at=datetime.now(timezone.utc),
- updated_at=datetime.now(timezone.utc),
- notes=None,
- )
-
- result = UserRateLimitMapper.model_dump(minimal_limit)
-
- assert result["user_id"] == "minimal-user"
- assert result["bypass_rate_limit"] is False
- assert result["global_multiplier"] == 1.0
- assert result["rules"] == []
- assert result["notes"] is None
-
- def test_from_dict_with_datetime_objects(self):
- """Test from_dict when timestamps are already datetime objects."""
- now = datetime.now(timezone.utc)
- data = {
- "user_id": "user-datetime",
- "bypass_rate_limit": False,
- "global_multiplier": 1.0,
- "rules": [],
- "created_at": now, # Already a datetime
- "updated_at": now, # Already a datetime
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-datetime"
- assert user_limit.created_at == now
- assert user_limit.updated_at == now
-
- def test_from_dict_with_mixed_timestamp_types(self):
- """Test from_dict with one string and one None timestamp."""
- data = {
- "user_id": "user-mixed",
- "bypass_rate_limit": False,
- "global_multiplier": 1.0,
- "rules": [],
- "created_at": "2024-01-15T10:00:00", # String
- "updated_at": None, # None
- }
-
- user_limit = UserRateLimitMapper.from_dict(data)
-
- assert user_limit.user_id == "user-mixed"
- assert user_limit.created_at.year == 2024
- assert user_limit.created_at.month == 1
- assert user_limit.created_at.day == 15
- assert user_limit.updated_at is not None # Should be set to current time
-
-
-class TestRateLimitRuleMapper:
- """Additional tests for RateLimitRuleMapper."""
-
- def test_from_dict_with_defaults(self):
- """Test creating rule from dict with minimal data (using defaults)."""
- data = {
- "endpoint_pattern": "/api/test",
- "group": "public",
- "requests": 50,
- "window_seconds": 30,
- # Missing optional fields
- }
-
- rule = RateLimitRuleMapper.from_dict(data)
-
- assert rule.endpoint_pattern == "/api/test"
- assert rule.group == EndpointGroup.PUBLIC
- assert rule.requests == 50
- assert rule.window_seconds == 30
- # Check defaults
- assert rule.burst_multiplier == 1.5
- assert rule.algorithm == RateLimitAlgorithm.SLIDING_WINDOW
- assert rule.priority == 0
- assert rule.enabled is True
-
-
-class TestRateLimitConfigMapper:
- """Additional tests for RateLimitConfigMapper."""
-
- def test_model_validate_json(self):
- """Test model_validate_json method."""
- json_str = """
- {
- "default_rules": [
- {
- "endpoint_pattern": "/api/*",
- "group": "public",
- "requests": 100,
- "window_seconds": 60,
- "burst_multiplier": 1.5,
- "algorithm": "sliding_window",
- "priority": 0,
- "enabled": true
- }
- ],
- "user_overrides": {
- "user-123": {
- "user_id": "user-123",
- "bypass_rate_limit": true,
- "global_multiplier": 2.0,
- "rules": [],
- "created_at": null,
- "updated_at": null,
- "notes": "VIP user"
- }
- },
- "global_enabled": true,
- "redis_ttl": 7200
- }
- """
-
- config = RateLimitConfigMapper.model_validate_json(json_str)
-
- assert len(config.default_rules) == 1
- assert config.default_rules[0].endpoint_pattern == "/api/*"
- assert "user-123" in config.user_overrides
- assert config.user_overrides["user-123"].bypass_rate_limit is True
- assert config.global_enabled is True
- assert config.redis_ttl == 7200
-
- def test_model_validate_json_bytes(self):
- """Test model_validate_json with bytes input."""
- json_bytes = b'{"default_rules": [], "user_overrides": {}, "global_enabled": false, "redis_ttl": 3600}'
-
- config = RateLimitConfigMapper.model_validate_json(json_bytes)
-
- assert config.default_rules == []
- assert config.user_overrides == {}
- assert config.global_enabled is False
- assert config.redis_ttl == 3600
-
- def test_model_dump_json(self):
- """Test model_dump_json method."""
- config = RateLimitConfig(
- default_rules=[
- RateLimitRule(
- endpoint_pattern="/test",
- group=EndpointGroup.ADMIN,
- requests=1000,
- window_seconds=60,
- )
- ],
- user_overrides={},
- global_enabled=True,
- redis_ttl=3600,
- )
-
- json_str = RateLimitConfigMapper.model_dump_json(config)
-
- assert isinstance(json_str, str)
- # Parse it back to verify
- import json
- data = json.loads(json_str)
- assert len(data["default_rules"]) == 1
- assert data["default_rules"][0]["endpoint_pattern"] == "/test"
- assert data["global_enabled"] is True
- assert data["redis_ttl"] == 3600
diff --git a/backend/tests/unit/infrastructure/mappers/test_replay_api_mapper.py b/backend/tests/unit/infrastructure/mappers/test_replay_api_mapper.py
deleted file mode 100644
index e21d307b..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_replay_api_mapper.py
+++ /dev/null
@@ -1,391 +0,0 @@
-"""Tests for replay API mapper."""
-
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.enums.events import EventType
-from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
-from app.domain.replay import ReplayConfig, ReplayFilter, ReplaySessionState
-from app.infrastructure.mappers.replay_api_mapper import ReplayApiMapper
-from app.schemas_pydantic.replay import ReplayRequest
-
-
-@pytest.fixture
-def sample_filter():
- """Create a sample replay filter."""
- return ReplayFilter(
- execution_id="exec-123",
- event_types=[EventType.EXECUTION_REQUESTED, EventType.EXECUTION_COMPLETED],
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- user_id="user-456",
- service_name="test-service",
- custom_query={"status": "completed"},
- exclude_event_types=[EventType.EXECUTION_FAILED],
- )
-
-
-@pytest.fixture
-def sample_config(sample_filter):
- """Create a sample replay config."""
- return ReplayConfig(
- replay_type=ReplayType.EXECUTION,
- target=ReplayTarget.KAFKA,
- filter=sample_filter,
- speed_multiplier=2.0,
- preserve_timestamps=True,
- batch_size=100,
- max_events=1000,
- target_topics={EventType.EXECUTION_REQUESTED: "test-topic"},
- target_file_path="/tmp/replay.json",
- skip_errors=True,
- retry_failed=False,
- retry_attempts=3,
- enable_progress_tracking=True,
- )
-
-
-@pytest.fixture
-def sample_session_state(sample_config):
- """Create a sample replay session state."""
- return ReplaySessionState(
- session_id="session-789",
- config=sample_config,
- status=ReplayStatus.RUNNING,
- total_events=500,
- replayed_events=250,
- failed_events=5,
- skipped_events=10,
- created_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
- started_at=datetime(2024, 1, 1, 10, 1, 0, tzinfo=timezone.utc),
- completed_at=datetime(2024, 1, 1, 10, 11, 0, tzinfo=timezone.utc),
- last_event_at=datetime(2024, 1, 1, 10, 10, 30, tzinfo=timezone.utc),
- errors=[{"error": "Error 1", "timestamp": "2024-01-01T10:05:00"}, {"error": "Error 2", "timestamp": "2024-01-01T10:06:00"}],
- )
-
-
-class TestReplayApiMapper:
- """Test replay API mapper."""
-
- def test_filter_to_schema_full(self, sample_filter):
- """Test converting replay filter to schema with all fields."""
- schema = ReplayApiMapper.filter_to_schema(sample_filter)
-
- assert schema.execution_id == "exec-123"
- assert schema.event_types == ["execution_requested", "execution_completed"]
- assert schema.start_time == datetime(2024, 1, 1, tzinfo=timezone.utc)
- assert schema.end_time == datetime(2024, 1, 31, tzinfo=timezone.utc)
- assert schema.user_id == "user-456"
- assert schema.service_name == "test-service"
- assert schema.custom_query == {"status": "completed"}
- assert schema.exclude_event_types == ["execution_failed"]
-
- def test_filter_to_schema_minimal(self):
- """Test converting minimal replay filter to schema."""
- filter_obj = ReplayFilter()
-
- schema = ReplayApiMapper.filter_to_schema(filter_obj)
-
- assert schema.execution_id is None
- assert schema.event_types is None
- assert schema.start_time is None
- assert schema.end_time is None
- assert schema.user_id is None
- assert schema.service_name is None
- assert schema.custom_query is None
- assert schema.exclude_event_types is None
-
- def test_filter_to_schema_no_event_types(self):
- """Test converting replay filter with no event types."""
- filter_obj = ReplayFilter(
- execution_id="exec-456",
- event_types=None,
- exclude_event_types=None,
- )
-
- schema = ReplayApiMapper.filter_to_schema(filter_obj)
-
- assert schema.execution_id == "exec-456"
- assert schema.event_types is None
- assert schema.exclude_event_types is None
-
- def test_config_to_schema_full(self, sample_config):
- """Test converting replay config to schema with all fields."""
- schema = ReplayApiMapper.config_to_schema(sample_config)
-
- assert schema.replay_type == ReplayType.EXECUTION
- assert schema.target == ReplayTarget.KAFKA
- assert schema.filter is not None
- assert schema.filter.execution_id == "exec-123"
- assert schema.speed_multiplier == 2.0
- assert schema.preserve_timestamps is True
- assert schema.batch_size == 100
- assert schema.max_events == 1000
- assert schema.target_topics == {"execution_requested": "test-topic"}
- assert schema.target_file_path == "/tmp/replay.json"
- assert schema.skip_errors is True
- assert schema.retry_failed is False
- assert schema.retry_attempts == 3
- assert schema.enable_progress_tracking is True
-
- def test_config_to_schema_minimal(self):
- """Test converting minimal replay config to schema."""
- config = ReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- filter=ReplayFilter(),
- )
-
- schema = ReplayApiMapper.config_to_schema(config)
-
- assert schema.replay_type == ReplayType.TIME_RANGE
- assert schema.target == ReplayTarget.FILE
- assert schema.filter is not None
- # Default values from ReplayConfig
- assert schema.speed_multiplier == 1.0
- assert schema.preserve_timestamps is False
- assert schema.batch_size == 100
- assert schema.max_events is None
- assert schema.target_topics == {}
- assert schema.target_file_path is None
- assert schema.skip_errors is True
- assert schema.retry_failed is False
- assert schema.retry_attempts == 3
- assert schema.enable_progress_tracking is True
-
- def test_config_to_schema_no_target_topics(self):
- """Test converting replay config with no target topics."""
- config = ReplayConfig(
- replay_type=ReplayType.EXECUTION,
- target=ReplayTarget.KAFKA,
- filter=ReplayFilter(),
- target_topics=None,
- )
-
- schema = ReplayApiMapper.config_to_schema(config)
-
- assert schema.target_topics == {}
-
- def test_session_to_response(self, sample_session_state):
- """Test converting session state to response."""
- response = ReplayApiMapper.session_to_response(sample_session_state)
-
- assert response.session_id == "session-789"
- assert response.config is not None
- assert response.config.replay_type == ReplayType.EXECUTION
- assert response.status == ReplayStatus.RUNNING
- assert response.total_events == 500
- assert response.replayed_events == 250
- assert response.failed_events == 5
- assert response.skipped_events == 10
- assert response.created_at == datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
- assert response.started_at == datetime(2024, 1, 1, 10, 1, 0, tzinfo=timezone.utc)
- assert response.completed_at == datetime(2024, 1, 1, 10, 11, 0, tzinfo=timezone.utc)
- assert response.last_event_at == datetime(2024, 1, 1, 10, 10, 30, tzinfo=timezone.utc)
- assert response.errors == [{"error": "Error 1", "timestamp": "2024-01-01T10:05:00"}, {"error": "Error 2", "timestamp": "2024-01-01T10:06:00"}]
-
- def test_session_to_summary_with_duration(self, sample_session_state):
- """Test converting session state to summary with duration calculation."""
- summary = ReplayApiMapper.session_to_summary(sample_session_state)
-
- assert summary.session_id == "session-789"
- assert summary.replay_type == ReplayType.EXECUTION
- assert summary.target == ReplayTarget.KAFKA
- assert summary.status == ReplayStatus.RUNNING
- assert summary.total_events == 500
- assert summary.replayed_events == 250
- assert summary.failed_events == 5
- assert summary.skipped_events == 10
- assert summary.created_at == datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
- assert summary.started_at == datetime(2024, 1, 1, 10, 1, 0, tzinfo=timezone.utc)
- assert summary.completed_at == datetime(2024, 1, 1, 10, 11, 0, tzinfo=timezone.utc)
-
- # Duration should be 600 seconds (10 minutes)
- assert summary.duration_seconds == 600.0
-
- # Throughput should be 250 events / 600 seconds
- assert summary.throughput_events_per_second == pytest.approx(250 / 600.0)
-
- def test_session_to_summary_no_duration(self):
- """Test converting session state to summary without completed time."""
- state = ReplaySessionState(
- session_id="session-001",
- config=ReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- filter=ReplayFilter(),
- ),
- status=ReplayStatus.RUNNING,
- total_events=100,
- replayed_events=50,
- failed_events=0,
- skipped_events=0,
- created_at=datetime.now(timezone.utc),
- started_at=datetime.now(timezone.utc),
- completed_at=None, # Not completed yet
- )
-
- summary = ReplayApiMapper.session_to_summary(state)
-
- assert summary.duration_seconds is None
- assert summary.throughput_events_per_second is None
-
- def test_session_to_summary_zero_duration(self):
- """Test converting session state with zero duration."""
- now = datetime.now(timezone.utc)
- state = ReplaySessionState(
- session_id="session-002",
- config=ReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- filter=ReplayFilter(),
- ),
- status=ReplayStatus.COMPLETED,
- total_events=0,
- replayed_events=0,
- failed_events=0,
- skipped_events=0,
- created_at=now,
- started_at=now,
- completed_at=now, # Same time as started
- )
-
- summary = ReplayApiMapper.session_to_summary(state)
-
- assert summary.duration_seconds == 0.0
- # Throughput should be None when duration is 0
- assert summary.throughput_events_per_second is None
-
- def test_session_to_summary_no_start_time(self):
- """Test converting session state without start time."""
- state = ReplaySessionState(
- session_id="session-003",
- config=ReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- filter=ReplayFilter(),
- ),
- status=ReplayStatus.CREATED,
- total_events=100,
- replayed_events=0,
- failed_events=0,
- skipped_events=0,
- created_at=datetime.now(timezone.utc),
- started_at=None, # Not started yet
- completed_at=None,
- )
-
- summary = ReplayApiMapper.session_to_summary(state)
-
- assert summary.duration_seconds is None
- assert summary.throughput_events_per_second is None
-
- def test_request_to_filter_full(self):
- """Test converting replay request to filter with all fields."""
- request = ReplayRequest(
- replay_type=ReplayType.EXECUTION,
- target=ReplayTarget.KAFKA,
- execution_id="exec-999",
- event_types=[EventType.EXECUTION_REQUESTED],
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- user_id="user-999",
- service_name="service-999",
- )
-
- filter_obj = ReplayApiMapper.request_to_filter(request)
-
- assert filter_obj.execution_id == "exec-999"
- assert filter_obj.event_types == [EventType.EXECUTION_REQUESTED]
- assert filter_obj.start_time == datetime(2024, 1, 1, tzinfo=timezone.utc)
- assert filter_obj.end_time == datetime(2024, 1, 31, tzinfo=timezone.utc)
- assert filter_obj.user_id == "user-999"
- assert filter_obj.service_name == "service-999"
-
- def test_request_to_filter_with_none_times(self):
- """Test converting replay request with None times."""
- request = ReplayRequest(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- start_time=None,
- end_time=None,
- )
-
- filter_obj = ReplayApiMapper.request_to_filter(request)
-
- assert filter_obj.start_time is None
- assert filter_obj.end_time is None
-
- def test_request_to_config_full(self):
- """Test converting replay request to config with all fields."""
- request = ReplayRequest(
- replay_type=ReplayType.EXECUTION,
- target=ReplayTarget.KAFKA,
- execution_id="exec-888",
- event_types=[EventType.EXECUTION_COMPLETED],
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- user_id="user-888",
- service_name="service-888",
- speed_multiplier=3.0,
- preserve_timestamps=False,
- batch_size=50,
- max_events=500,
- skip_errors=False,
- target_file_path="/tmp/output.json",
- )
-
- config = ReplayApiMapper.request_to_config(request)
-
- assert config.replay_type == ReplayType.EXECUTION
- assert config.target == ReplayTarget.KAFKA
- assert config.filter.execution_id == "exec-888"
- assert config.filter.event_types == [EventType.EXECUTION_COMPLETED]
- assert config.speed_multiplier == 3.0
- assert config.preserve_timestamps is False
- assert config.batch_size == 50
- assert config.max_events == 500
- assert config.skip_errors is False
- assert config.target_file_path == "/tmp/output.json"
-
- def test_request_to_config_minimal(self):
- """Test converting minimal replay request to config."""
- request = ReplayRequest(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- )
-
- config = ReplayApiMapper.request_to_config(request)
-
- assert config.replay_type == ReplayType.TIME_RANGE
- assert config.target == ReplayTarget.FILE
- assert config.filter is not None
- # Default values from ReplayConfig
- assert config.speed_multiplier == 1.0
- assert config.preserve_timestamps == False
- assert config.batch_size == 100
- assert config.max_events is None
- assert config.skip_errors == True
- assert config.target_file_path is None
-
- def test_op_to_response(self):
- """Test converting operation to response."""
- response = ReplayApiMapper.op_to_response(
- session_id="session-777",
- status=ReplayStatus.COMPLETED,
- message="Replay completed successfully",
- )
-
- assert response.session_id == "session-777"
- assert response.status == ReplayStatus.COMPLETED
- assert response.message == "Replay completed successfully"
-
- def test_cleanup_to_response(self):
- """Test converting cleanup to response."""
- response = ReplayApiMapper.cleanup_to_response(
- removed_sessions=5,
- message="Cleaned up 5 old sessions",
- )
-
- assert response.removed_sessions == 5
- assert response.message == "Cleaned up 5 old sessions"
diff --git a/backend/tests/unit/infrastructure/mappers/test_replay_mapper.py b/backend/tests/unit/infrastructure/mappers/test_replay_mapper.py
deleted file mode 100644
index 08942ea3..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_replay_mapper.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.admin import (
- ReplayQuery,
- ReplaySession,
- ReplaySessionStatusDetail,
-)
-from app.domain.enums.replay import ReplayStatus
-from app.infrastructure.mappers import ReplayQueryMapper, ReplaySessionMapper
-
-pytestmark = pytest.mark.unit
-
-
-def _session() -> ReplaySession:
- return ReplaySession(
- session_id="s1",
- status=ReplayStatus.SCHEDULED,
- total_events=10,
- correlation_id="c",
- created_at=datetime.now(timezone.utc),
- dry_run=True,
- )
-
-
-def test_replay_session_mapper_roundtrip_and_status_helpers() -> None:
- s = _session()
- d = ReplaySessionMapper.to_dict(s)
- s2 = ReplaySessionMapper.from_dict(d)
- assert s2.session_id == s.session_id and s2.status == s.status
-
- info = ReplaySessionMapper.to_status_info(s2)
- di = ReplaySessionMapper.status_info_to_dict(info)
- assert di["session_id"] == s.session_id and di["status"] == s.status.value
-
- detail = ReplaySessionStatusDetail(session=s2, estimated_completion=datetime.now(timezone.utc))
- dd = ReplaySessionMapper.status_detail_to_dict(detail)
- assert dd["session_id"] == s.session_id and "execution_results" in dd
-
-
-def test_replay_query_mapper() -> None:
- q = ReplayQuery(event_ids=["e1"], correlation_id="x", aggregate_id="a")
- mq = ReplayQueryMapper.to_mongodb_query(q)
- assert "event_id" in mq and mq["metadata.correlation_id"] == "x" and mq["aggregate_id"] == "a"
-
- # time window
- q2 = ReplayQuery(start_time=datetime.now(timezone.utc), end_time=datetime.now(timezone.utc))
- mq2 = ReplayQueryMapper.to_mongodb_query(q2)
- assert "timestamp" in mq2 and "$gte" in mq2["timestamp"] and "$lte" in mq2["timestamp"]
-
-
diff --git a/backend/tests/unit/infrastructure/mappers/test_replay_mapper_extended.py b/backend/tests/unit/infrastructure/mappers/test_replay_mapper_extended.py
deleted file mode 100644
index 1c16328b..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_replay_mapper_extended.py
+++ /dev/null
@@ -1,418 +0,0 @@
-"""Extended tests for replay mapper to achieve 95%+ coverage."""
-
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.admin import (
- ReplayQuery,
- ReplaySession,
- ReplaySessionStatusDetail,
- ReplaySessionStatusInfo,
-)
-from app.domain.enums.events import EventType
-from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
-from app.domain.replay import ReplayConfig, ReplayFilter, ReplaySessionState
-from app.infrastructure.mappers.replay_mapper import (
- ReplayApiMapper,
- ReplayQueryMapper,
- ReplaySessionMapper,
- ReplayStateMapper,
-)
-from app.schemas_pydantic.admin_events import EventReplayRequest
-
-
-@pytest.fixture
-def sample_replay_session():
- """Create a sample replay session with all optional fields."""
- return ReplaySession(
- session_id="session-123",
- type="replay_session",
- status=ReplayStatus.RUNNING,
- total_events=100,
- replayed_events=50,
- failed_events=5,
- skipped_events=10,
- correlation_id="corr-456",
- created_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
- started_at=datetime(2024, 1, 1, 10, 1, 0, tzinfo=timezone.utc),
- completed_at=datetime(2024, 1, 1, 10, 30, 0, tzinfo=timezone.utc),
- error="Some error occurred",
- created_by="admin-user",
- target_service="test-service",
- dry_run=False,
- triggered_executions=["exec-1", "exec-2"],
- )
-
-
-@pytest.fixture
-def minimal_replay_session():
- """Create a minimal replay session without optional fields."""
- return ReplaySession(
- session_id="session-456",
- status=ReplayStatus.SCHEDULED,
- total_events=10,
- correlation_id="corr-789",
- created_at=datetime.now(timezone.utc),
- dry_run=True,
- )
-
-
-@pytest.fixture
-def sample_replay_config():
- """Create a sample replay config."""
- return ReplayConfig(
- replay_type=ReplayType.EXECUTION,
- target=ReplayTarget.KAFKA,
- filter=ReplayFilter(
- execution_id="exec-123",
- event_types=[EventType.EXECUTION_REQUESTED],
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- ),
- speed_multiplier=2.0,
- preserve_timestamps=True,
- batch_size=100,
- max_events=1000,
- )
-
-
-@pytest.fixture
-def sample_replay_session_state(sample_replay_config):
- """Create a sample replay session state."""
- return ReplaySessionState(
- session_id="state-123",
- config=sample_replay_config,
- status=ReplayStatus.RUNNING,
- total_events=500,
- replayed_events=250,
- failed_events=10,
- skipped_events=5,
- created_at=datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc),
- started_at=datetime(2024, 1, 1, 9, 1, 0, tzinfo=timezone.utc),
- completed_at=datetime(2024, 1, 1, 9, 30, 0, tzinfo=timezone.utc),
- last_event_at=datetime(2024, 1, 1, 9, 29, 30, tzinfo=timezone.utc),
- errors=["Error 1", "Error 2"],
- )
-
-
-class TestReplaySessionMapper:
- """Extended tests for ReplaySessionMapper."""
-
- def test_to_dict_with_all_optional_fields(self, sample_replay_session):
- """Test converting session to dict with all optional fields present."""
- result = ReplaySessionMapper.to_dict(sample_replay_session)
-
- assert result["session_id"] == "session-123"
- assert result["started_at"] == sample_replay_session.started_at
- assert result["completed_at"] == sample_replay_session.completed_at
- assert result["error"] == "Some error occurred"
- assert result["created_by"] == "admin-user"
- assert result["target_service"] == "test-service"
- assert result["triggered_executions"] == ["exec-1", "exec-2"]
-
- def test_to_dict_without_optional_fields(self, minimal_replay_session):
- """Test converting session to dict without optional fields."""
- result = ReplaySessionMapper.to_dict(minimal_replay_session)
-
- assert result["session_id"] == "session-456"
- assert "started_at" not in result
- assert "completed_at" not in result
- assert "error" not in result
- assert "created_by" not in result
- assert "target_service" not in result
-
- def test_from_dict_with_missing_fields(self):
- """Test creating session from dict with missing fields."""
- data = {} # Minimal data
-
- session = ReplaySessionMapper.from_dict(data)
-
- assert session.session_id == ""
- assert session.type == "replay_session"
- assert session.status == ReplayStatus.SCHEDULED
- assert session.total_events == 0
- assert session.replayed_events == 0
- assert session.failed_events == 0
- assert session.skipped_events == 0
- assert session.correlation_id == ""
- assert session.dry_run is False
- assert session.triggered_executions == []
-
- def test_status_detail_to_dict_without_estimated_completion(self, sample_replay_session):
- """Test converting status detail without estimated_completion."""
- detail = ReplaySessionStatusDetail(
- session=sample_replay_session,
- estimated_completion=None, # No estimated completion
- execution_results={"success": 10, "failed": 2},
- )
-
- result = ReplaySessionMapper.status_detail_to_dict(detail)
-
- assert result["session_id"] == "session-123"
- assert "estimated_completion" not in result
- assert result["execution_results"] == {"success": 10, "failed": 2}
-
- def test_status_detail_to_dict_with_estimated_completion(self, sample_replay_session):
- """Test converting status detail with estimated_completion."""
- estimated = datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc)
- detail = ReplaySessionStatusDetail(
- session=sample_replay_session,
- estimated_completion=estimated,
- )
-
- result = ReplaySessionMapper.status_detail_to_dict(detail)
-
- assert result["estimated_completion"] == estimated
-
- def test_to_status_info(self, sample_replay_session):
- """Test converting session to status info."""
- info = ReplaySessionMapper.to_status_info(sample_replay_session)
-
- assert isinstance(info, ReplaySessionStatusInfo)
- assert info.session_id == sample_replay_session.session_id
- assert info.status == sample_replay_session.status
- assert info.total_events == sample_replay_session.total_events
- assert info.replayed_events == sample_replay_session.replayed_events
- assert info.failed_events == sample_replay_session.failed_events
- assert info.skipped_events == sample_replay_session.skipped_events
- assert info.correlation_id == sample_replay_session.correlation_id
- assert info.created_at == sample_replay_session.created_at
- assert info.started_at == sample_replay_session.started_at
- assert info.completed_at == sample_replay_session.completed_at
- assert info.error == sample_replay_session.error
- assert info.progress_percentage == sample_replay_session.progress_percentage
-
-
-class TestReplayQueryMapper:
- """Extended tests for ReplayQueryMapper."""
-
- def test_to_mongodb_query_with_start_time_only(self):
- """Test query with only start_time."""
- query = ReplayQuery(
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc)
- )
-
- result = ReplayQueryMapper.to_mongodb_query(query)
-
- assert "timestamp" in result
- assert "$gte" in result["timestamp"]
- assert "$lte" not in result["timestamp"]
-
- def test_to_mongodb_query_with_end_time_only(self):
- """Test query with only end_time."""
- query = ReplayQuery(
- end_time=datetime(2024, 12, 31, tzinfo=timezone.utc)
- )
-
- result = ReplayQueryMapper.to_mongodb_query(query)
-
- assert "timestamp" in result
- assert "$lte" in result["timestamp"]
- assert "$gte" not in result["timestamp"]
-
- def test_to_mongodb_query_empty(self):
- """Test empty query."""
- query = ReplayQuery()
-
- result = ReplayQueryMapper.to_mongodb_query(query)
-
- assert result == {}
-
-
-class TestReplayApiMapper:
- """Tests for ReplayApiMapper."""
-
- def test_request_to_query_full(self):
- """Test converting full request to query."""
- request = EventReplayRequest(
- event_ids=["ev-1", "ev-2"],
- correlation_id="api-corr-123",
- aggregate_id="api-agg-456",
- start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
- end_time=datetime(2024, 1, 31, tzinfo=timezone.utc),
- )
-
- query = ReplayApiMapper.request_to_query(request)
-
- assert query.event_ids == ["ev-1", "ev-2"]
- assert query.correlation_id == "api-corr-123"
- assert query.aggregate_id == "api-agg-456"
- assert query.start_time == datetime(2024, 1, 1, tzinfo=timezone.utc)
- assert query.end_time == datetime(2024, 1, 31, tzinfo=timezone.utc)
-
- def test_request_to_query_minimal(self):
- """Test converting minimal request to query."""
- request = EventReplayRequest()
-
- query = ReplayApiMapper.request_to_query(request)
-
- assert query.event_ids is None
- assert query.correlation_id is None
- assert query.aggregate_id is None
- assert query.start_time is None
- assert query.end_time is None
-
-
-class TestReplayStateMapper:
- """Tests for ReplayStateMapper."""
-
- def test_to_mongo_document_full(self, sample_replay_session_state):
- """Test converting session state to mongo document with all fields."""
- doc = ReplayStateMapper.to_mongo_document(sample_replay_session_state)
-
- assert doc["session_id"] == "state-123"
- assert doc["status"] == ReplayStatus.RUNNING
- assert doc["total_events"] == 500
- assert doc["replayed_events"] == 250
- assert doc["failed_events"] == 10
- assert doc["skipped_events"] == 5
- assert doc["created_at"] == sample_replay_session_state.created_at
- assert doc["started_at"] == sample_replay_session_state.started_at
- assert doc["completed_at"] == sample_replay_session_state.completed_at
- assert doc["last_event_at"] == sample_replay_session_state.last_event_at
- assert doc["errors"] == ["Error 1", "Error 2"]
- assert "config" in doc
-
- def test_to_mongo_document_minimal(self):
- """Test converting minimal session state to mongo document."""
- minimal_state = ReplaySessionState(
- session_id="minimal-123",
- config=ReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- filter=ReplayFilter(),
- ),
- status=ReplayStatus.SCHEDULED,
- )
-
- doc = ReplayStateMapper.to_mongo_document(minimal_state)
-
- assert doc["session_id"] == "minimal-123"
- assert doc["status"] == ReplayStatus.SCHEDULED
- assert doc["total_events"] == 0
- assert doc["replayed_events"] == 0
- assert doc["failed_events"] == 0
- assert doc["skipped_events"] == 0
- assert doc["started_at"] is None
- assert doc["completed_at"] is None
- assert doc["last_event_at"] is None
- assert doc["errors"] == []
-
- def test_to_mongo_document_without_attributes(self):
- """Test converting object without expected attributes."""
- # Create a mock object without some attributes
- class MockSession:
- session_id = "mock-123"
- config = ReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.FILE,
- filter=ReplayFilter(),
- )
- status = ReplayStatus.RUNNING
- created_at = datetime.now(timezone.utc)
-
- mock_session = MockSession()
- doc = ReplayStateMapper.to_mongo_document(mock_session)
-
- # Should use getattr with defaults
- assert doc["total_events"] == 0
- assert doc["replayed_events"] == 0
- assert doc["failed_events"] == 0
- assert doc["skipped_events"] == 0
- assert doc["started_at"] is None
- assert doc["completed_at"] is None
- assert doc["last_event_at"] is None
- assert doc["errors"] == []
-
- def test_from_mongo_document_full(self):
- """Test creating session state from full mongo document."""
- doc = {
- "session_id": "from-mongo-123",
- "config": {
- "replay_type": "execution",
- "target": "kafka",
- "filter": {
- "execution_id": "exec-999",
- "event_types": ["execution_requested"],
- },
- "speed_multiplier": 3.0,
- "batch_size": 50,
- },
- "status": ReplayStatus.COMPLETED,
- "total_events": 100,
- "replayed_events": 100,
- "failed_events": 0,
- "skipped_events": 0,
- "started_at": datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
- "completed_at": datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
- "last_event_at": datetime(2024, 1, 1, 10, 9, 50, tzinfo=timezone.utc),
- "errors": ["Warning 1"],
- }
-
- state = ReplayStateMapper.from_mongo_document(doc)
-
- assert state.session_id == "from-mongo-123"
- assert state.status == ReplayStatus.COMPLETED
- assert state.total_events == 100
- assert state.replayed_events == 100
- assert state.failed_events == 0
- assert state.skipped_events == 0
- assert state.started_at == doc["started_at"]
- assert state.completed_at == doc["completed_at"]
- assert state.last_event_at == doc["last_event_at"]
- assert state.errors == ["Warning 1"]
-
- def test_from_mongo_document_minimal(self):
- """Test creating session state from minimal mongo document."""
- doc = {
- "config": {
- "replay_type": "time_range",
- "target": "kafka",
- "filter": {}, # Empty filter is valid
- }
- } # Minimal valid document
-
- state = ReplayStateMapper.from_mongo_document(doc)
-
- assert state.session_id == ""
- assert state.status == ReplayStatus.SCHEDULED # Default
- assert state.total_events == 0
- assert state.replayed_events == 0
- assert state.failed_events == 0
- assert state.skipped_events == 0
- assert state.started_at is None
- assert state.completed_at is None
- assert state.last_event_at is None
- assert state.errors == []
-
- def test_from_mongo_document_with_string_status(self):
- """Test creating session state with string status."""
- doc = {
- "session_id": "string-status-123",
- "status": "running", # String instead of enum
- "config": {
- "replay_type": "time_range",
- "target": "kafka",
- "filter": {},
- },
- }
-
- state = ReplayStateMapper.from_mongo_document(doc)
-
- assert state.status == ReplayStatus.RUNNING
-
- def test_from_mongo_document_with_enum_status(self):
- """Test creating session state with enum status."""
- doc = {
- "session_id": "enum-status-123",
- "status": ReplayStatus.FAILED, # Already an enum
- "config": {
- "replay_type": "execution",
- "target": "kafka",
- "filter": {},
- },
- }
-
- state = ReplayStateMapper.from_mongo_document(doc)
-
- assert state.status == ReplayStatus.FAILED
diff --git a/backend/tests/unit/infrastructure/mappers/test_saga_mapper.py b/backend/tests/unit/infrastructure/mappers/test_saga_mapper.py
deleted file mode 100644
index 6800c08f..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_saga_mapper.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.enums.saga import SagaState
-from app.domain.saga.models import Saga, SagaFilter, SagaInstance
-from app.infrastructure.mappers import (
- SagaFilterMapper,
- SagaInstanceMapper,
- SagaMapper,
-)
-
-pytestmark = pytest.mark.unit
-
-
-def _saga() -> Saga:
- return Saga(
- saga_id="s1",
- saga_name="demo",
- execution_id="e1",
- state=SagaState.RUNNING,
- current_step="a",
- completed_steps=["a"],
- compensated_steps=[],
- context_data={"k": "v"},
- error_message=None,
- )
-
-
-def test_saga_mapper_to_from_mongo() -> None:
- s = _saga()
- m = SagaMapper()
- doc = m.to_mongo(s)
- s2 = m.from_mongo({**doc})
- assert s2.saga_id == s.saga_id and s2.state == s.state
-
-
-def test_saga_instance_mapper_roundtrip_and_clean_context() -> None:
- inst = SagaInstance(saga_name="demo", execution_id="e1", context_data={"_private": "x", "ok": object()})
- md = SagaInstanceMapper.to_mongo(inst)
- assert "_private" not in md["context_data"] and "ok" in md["context_data"]
-
- doc = {
- "saga_id": inst.saga_id,
- "saga_name": inst.saga_name,
- "execution_id": inst.execution_id,
- "state": "completed",
- "completed_steps": ["a"],
- "compensated_steps": [],
- "context_data": {"a": 1},
- "retry_count": 1,
- "created_at": datetime.now(timezone.utc),
- "updated_at": datetime.now(timezone.utc),
- }
- inst2 = SagaInstanceMapper.from_mongo(doc)
- assert inst2.state in (SagaState.COMPLETED, SagaState.CREATED)
-
- # bad state falls back to CREATED
- doc_bad = doc | {"state": "???"}
- inst3 = SagaInstanceMapper.from_mongo(doc_bad)
- assert inst3.state == SagaState.CREATED
-
-
-def test_saga_filter_mapper_to_query() -> None:
- f = SagaFilter(state=SagaState.COMPLETED, execution_ids=["e1"], saga_name="demo", error_status=False)
- fq = SagaFilterMapper().to_mongodb_query(f)
- assert fq["state"] == SagaState.COMPLETED.value and fq["execution_id"]["$in"] == ["e1"] and fq[
- "error_message"] is None
diff --git a/backend/tests/unit/infrastructure/mappers/test_saga_mapper_extended.py b/backend/tests/unit/infrastructure/mappers/test_saga_mapper_extended.py
deleted file mode 100644
index bd28091d..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_saga_mapper_extended.py
+++ /dev/null
@@ -1,327 +0,0 @@
-"""Extended tests for saga mapper to achieve 95%+ coverage."""
-
-from datetime import datetime, timezone
-
-import pytest
-from app.domain.enums.saga import SagaState
-from app.domain.saga.models import Saga, SagaFilter, SagaInstance
-from app.infrastructure.mappers.saga_mapper import (
- SagaFilterMapper,
- SagaInstanceMapper,
- SagaMapper,
-)
-
-
-@pytest.fixture
-def sample_saga():
- """Create a sample saga with all fields."""
- return Saga(
- saga_id="saga-123",
- saga_name="test-saga",
- execution_id="exec-456",
- state=SagaState.RUNNING,
- current_step="step-2",
- completed_steps=["step-1"],
- compensated_steps=[],
- context_data={"key": "value", "_private": "secret", "user_id": "user-789"},
- error_message="Some error",
- created_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
- updated_at=datetime(2024, 1, 1, 10, 30, 0, tzinfo=timezone.utc),
- completed_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
- retry_count=2,
- )
-
-
-@pytest.fixture
-def sample_saga_instance():
- """Create a sample saga instance."""
- return SagaInstance(
- saga_id="inst-123",
- saga_name="test-instance",
- execution_id="exec-789",
- state=SagaState.COMPENSATING,
- current_step="compensate-1",
- completed_steps=["step-1", "step-2"],
- compensated_steps=["step-2"],
- context_data={"data": "test", "_internal": "hidden"},
- error_message="Failed step",
- created_at=datetime(2024, 1, 2, 9, 0, 0, tzinfo=timezone.utc),
- updated_at=datetime(2024, 1, 2, 9, 30, 0, tzinfo=timezone.utc),
- completed_at=None,
- retry_count=1,
- )
-
-
-class TestSagaMapper:
- """Extended tests for SagaMapper."""
-
- def test_to_mongo_with_private_keys(self, sample_saga):
- """Test that private keys (starting with '_') are filtered out."""
- mapper = SagaMapper()
- doc = mapper.to_mongo(sample_saga)
-
- # Private key should be filtered out
- assert "_private" not in doc["context_data"]
- assert "key" in doc["context_data"]
- assert "user_id" in doc["context_data"]
-
- def test_to_mongo_with_none_context(self):
- """Test handling of None context_data."""
- saga = Saga(
- saga_id="saga-001",
- saga_name="test",
- execution_id="exec-001",
- state=SagaState.CREATED,
- context_data=None,
- )
-
- mapper = SagaMapper()
- doc = mapper.to_mongo(saga)
-
- assert doc["context_data"] == {}
-
- def test_to_mongo_with_non_dict_context(self):
- """Test handling of non-dict context_data."""
- saga = Saga(
- saga_id="saga-002",
- saga_name="test",
- execution_id="exec-002",
- state=SagaState.CREATED,
- context_data="not a dict", # Invalid but testing defensive code
- )
-
- mapper = SagaMapper()
- doc = mapper.to_mongo(saga)
-
- # Should return the non-dict value as-is (line 38 checks isinstance)
- assert doc["context_data"] == "not a dict"
-
- def test_from_instance(self, sample_saga_instance):
- """Test converting SagaInstance to Saga."""
- mapper = SagaMapper()
- saga = mapper.from_instance(sample_saga_instance)
-
- assert saga.saga_id == sample_saga_instance.saga_id
- assert saga.saga_name == sample_saga_instance.saga_name
- assert saga.execution_id == sample_saga_instance.execution_id
- assert saga.state == sample_saga_instance.state
- assert saga.current_step == sample_saga_instance.current_step
- assert saga.completed_steps == sample_saga_instance.completed_steps
- assert saga.compensated_steps == sample_saga_instance.compensated_steps
- assert saga.context_data == sample_saga_instance.context_data
- assert saga.error_message == sample_saga_instance.error_message
- assert saga.retry_count == sample_saga_instance.retry_count
-
-
-class TestSagaInstanceMapper:
- """Extended tests for SagaInstanceMapper."""
-
- def test_from_mongo_with_invalid_state(self):
- """Test from_mongo with invalid state value that triggers exception."""
- doc = {
- "saga_id": "saga-123",
- "saga_name": "test",
- "execution_id": "exec-123",
- "state": 12345, # Invalid state (not string or SagaState)
- "completed_steps": [],
- "compensated_steps": [],
- "context_data": {},
- "retry_count": 0,
- }
-
- instance = SagaInstanceMapper.from_mongo(doc)
-
- # Should fall back to CREATED on exception (line 127)
- assert instance.state == SagaState.CREATED
-
- def test_from_mongo_with_saga_state_enum(self):
- """Test from_mongo when state is already a SagaState enum."""
- doc = {
- "saga_id": "saga-124",
- "saga_name": "test",
- "execution_id": "exec-124",
- "state": SagaState.COMPLETED, # Already an enum
- "completed_steps": ["step-1"],
- "compensated_steps": [],
- "context_data": {},
- "retry_count": 1,
- }
-
- instance = SagaInstanceMapper.from_mongo(doc)
-
- assert instance.state == SagaState.COMPLETED
-
- def test_from_mongo_without_datetime_fields(self):
- """Test from_mongo without created_at and updated_at."""
- doc = {
- "saga_id": "saga-125",
- "saga_name": "test",
- "execution_id": "exec-125",
- "state": "running",
- "completed_steps": [],
- "compensated_steps": [],
- "context_data": {},
- "retry_count": 0,
- # No created_at or updated_at
- }
-
- instance = SagaInstanceMapper.from_mongo(doc)
-
- assert instance.saga_id == "saga-125"
- # Should have default datetime values
- assert instance.created_at is not None
- assert instance.updated_at is not None
-
- def test_from_mongo_with_datetime_fields(self):
- """Test from_mongo with created_at and updated_at present."""
- now = datetime.now(timezone.utc)
- doc = {
- "saga_id": "saga-126",
- "saga_name": "test",
- "execution_id": "exec-126",
- "state": "running",
- "completed_steps": [],
- "compensated_steps": [],
- "context_data": {},
- "retry_count": 0,
- "created_at": now,
- "updated_at": now,
- }
-
- instance = SagaInstanceMapper.from_mongo(doc)
-
- assert instance.created_at == now
- assert instance.updated_at == now
-
- def test_to_mongo_with_various_context_types(self):
- """Test to_mongo with different value types in context_data."""
-
- class CustomObject:
- def __str__(self):
- return "custom_str"
-
- class BadObject:
- def __str__(self):
- raise ValueError("Cannot convert")
-
- instance = SagaInstance(
- saga_name="test",
- execution_id="exec-127",
- context_data={
- "_private": "should be skipped",
- "string": "test",
- "int": 42,
- "float": 3.14,
- "bool": True,
- "list": [1, 2, 3],
- "dict": {"nested": "value"},
- "none": None,
- "custom": CustomObject(),
- "bad": BadObject(),
- }
- )
-
- doc = SagaInstanceMapper.to_mongo(instance)
-
- # Check filtered context
- context = doc["context_data"]
- assert "_private" not in context
- assert context["string"] == "test"
- assert context["int"] == 42
- assert context["float"] == 3.14
- assert context["bool"] is True
- assert context["list"] == [1, 2, 3]
- assert context["dict"] == {"nested": "value"}
- assert context["none"] is None
- assert context["custom"] == "custom_str" # Converted to string
- assert "bad" not in context # Skipped due to exception
-
- def test_to_mongo_with_state_without_value_attr(self):
- """Test to_mongo when state doesn't have 'value' attribute."""
- instance = SagaInstance(
- saga_name="test",
- execution_id="exec-128",
- )
- # Mock the state to not have 'value' attribute
- instance.state = "MOCK_STATE" # String instead of enum
-
- doc = SagaInstanceMapper.to_mongo(instance)
-
- # Should use str(state) fallback (line 171)
- assert doc["state"] == "MOCK_STATE"
-
-
-class TestSagaFilterMapper:
- """Extended tests for SagaFilterMapper."""
-
- def test_to_mongodb_query_with_error_status_true(self):
- """Test filter with error_status=True (has errors)."""
- filter_obj = SagaFilter(
- error_status=True # Looking for sagas with errors
- )
-
- mapper = SagaFilterMapper()
- query = mapper.to_mongodb_query(filter_obj)
-
- assert query["error_message"] == {"$ne": None}
-
- def test_to_mongodb_query_with_error_status_false(self):
- """Test filter with error_status=False (no errors)."""
- filter_obj = SagaFilter(
- error_status=False # Looking for sagas without errors
- )
-
- mapper = SagaFilterMapper()
- query = mapper.to_mongodb_query(filter_obj)
-
- assert query["error_message"] is None
-
- def test_to_mongodb_query_with_created_after_only(self):
- """Test filter with only created_after."""
- after_date = datetime(2024, 1, 1, tzinfo=timezone.utc)
- filter_obj = SagaFilter(
- created_after=after_date
- )
-
- mapper = SagaFilterMapper()
- query = mapper.to_mongodb_query(filter_obj)
-
- assert query["created_at"] == {"$gte": after_date}
-
- def test_to_mongodb_query_with_created_before_only(self):
- """Test filter with only created_before."""
- before_date = datetime(2024, 12, 31, tzinfo=timezone.utc)
- filter_obj = SagaFilter(
- created_before=before_date
- )
-
- mapper = SagaFilterMapper()
- query = mapper.to_mongodb_query(filter_obj)
-
- assert query["created_at"] == {"$lte": before_date}
-
- def test_to_mongodb_query_with_both_dates(self):
- """Test filter with both created_after and created_before."""
- after_date = datetime(2024, 1, 1, tzinfo=timezone.utc)
- before_date = datetime(2024, 12, 31, tzinfo=timezone.utc)
- filter_obj = SagaFilter(
- created_after=after_date,
- created_before=before_date
- )
-
- mapper = SagaFilterMapper()
- query = mapper.to_mongodb_query(filter_obj)
-
- assert query["created_at"] == {
- "$gte": after_date,
- "$lte": before_date
- }
-
- def test_to_mongodb_query_empty_filter(self):
- """Test empty filter produces empty query."""
- filter_obj = SagaFilter()
-
- mapper = SagaFilterMapper()
- query = mapper.to_mongodb_query(filter_obj)
-
- assert query == {}
diff --git a/backend/tests/unit/infrastructure/mappers/test_saved_script_mapper.py b/backend/tests/unit/infrastructure/mappers/test_saved_script_mapper.py
deleted file mode 100644
index 018684de..00000000
--- a/backend/tests/unit/infrastructure/mappers/test_saved_script_mapper.py
+++ /dev/null
@@ -1,259 +0,0 @@
-"""Tests for saved script mapper to achieve 95%+ coverage."""
-
-from datetime import datetime, timezone
-from unittest.mock import patch
-from uuid import UUID
-
-import pytest
-from app.domain.saved_script.models import (
- DomainSavedScriptCreate,
- DomainSavedScriptUpdate,
-)
-from app.infrastructure.mappers.saved_script_mapper import SavedScriptMapper
-
-
-@pytest.fixture
-def sample_create_script():
- """Create a sample script creation object with all fields."""
- return DomainSavedScriptCreate(
- name="Test Script",
- script="print('Hello, World!')",
- lang="python",
- lang_version="3.11",
- description="A test script for unit testing",
- )
-
-
-@pytest.fixture
-def sample_create_script_minimal():
- """Create a minimal script creation object."""
- return DomainSavedScriptCreate(
- name="Minimal Script",
- script="console.log('test')",
- )
-
-
-@pytest.fixture
-def sample_update_all_fields():
- """Create an update object with all fields."""
- return DomainSavedScriptUpdate(
- name="Updated Name",
- script="print('Updated')",
- lang="python",
- lang_version="3.12",
- description="Updated description",
- )
-
-
-@pytest.fixture
-def sample_update_partial():
- """Create an update object with only some fields."""
- return DomainSavedScriptUpdate(
- name="New Name",
- script=None,
- lang=None,
- lang_version=None,
- description="New description",
- )
-
-
-@pytest.fixture
-def sample_mongo_document():
- """Create a sample MongoDB document with all fields."""
- return {
- "_id": "mongo_id_123",
- "script_id": "script-123",
- "user_id": "user-456",
- "name": "DB Script",
- "script": "def main(): pass",
- "lang": "python",
- "lang_version": "3.10",
- "description": "Script from database",
- "created_at": datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
- "updated_at": datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
- "extra_field": "should be ignored",
- }
-
-
-class TestSavedScriptMapper:
- """Test SavedScriptMapper methods."""
-
- def test_to_insert_document_with_all_fields(self, sample_create_script):
- """Test creating insert document with all fields."""
- user_id = "test-user-123"
-
- with patch('app.infrastructure.mappers.saved_script_mapper.uuid4') as mock_uuid:
- mock_uuid.return_value = UUID('12345678-1234-5678-1234-567812345678')
-
- doc = SavedScriptMapper.to_insert_document(sample_create_script, user_id)
-
- assert doc["script_id"] == "12345678-1234-5678-1234-567812345678"
- assert doc["user_id"] == user_id
- assert doc["name"] == "Test Script"
- assert doc["script"] == "print('Hello, World!')"
- assert doc["lang"] == "python"
- assert doc["lang_version"] == "3.11"
- assert doc["description"] == "A test script for unit testing"
- assert isinstance(doc["created_at"], datetime)
- assert isinstance(doc["updated_at"], datetime)
- assert doc["created_at"] == doc["updated_at"] # Should be same timestamp
-
- def test_to_insert_document_with_minimal_fields(self, sample_create_script_minimal):
- """Test creating insert document with minimal fields (using defaults)."""
- user_id = "minimal-user"
-
- doc = SavedScriptMapper.to_insert_document(sample_create_script_minimal, user_id)
-
- assert doc["user_id"] == user_id
- assert doc["name"] == "Minimal Script"
- assert doc["script"] == "console.log('test')"
- assert doc["lang"] == "python" # Default value
- assert doc["lang_version"] == "3.11" # Default value
- assert doc["description"] is None # Optional field
- assert "script_id" in doc
- assert "created_at" in doc
- assert "updated_at" in doc
-
- def test_to_update_dict_with_all_fields(self, sample_update_all_fields):
- """Test converting update object with all fields to dict."""
- update_dict = SavedScriptMapper.to_update_dict(sample_update_all_fields)
-
- assert update_dict["name"] == "Updated Name"
- assert update_dict["script"] == "print('Updated')"
- assert update_dict["lang"] == "python"
- assert update_dict["lang_version"] == "3.12"
- assert update_dict["description"] == "Updated description"
- assert "updated_at" in update_dict
- assert isinstance(update_dict["updated_at"], datetime)
-
- def test_to_update_dict_with_none_fields(self, sample_update_partial):
- """Test that None fields are filtered out from update dict."""
- update_dict = SavedScriptMapper.to_update_dict(sample_update_partial)
-
- assert update_dict["name"] == "New Name"
- assert "script" not in update_dict # None value should be filtered
- assert "lang" not in update_dict # None value should be filtered
- assert "lang_version" not in update_dict # None value should be filtered
- assert update_dict["description"] == "New description"
- assert "updated_at" in update_dict
-
- def test_to_update_dict_with_only_updated_at(self):
- """Test update with all fields None except updated_at."""
- update = DomainSavedScriptUpdate() # All fields default to None
-
- update_dict = SavedScriptMapper.to_update_dict(update)
-
- # Only updated_at should be present (it has a default factory)
- assert len(update_dict) == 1
- assert "updated_at" in update_dict
- assert isinstance(update_dict["updated_at"], datetime)
-
- def test_from_mongo_document_with_all_fields(self, sample_mongo_document):
- """Test converting MongoDB document to domain model with all fields."""
- script = SavedScriptMapper.from_mongo_document(sample_mongo_document)
-
- assert script.script_id == "script-123"
- assert script.user_id == "user-456"
- assert script.name == "DB Script"
- assert script.script == "def main(): pass"
- assert script.lang == "python"
- assert script.lang_version == "3.10"
- assert script.description == "Script from database"
- assert script.created_at == datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
- assert script.updated_at == datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc)
- # Extra field should be ignored
- assert not hasattr(script, "extra_field")
- assert not hasattr(script, "_id")
-
- def test_from_mongo_document_with_missing_optional_fields(self):
- """Test converting MongoDB document with missing optional fields."""
- doc = {
- "script_id": "minimal-123",
- "user_id": "minimal-user",
- "name": "Minimal",
- "script": "pass",
- "lang": "python",
- "lang_version": "3.9",
- # No description, created_at, or updated_at
- }
-
- script = SavedScriptMapper.from_mongo_document(doc)
-
- assert script.script_id == "minimal-123"
- assert script.user_id == "minimal-user"
- assert script.name == "Minimal"
- assert script.script == "pass"
- assert script.lang == "python"
- assert script.lang_version == "3.9"
- assert script.description is None # Should use dataclass default
- # created_at and updated_at should use dataclass defaults
- assert isinstance(script.created_at, datetime)
- assert isinstance(script.updated_at, datetime)
-
- def test_from_mongo_document_with_non_string_fields(self):
- """Test type coercion when fields are not strings."""
- doc = {
- "script_id": 123, # Integer instead of string
- "user_id": 456, # Integer instead of string
- "name": 789, # Integer instead of string
- "script": {"code": "test"}, # Dict instead of string
- "lang": ["python"], # List instead of string
- "lang_version": 3.11, # Float instead of string
- "description": "Valid description",
- "created_at": datetime(2024, 1, 1, tzinfo=timezone.utc),
- "updated_at": datetime(2024, 1, 2, tzinfo=timezone.utc),
- }
-
- script = SavedScriptMapper.from_mongo_document(doc)
-
- # All fields should be coerced to strings
- assert script.script_id == "123"
- assert script.user_id == "456"
- assert script.name == "789"
- assert script.script == "{'code': 'test'}"
- assert script.lang == "['python']"
- assert script.lang_version == "3.11"
- assert script.description == "Valid description"
-
- def test_from_mongo_document_empty(self):
- """Test converting empty MongoDB document should fail."""
- doc = {}
-
- # Should raise TypeError since required fields are missing
- with pytest.raises(TypeError) as exc_info:
- SavedScriptMapper.from_mongo_document(doc)
-
- assert "missing" in str(exc_info.value).lower()
-
- def test_from_mongo_document_only_unknown_fields(self):
- """Test converting document with only unknown fields should fail."""
- doc = {
- "_id": "some_id",
- "unknown_field1": "value1",
- "unknown_field2": "value2",
- "not_in_dataclass": "value3",
- }
-
- # Should raise TypeError since required fields are missing
- with pytest.raises(TypeError) as exc_info:
- SavedScriptMapper.from_mongo_document(doc)
-
- assert "missing" in str(exc_info.value).lower()
-
- def test_from_mongo_document_partial_string_fields(self):
- """Test with some string fields present and some missing should fail."""
- doc = {
- "script_id": "id-123",
- "user_id": 999, # Non-string, should be coerced
- "name": "Test",
- # script is missing - required field
- "lang": "javascript",
- # lang_version is missing
- "description": None, # Explicitly None
- }
-
- # Should raise TypeError since required field 'script' is missing
- with pytest.raises(TypeError) as exc_info:
- SavedScriptMapper.from_mongo_document(doc)
-
- assert "missing" in str(exc_info.value).lower()
diff --git a/backend/tests/unit/schemas_pydantic/test_events_schemas.py b/backend/tests/unit/schemas_pydantic/test_events_schemas.py
index f3121cbb..30ef50c2 100644
--- a/backend/tests/unit/schemas_pydantic/test_events_schemas.py
+++ b/backend/tests/unit/schemas_pydantic/test_events_schemas.py
@@ -1,25 +1,7 @@
-import math
-from datetime import datetime, timezone, timedelta
-
import pytest
+from app.schemas_pydantic.events import EventFilterRequest
from app.domain.enums.common import SortOrder
-from app.domain.enums.events import EventType
-from app.infrastructure.kafka.events.metadata import AvroEventMetadata
-from app.schemas_pydantic.events import (
- EventAggregationRequest,
- EventBase,
- EventFilterRequest,
- EventInDB,
- EventListResponse,
- EventProjection,
- EventQuery,
- EventResponse,
- EventStatistics,
- PublishEventRequest,
- PublishEventResponse,
- ResourceUsage,
-)
def test_event_filter_request_sort_validator_accepts_allowed_fields():
@@ -34,78 +16,3 @@ def test_event_filter_request_sort_validator_accepts_allowed_fields():
def test_event_filter_request_sort_validator_rejects_invalid():
with pytest.raises(ValueError):
EventFilterRequest(sort_by="not-a-field")
-
-
-def test_event_base_and_in_db_defaults_and_metadata():
- meta = AvroEventMetadata(service_name="tests", service_version="1.0", user_id="u1")
- ev = EventBase(
- event_type=EventType.EXECUTION_REQUESTED,
- metadata=meta,
- payload={"execution_id": "e1"},
- )
- assert ev.event_id and ev.timestamp.tzinfo is not None
- edb = EventInDB(**ev.model_dump())
- assert isinstance(edb.stored_at, datetime)
- assert isinstance(edb.ttl_expires_at, datetime)
- # ttl should be after stored_at by ~30 days
- assert edb.ttl_expires_at > edb.stored_at
-
-
-def test_publish_event_request_and_response():
- req = PublishEventRequest(
- event_type=EventType.EXECUTION_REQUESTED,
- payload={"x": 1},
- aggregate_id="agg",
- )
- assert req.event_type is EventType.EXECUTION_REQUESTED
- resp = PublishEventResponse(event_id="e", status="queued", timestamp=datetime.now(timezone.utc))
- assert resp.status == "queued"
-
-
-def test_event_query_schema_and_list_response():
- q = EventQuery(
- event_types=[EventType.EXECUTION_REQUESTED, EventType.POD_CREATED],
- user_id="u1",
- start_time=datetime.now(timezone.utc) - timedelta(hours=1),
- end_time=datetime.now(timezone.utc),
- limit=50,
- skip=0,
- )
- assert len(q.event_types or []) == 2 and q.limit == 50
-
- # Minimal list response compose/decompose
- er = EventResponse(
- event_id="id",
- event_type=EventType.POD_CREATED,
- event_version="1.0",
- timestamp=datetime.now(timezone.utc),
- metadata={},
- payload={},
- )
- lst = EventListResponse(events=[er], total=1, limit=1, skip=0, has_more=False)
- assert lst.total == 1 and not lst.has_more
-
-
-def test_event_projection_and_statistics_examples():
- proj = EventProjection(
- name="exec_summary",
- source_events=[EventType.EXECUTION_REQUESTED, EventType.EXECUTION_COMPLETED],
- aggregation_pipeline=[{"$match": {"event_type": str(EventType.EXECUTION_REQUESTED)}}],
- output_collection="summary",
- )
- assert proj.refresh_interval_seconds == 300
-
- stats = EventStatistics(
- total_events=2,
- events_by_type={str(EventType.EXECUTION_REQUESTED): 1},
- events_by_service={"svc": 2},
- events_by_hour=[{"hour": "2025-01-01 00:00", "count": 2}],
- )
- assert stats.total_events == 2
-
-
-def test_resource_usage_schema():
- ru = ResourceUsage(cpu_seconds=1.5, memory_mb_seconds=256.0, disk_io_mb=10.0, network_io_mb=5.0)
- dumped = ru.model_dump()
- assert math.isclose(dumped["cpu_seconds"], 1.5)
-
diff --git a/backend/tests/unit/schemas_pydantic/test_execution_schemas.py b/backend/tests/unit/schemas_pydantic/test_execution_schemas.py
index 2ff863f4..38e59401 100644
--- a/backend/tests/unit/schemas_pydantic/test_execution_schemas.py
+++ b/backend/tests/unit/schemas_pydantic/test_execution_schemas.py
@@ -20,4 +20,3 @@ def test_execution_request_unsupported_version_raises():
with pytest.raises(ValueError) as e:
ExecutionRequest(script="print(1)", lang="python", lang_version="9.9")
assert "Version '9.9' not supported for python" in str(e.value)
-
diff --git a/backend/tests/unit/schemas_pydantic/test_health_dashboard_schemas.py b/backend/tests/unit/schemas_pydantic/test_health_dashboard_schemas.py
deleted file mode 100644
index fb1f0d02..00000000
--- a/backend/tests/unit/schemas_pydantic/test_health_dashboard_schemas.py
+++ /dev/null
@@ -1,100 +0,0 @@
-from datetime import datetime, timezone
-
-from app.schemas_pydantic.health_dashboard import (
- CategoryHealthResponse,
- CategoryHealthStatistics,
- CategoryServices,
- DependencyEdge,
- DependencyGraph,
- DependencyNode,
- DetailedHealthStatus,
- HealthAlert,
- HealthCheckConfig,
- HealthCheckState,
- HealthDashboardResponse,
- HealthMetricsSummary,
- HealthStatistics,
- HealthTrend,
- ServiceHealth,
- ServiceHealthDetails,
- ServiceHistoryDataPoint,
- ServiceHistoryResponse,
- ServiceHistorySummary,
- ServiceRealtimeStatus,
- ServiceDependenciesResponse,
-)
-from app.domain.enums.health import AlertSeverity
-
-
-def _now() -> datetime:
- return datetime.now(timezone.utc)
-
-
-def test_alert_and_metrics_and_trend_models():
- alert = HealthAlert(
- id="a1", severity=AlertSeverity.CRITICAL, service="backend", status="unhealthy", message="down",
- timestamp=_now(), duration_ms=12.3
- )
- assert alert.severity is AlertSeverity.CRITICAL
-
- metrics = HealthMetricsSummary(
- total_checks=10, healthy_checks=7, failed_checks=3, avg_check_duration_ms=5.5, total_failures_24h=3, uptime_percentage_24h=99.1
- )
- assert metrics.total_checks == 10
-
- trend = HealthTrend(timestamp=_now(), status="ok", healthy_count=10, unhealthy_count=0, degraded_count=0)
- assert trend.healthy_count == 10
-
-
-def test_service_health_and_dashboard_models():
- svc = ServiceHealth(name="backend", status="healthy", uptime_percentage=99.9, last_check=_now(), message="ok", critical=False)
- dash = HealthDashboardResponse(
- overall_status="healthy", last_updated=_now(), services=[svc], statistics={"total": 1}, alerts=[], trends=[]
- )
- assert dash.overall_status == "healthy"
-
-
-def test_category_services_and_detailed_status():
- cat = CategoryServices(status="healthy", message="ok", duration_ms=1.0, details={"k": "v"})
- stats = HealthStatistics(total_checks=10, healthy=9, degraded=1, unhealthy=0, unknown=0)
- detailed = DetailedHealthStatus(
- timestamp=_now().isoformat(), overall_status="healthy", categories={"core": {"db": cat}}, statistics=stats
- )
- assert detailed.categories["core"]["db"].status == "healthy"
-
-
-def test_dependency_graph_and_service_dependencies():
- nodes = [DependencyNode(id="svcA", label="Service A", status="healthy", critical=False, message="ok")]
- edges = [DependencyEdge(**{"from": "svcA", "to": "svcB", "critical": True})]
- graph = DependencyGraph(nodes=nodes, edges=edges)
- assert graph.edges[0].from_service == "svcA" and graph.edges[0].to_service == "svcB"
-
- from app.schemas_pydantic.health_dashboard import ServiceImpactAnalysis
- impact = {"svcA": ServiceImpactAnalysis(status="ok", affected_services=[], is_critical=False)}
- dep = ServiceDependenciesResponse(
- dependency_graph=graph,
- impact_analysis=impact,
- total_services=1,
- healthy_services=1,
- critical_services_down=0,
- )
- assert dep.total_services == 1
-
-
-def test_service_health_details_and_history():
- cfg = HealthCheckConfig(type="http", critical=True, interval_seconds=10.0, timeout_seconds=2.0, failure_threshold=3)
- state = HealthCheckState(consecutive_failures=0, consecutive_successes=5)
- details = ServiceHealthDetails(
- name="backend", status="healthy", message="ok", duration_ms=1.2, timestamp=_now(), check_config=cfg, state=state
- )
- assert details.state.consecutive_successes == 5
-
- dp = ServiceHistoryDataPoint(timestamp=_now(), status="ok", duration_ms=1.0, healthy=True)
- summary = ServiceHistorySummary(uptime_percentage=99.9, total_checks=10, healthy_checks=9, failure_count=1)
- hist = ServiceHistoryResponse(service_name="backend", time_range_hours=24, data_points=[dp], summary=summary)
- assert hist.time_range_hours == 24
-
-
-def test_realtime_status_model():
- rt = ServiceRealtimeStatus(status="ok", message="fine", duration_ms=2.0, last_check=_now(), details={})
- assert rt.status == "ok"
diff --git a/backend/tests/unit/schemas_pydantic/test_notification_schemas.py b/backend/tests/unit/schemas_pydantic/test_notification_schemas.py
index 00a0c7d4..14b304bc 100644
--- a/backend/tests/unit/schemas_pydantic/test_notification_schemas.py
+++ b/backend/tests/unit/schemas_pydantic/test_notification_schemas.py
@@ -1,19 +1,12 @@
from datetime import UTC, datetime, timedelta
import pytest
+
from app.domain.enums.notification import NotificationChannel, NotificationSeverity, NotificationStatus
-from app.schemas_pydantic.notification import (
- Notification,
- NotificationBatch,
- NotificationListResponse,
- NotificationResponse,
- NotificationStats,
- NotificationSubscription,
- SubscriptionUpdate,
-)
+from app.schemas_pydantic.notification import Notification, NotificationBatch
-def test_notification_scheduled_for_future_validation():
+def test_notification_scheduled_for_must_be_future():
n = Notification(
user_id="u1",
channel=NotificationChannel.IN_APP,
@@ -43,35 +36,6 @@ def test_notification_batch_validation_limits():
with pytest.raises(ValueError):
NotificationBatch(notifications=[])
- # Upper bound: >1000 should fail
many = [n1.model_copy() for _ in range(1001)]
with pytest.raises(ValueError):
NotificationBatch(notifications=many)
-
-
-def test_notification_response_and_list():
- n = Notification(user_id="u1", channel=NotificationChannel.IN_APP, subject="s", body="b")
- resp = NotificationResponse(
- notification_id=n.notification_id,
- channel=n.channel,
- status=n.status,
- subject=n.subject,
- body=n.body,
- action_url=None,
- created_at=n.created_at,
- read_at=None,
- severity=n.severity,
- tags=[],
- )
- lst = NotificationListResponse(notifications=[resp], total=1, unread_count=1)
- assert lst.unread_count == 1
-
-
-def test_subscription_models_and_stats():
- sub = NotificationSubscription(user_id="u1", channel=NotificationChannel.IN_APP)
- upd = SubscriptionUpdate(enabled=True)
- assert sub.enabled is True and upd.enabled is True
-
- now = datetime.now(UTC)
- stats = NotificationStats(start_date=now - timedelta(days=1), end_date=now)
- assert stats.total_sent == 0 and stats.delivery_rate == 0.0
diff --git a/backend/tests/unit/schemas_pydantic/test_replay_models_schemas.py b/backend/tests/unit/schemas_pydantic/test_replay_models_schemas.py
deleted file mode 100644
index 98fff483..00000000
--- a/backend/tests/unit/schemas_pydantic/test_replay_models_schemas.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from datetime import datetime, timezone
-
-from app.domain.enums.events import EventType
-from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType
-from app.domain.replay.models import ReplayConfig as DomainReplayConfig, ReplayFilter as DomainReplayFilter
-from app.schemas_pydantic.replay_models import ReplayConfigSchema, ReplayFilterSchema, ReplaySession
-
-
-def test_replay_filter_schema_from_domain():
- df = DomainReplayFilter(
- execution_id="e1",
- event_types=[EventType.EXECUTION_REQUESTED],
- exclude_event_types=[EventType.POD_CREATED],
- start_time=datetime.now(timezone.utc),
- end_time=datetime.now(timezone.utc),
- user_id="u1",
- service_name="svc",
- custom_query={"x": 1},
- )
- sf = ReplayFilterSchema.from_domain(df)
- assert sf.event_types == [str(EventType.EXECUTION_REQUESTED)]
- assert sf.exclude_event_types == [str(EventType.POD_CREATED)]
-
-
-def test_replay_config_schema_from_domain_and_key_conversion():
- df = DomainReplayFilter(event_types=[EventType.EXECUTION_REQUESTED])
- cfg = DomainReplayConfig(
- replay_type=ReplayType.TIME_RANGE,
- target=ReplayTarget.KAFKA,
- filter=df,
- target_topics={EventType.EXECUTION_REQUESTED: "execution-events"},
- max_events=10,
- )
- sc = ReplayConfigSchema.model_validate(cfg)
- assert sc.target_topics == {str(EventType.EXECUTION_REQUESTED): "execution-events"}
- assert sc.max_events == 10
-
-
-def test_replay_session_coerces_config_from_domain():
- df = DomainReplayFilter()
- cfg = DomainReplayConfig(replay_type=ReplayType.TIME_RANGE, filter=df)
- session = ReplaySession(config=cfg)
- assert session.status == ReplayStatus.CREATED
- assert isinstance(session.config, ReplayConfigSchema)
diff --git a/backend/tests/unit/schemas_pydantic/test_saga_schemas.py b/backend/tests/unit/schemas_pydantic/test_saga_schemas.py
deleted file mode 100644
index 290446c4..00000000
--- a/backend/tests/unit/schemas_pydantic/test_saga_schemas.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from datetime import datetime, timezone
-
-from app.domain.enums.saga import SagaState
-from app.domain.saga.models import Saga
-from app.schemas_pydantic.saga import SagaStatusResponse
-
-
-def test_saga_status_response_from_domain():
- s = Saga(
- saga_id="s1",
- saga_name="exec-saga",
- execution_id="e1",
- state=SagaState.RUNNING,
- current_step="allocate",
- completed_steps=["validate"],
- compensated_steps=[],
- error_message=None,
- created_at=datetime.now(timezone.utc),
- updated_at=datetime.now(timezone.utc),
- completed_at=None,
- retry_count=1,
- )
- resp = SagaStatusResponse.from_domain(s)
- assert resp.saga_id == "s1" and resp.current_step == "allocate"
- assert resp.created_at.endswith("Z") is False # isoformat without enforced Z; just ensure string
-
diff --git a/backend/tests/unit/services/coordinator/test_queue_manager.py b/backend/tests/unit/services/coordinator/test_queue_manager.py
index d21c7363..e3151a16 100644
--- a/backend/tests/unit/services/coordinator/test_queue_manager.py
+++ b/backend/tests/unit/services/coordinator/test_queue_manager.py
@@ -1,8 +1,12 @@
+import logging
+
import pytest
from app.services.coordinator.queue_manager import QueueManager, QueuePriority
from tests.helpers import make_execution_requested_event
+_test_logger = logging.getLogger("test.services.coordinator.queue_manager")
+
def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value):
return make_execution_requested_event(execution_id=execution_id, priority=priority)
@@ -10,7 +14,7 @@ def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value):
@pytest.mark.asyncio
async def test_requeue_execution_increments_priority():
- qm = QueueManager(max_queue_size=10)
+ qm = QueueManager(max_queue_size=10, logger=_test_logger)
await qm.start()
# Use NORMAL priority which can be incremented to LOW
e = ev("x", priority=QueuePriority.NORMAL.value)
@@ -23,7 +27,7 @@ async def test_requeue_execution_increments_priority():
@pytest.mark.asyncio
async def test_queue_stats_empty_and_after_add():
- qm = QueueManager(max_queue_size=5)
+ qm = QueueManager(max_queue_size=5, logger=_test_logger)
await qm.start()
stats0 = await qm.get_queue_stats()
assert stats0["total_size"] == 0
diff --git a/backend/tests/unit/services/coordinator/test_resource_manager.py b/backend/tests/unit/services/coordinator/test_resource_manager.py
index a7971e57..5e1df687 100644
--- a/backend/tests/unit/services/coordinator/test_resource_manager.py
+++ b/backend/tests/unit/services/coordinator/test_resource_manager.py
@@ -1,11 +1,15 @@
+import logging
+
import pytest
from app.services.coordinator.resource_manager import ResourceManager
+_test_logger = logging.getLogger("test.services.coordinator.resource_manager")
+
@pytest.mark.asyncio
async def test_request_allocation_defaults_and_limits() -> None:
- rm = ResourceManager(total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0)
+ rm = ResourceManager(total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0, logger=_test_logger)
# Default for python
alloc = await rm.request_allocation("e1", "python")
@@ -23,7 +27,7 @@ async def test_request_allocation_defaults_and_limits() -> None:
@pytest.mark.asyncio
async def test_release_and_can_allocate() -> None:
- rm = ResourceManager(total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0)
+ rm = ResourceManager(total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0, logger=_test_logger)
a = await rm.request_allocation("e1", "python", requested_cpu=1.0, requested_memory_mb=512)
assert a is not None
@@ -43,7 +47,7 @@ async def test_release_and_can_allocate() -> None:
@pytest.mark.asyncio
async def test_resource_stats() -> None:
- rm = ResourceManager(total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0)
+ rm = ResourceManager(total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0, logger=_test_logger)
# Make sure the allocation succeeds
alloc = await rm.request_allocation("e1", "python", requested_cpu=0.5, requested_memory_mb=256)
assert alloc is not None, "Allocation should have succeeded"
diff --git a/backend/tests/unit/services/idempotency/test_idempotency_manager.py b/backend/tests/unit/services/idempotency/test_idempotency_manager.py
index aab5f33a..df1b2092 100644
--- a/backend/tests/unit/services/idempotency/test_idempotency_manager.py
+++ b/backend/tests/unit/services/idempotency/test_idempotency_manager.py
@@ -1,3 +1,4 @@
+import logging
from unittest.mock import MagicMock
import pytest
@@ -11,6 +12,9 @@
pytestmark = pytest.mark.unit
+# Test logger
+_test_logger = logging.getLogger("test.idempotency_manager")
+
class TestIdempotencyKeyStrategy:
def test_event_based(self) -> None:
@@ -84,7 +88,7 @@ def test_custom_config(self) -> None:
def test_manager_generate_key_variants() -> None:
repo = MagicMock()
- mgr = IdempotencyManager(IdempotencyConfig(), repo)
+ mgr = IdempotencyManager(IdempotencyConfig(), repo, _test_logger)
ev = MagicMock(spec=BaseEvent)
ev.event_type = "t"
ev.event_id = "e"
diff --git a/backend/tests/unit/services/idempotency/test_middleware.py b/backend/tests/unit/services/idempotency/test_middleware.py
index d5ca17f7..c4b19acf 100644
--- a/backend/tests/unit/services/idempotency/test_middleware.py
+++ b/backend/tests/unit/services/idempotency/test_middleware.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -13,6 +14,8 @@
from app.domain.enums.events import EventType
from app.domain.enums.kafka import KafkaTopic
+_test_logger = logging.getLogger("test.services.idempotency.middleware")
+
pytestmark = pytest.mark.unit
@@ -42,7 +45,8 @@ def idempotent_event_handler(self, mock_handler, mock_idempotency_manager):
idempotency_manager=mock_idempotency_manager,
key_strategy="event_based",
ttl_seconds=3600,
- cache_result=True
+ cache_result=True,
+ logger=_test_logger
)
@pytest.mark.asyncio
@@ -54,7 +58,8 @@ async def test_call_with_fields(self, mock_handler, mock_idempotency_manager, ev
handler=mock_handler,
idempotency_manager=mock_idempotency_manager,
key_strategy="content_hash",
- fields=fields
+ fields=fields,
+ logger=_test_logger
)
idempotency_result = IdempotencyResult(
diff --git a/backend/tests/unit/services/pod_monitor/test_event_mapper.py b/backend/tests/unit/services/pod_monitor/test_event_mapper.py
index 0dcb35e8..48a36d4b 100644
--- a/backend/tests/unit/services/pod_monitor/test_event_mapper.py
+++ b/backend/tests/unit/services/pod_monitor/test_event_mapper.py
@@ -1,4 +1,5 @@
import json
+import logging
import pytest
from app.domain.enums.storage import ExecutionErrorType
@@ -19,13 +20,15 @@
pytestmark = pytest.mark.unit
+_test_logger = logging.getLogger("test.services.pod_monitor.event_mapper")
+
def _ctx(pod: Pod, event_type: str = "ADDED") -> PodContext:
return PodContext(pod=pod, execution_id="e1", metadata=AvroEventMetadata(service_name="t", service_version="1"), phase=pod.status.phase or "", event_type=event_type)
def test_pending_running_and_succeeded_mapping() -> None:
- pem = PodEventMapper(k8s_api=FakeApi(json.dumps({"stdout": "ok", "stderr": "", "exit_code": 0, "resource_usage": {"execution_time_wall_seconds": 0, "cpu_time_jiffies": 0, "clk_tck_hertz": 0, "peak_memory_kb": 0}})))
+ pem = PodEventMapper(k8s_api=FakeApi(json.dumps({"stdout": "ok", "stderr": "", "exit_code": 0, "resource_usage": {"execution_time_wall_seconds": 0, "cpu_time_jiffies": 0, "clk_tck_hertz": 0, "peak_memory_kb": 0}})), logger=_test_logger)
# Pending -> scheduled (set execution-id label and PodScheduled condition)
pend = Pod("p", "Pending")
@@ -60,7 +63,8 @@ def __init__(self, t, s): self.type=t; self.status=s
def test_failed_timeout_and_deleted() -> None:
- pem = PodEventMapper(k8s_api=FakeApi(""))
+ valid_logs = json.dumps({"stdout": "", "stderr": "", "exit_code": 137, "resource_usage": {}})
+ pem = PodEventMapper(k8s_api=FakeApi(valid_logs), logger=_test_logger)
# Timeout via DeadlineExceeded
pod_to = Pod("p", "Failed", cs=[ContainerStatus(State(terminated=Terminated(137)))], reason="DeadlineExceeded", adl=5)
@@ -69,21 +73,25 @@ def test_failed_timeout_and_deleted() -> None:
assert ev.event_type.value == "execution_timeout" and ev.timeout_seconds == 5
# Failed: terminated exit_code nonzero, message used as stderr, error type defaults to SCRIPT_ERROR
+ # Note: ExecutionFailedEvent can have None resource_usage when logs extraction fails
+ pem_no_logs = PodEventMapper(k8s_api=FakeApi(""), logger=_test_logger)
pod_fail = Pod("p2", "Failed", cs=[ContainerStatus(State(terminated=Terminated(2, message="boom")))])
pod_fail.metadata.labels = {"execution-id": "e2"}
- evf = pem.map_pod_event(pod_fail, "MODIFIED")[0]
+ evf = pem_no_logs.map_pod_event(pod_fail, "MODIFIED")[0]
assert evf.event_type.value == "execution_failed" and evf.error_type in {ExecutionErrorType.SCRIPT_ERROR}
# Deleted -> terminated when container terminated present (exit code 0 returns completed for DELETED)
+ valid_logs_0 = json.dumps({"stdout": "", "stderr": "", "exit_code": 0, "resource_usage": {}})
+ pem_completed = PodEventMapper(k8s_api=FakeApi(valid_logs_0), logger=_test_logger)
pod_del = Pod("p3", "Failed", cs=[ContainerStatus(State(terminated=Terminated(0, reason="Completed")))])
pod_del.metadata.labels = {"execution-id": "e3"}
- evd = pem.map_pod_event(pod_del, "DELETED")[0]
+ evd = pem_completed.map_pod_event(pod_del, "DELETED")[0]
# For DELETED event with exit code 0, it returns execution_completed, not pod_terminated
assert evd.event_type.value == "execution_completed"
def test_extract_id_and_metadata_priority_and_duplicates() -> None:
- pem = PodEventMapper(k8s_api=FakeApi(""))
+ pem = PodEventMapper(k8s_api=FakeApi(""), logger=_test_logger)
# From label
p = Pod("any", "Pending")
@@ -113,7 +121,7 @@ def test_scheduled_requires_condition() -> None:
class Cond:
def __init__(self, t, s): self.type=t; self.status=s
- pem = PodEventMapper(k8s_api=FakeApi(""))
+ pem = PodEventMapper(k8s_api=FakeApi(""), logger=_test_logger)
pod = Pod("p", "Pending")
# No conditions -> None
assert pem._map_scheduled(_ctx(pod)) is None
@@ -129,16 +137,16 @@ def __init__(self, t, s): self.type=t; self.status=s
def test_parse_and_log_paths_and_analyze_failure_variants(caplog) -> None:
# _parse_executor_output line-by-line
line_json = '{"stdout":"x","stderr":"","exit_code":3,"resource_usage":{}}'
- pem = PodEventMapper(k8s_api=FakeApi("junk\n" + line_json))
+ pem = PodEventMapper(k8s_api=FakeApi("junk\n" + line_json), logger=_test_logger)
pod = Pod("p", "Succeeded", cs=[ContainerStatus(State(terminated=Terminated(0)))])
logs = pem._extract_logs(pod)
assert logs.exit_code == 3 and logs.stdout == "x"
- # _extract_logs: no api
- pem2 = PodEventMapper(k8s_api=None)
- assert pem2._extract_logs(pod).exit_code is None
+ # _extract_logs: no api -> returns None
+ pem2 = PodEventMapper(k8s_api=None, logger=_test_logger)
+ assert pem2._extract_logs(pod) is None
- # _extract_logs exceptions -> 404/400/generic branches
+ # _extract_logs exceptions -> 404/400/generic branches, all return None
class _API404(FakeApi):
def read_namespaced_pod_log(self, *a, **k): raise Exception("404 Not Found")
class _API400(FakeApi):
@@ -146,12 +154,12 @@ def read_namespaced_pod_log(self, *a, **k): raise Exception("400 Bad Request")
class _APIGen(FakeApi):
def read_namespaced_pod_log(self, *a, **k): raise Exception("boom")
- pem404 = PodEventMapper(k8s_api=_API404(""))
- assert pem404._extract_logs(pod).exit_code is None
- pem400 = PodEventMapper(k8s_api=_API400(""))
- assert pem400._extract_logs(pod).exit_code is None
- pemg = PodEventMapper(k8s_api=_APIGen(""))
- assert pemg._extract_logs(pod).exit_code is None
+ pem404 = PodEventMapper(k8s_api=_API404(""), logger=_test_logger)
+ assert pem404._extract_logs(pod) is None
+ pem400 = PodEventMapper(k8s_api=_API400(""), logger=_test_logger)
+ assert pem400._extract_logs(pod) is None
+ pemg = PodEventMapper(k8s_api=_APIGen(""), logger=_test_logger)
+ assert pemg._extract_logs(pod) is None
# _analyze_failure: Evicted
pod_e = Pod("p", "Failed")
@@ -173,7 +181,8 @@ def read_namespaced_pod_log(self, *a, **k): raise Exception("boom")
def test_all_containers_succeeded_and_cache_behavior() -> None:
- pem = PodEventMapper(k8s_api=FakeApi(""))
+ valid_logs = json.dumps({"stdout": "", "stderr": "", "exit_code": 0, "resource_usage": {}})
+ pem = PodEventMapper(k8s_api=FakeApi(valid_logs), logger=_test_logger)
term0 = ContainerStatus(State(terminated=Terminated(0)))
term0b = ContainerStatus(State(terminated=Terminated(0)))
pod = Pod("p", "Failed", cs=[term0, term0b])
diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py
index bf7731fc..4aff899f 100644
--- a/backend/tests/unit/services/pod_monitor/test_monitor.py
+++ b/backend/tests/unit/services/pod_monitor/test_monitor.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
import types
import pytest
@@ -9,10 +10,57 @@
pytestmark = pytest.mark.unit
+# Test logger for all tests
+_test_logger = logging.getLogger("test.pod_monitor")
+
+
+# ===== Shared stubs for k8s mocking =====
+
+
+class _Cfg:
+ host = "https://k8s"
+ ssl_ca_cert = None
+
+
+class _K8sConfig:
+ def load_incluster_config(self):
+ pass
+
+ def load_kube_config(self, config_file=None):
+ pass # noqa: ARG002
+
+
+class _Conf:
+ @staticmethod
+ def get_default_copy():
+ return _Cfg()
+
+
+class _ApiClient:
+ def __init__(self, cfg):
+ pass # noqa: ARG002
+
+
+class _Core:
+ def __init__(self, api):
+ pass # noqa: ARG002
+
+ def get_api_resources(self):
+ return None
+
+
+class _Watch:
+ def __init__(self):
+ pass
+
+ def stop(self):
+ pass
+
class _SpyMapper:
def __init__(self) -> None:
self.cleared = False
+
def clear_cache(self) -> None:
self.cleared = True
@@ -27,17 +75,27 @@ def stop(self):
return None
-class _FakeProducer:
- async def start(self):
- return None
- async def stop(self):
- return None
- async def produce(self, *a, **k): # noqa: ARG002
- return None
- # Adapter looks at _producer._producer is not None for health
- @property
- def _producer(self):
- return object()
+class _FakeKafkaEventService:
+ """Fake KafkaEventService for testing."""
+
+ def __init__(self):
+ self.published_events = []
+
+ async def publish_base_event(self, event, key=None):
+ self.published_events.append((event, key))
+ return event.event_id if hasattr(event, "event_id") else "fake-id"
+
+
+def _patch_k8s(monkeypatch, k8s_config=None, conf=None, api_client=None, core=None, watch=None):
+ """Helper to patch k8s modules with defaults or custom stubs."""
+ monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", k8s_config or _K8sConfig())
+ monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.Configuration", conf or _Conf)
+ monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.ApiClient", api_client or _ApiClient)
+ monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.CoreV1Api", core or _Core)
+ monkeypatch.setattr("app.services.pod_monitor.monitor.watch", types.SimpleNamespace(Watch=watch or _Watch))
+
+
+# ===== Tests =====
@pytest.mark.asyncio
@@ -45,16 +103,17 @@ async def test_start_and_stop_lifecycle(monkeypatch) -> None:
cfg = PodMonitorConfig()
cfg.enable_state_reconciliation = False
- pm = PodMonitor(cfg, producer=_FakeProducer())
- # Avoid real k8s client init; keep our spy mapper in place
- pm._initialize_kubernetes_client = lambda: None # type: ignore[assignment]
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
+ pm._initialize_kubernetes_client = lambda: None
spy = _SpyMapper()
- pm._event_mapper = spy # type: ignore[assignment]
+ pm._event_mapper = spy
pm._v1 = _StubV1()
pm._watch = _StubWatch()
+
async def _quick_watch():
return None
- pm._watch_pods = _quick_watch # type: ignore[assignment]
+
+ pm._watch_pods = _quick_watch
await pm.start()
assert pm.state.name == "RUNNING"
@@ -65,43 +124,10 @@ async def _quick_watch():
def test_initialize_kubernetes_client_paths(monkeypatch) -> None:
cfg = PodMonitorConfig()
- # Create stubs for k8s modules
- class _Cfg:
- host = "https://k8s"
- ssl_ca_cert = None
-
- class _K8sConfig:
- def load_incluster_config(self): pass # noqa: D401, E701
- def load_kube_config(self, config_file=None): pass # noqa: D401, E701, ARG002
-
- class _Conf:
- @staticmethod
- def get_default_copy():
- return _Cfg()
-
- class _ApiClient:
- def __init__(self, cfg): # noqa: ARG002
- pass
-
- class _Core:
- def __init__(self, api): # noqa: ARG002
- self._ok = True
- def get_api_resources(self):
- return None
+ _patch_k8s(monkeypatch)
- class _Watch:
- def __init__(self): pass
-
- # Patch modules used by monitor
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", _K8sConfig())
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.Configuration", _Conf)
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.ApiClient", _ApiClient)
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.CoreV1Api", _Core)
- monkeypatch.setattr("app.services.pod_monitor.monitor.watch", types.SimpleNamespace(Watch=_Watch))
-
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._initialize_kubernetes_client()
- # After init, client/watch set and event mapper rebuilt
assert pm._v1 is not None and pm._watch is not None
@@ -110,85 +136,39 @@ async def test_watch_pod_events_flow_and_publish(monkeypatch) -> None:
cfg = PodMonitorConfig()
cfg.enable_state_reconciliation = False
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Use real mapper with fake API so mapping yields events
from app.services.pod_monitor.event_mapper import PodEventMapper as PEM
- pm._event_mapper = PEM(k8s_api=FakeApi("{}"))
- # Fake v1 and watch
+ pm._event_mapper = PEM(k8s_api=FakeApi("{}"), logger=_test_logger)
+
class V1:
def list_namespaced_pod(self, **kwargs): # noqa: ARG002
return None
pm._v1 = V1()
- # Construct a pod that maps to completed
pod = make_pod(name="p", phase="Succeeded", labels={"execution-id": "e1"}, term_exit=0, resource_version="rv1")
pm._watch = make_watch([{"type": "MODIFIED", "object": pod}], resource_version="rv2")
- # Speed up
pm._state = pm.state.__class__.RUNNING
await pm._watch_pod_events()
- # resource version updated from stream
assert pm._last_resource_version == "rv2"
@pytest.mark.asyncio
async def test_process_raw_event_invalid_and_handle_watch_error(monkeypatch) -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Invalid event shape
await pm._process_raw_event({})
- # Backoff progression without real sleep by setting base delay to 0
pm.config.watch_reconnect_delay = 0
pm._reconnect_attempts = 0
- await pm._handle_watch_error() # 1
- await pm._handle_watch_error() # 2
+ await pm._handle_watch_error()
+ await pm._handle_watch_error()
assert pm._reconnect_attempts >= 2
-@pytest.mark.asyncio
-async def test_unified_producer_adapter() -> None:
- from app.services.pod_monitor.monitor import UnifiedProducerAdapter
-
- class _TrackerProducer:
- def __init__(self):
- self.events = []
- self._producer = object()
- async def produce(self, event_to_produce, key=None):
- self.events.append((event_to_produce, key))
-
- tracker = _TrackerProducer()
- adapter = UnifiedProducerAdapter(tracker)
-
- # Test send_event success
- class _Event:
- pass
- event = _Event()
- success = await adapter.send_event(event, "topic", "key")
- assert success is True and tracker.events == [(event, "key")]
-
- # Test is_healthy
- assert await adapter.is_healthy() is True
-
- # Test send_event failure
- class _FailProducer:
- _producer = object()
- async def produce(self, *a, **k):
- raise RuntimeError("boom")
-
- fail_adapter = UnifiedProducerAdapter(_FailProducer())
- success = await fail_adapter.send_event(_Event(), "topic")
- assert success is False
-
- # Test is_healthy with None producer
- class _NoneProducer:
- _producer = None
- assert await UnifiedProducerAdapter(_NoneProducer()).is_healthy() is False
-
-
@pytest.mark.asyncio
async def test_get_status() -> None:
cfg = PodMonitorConfig()
@@ -196,13 +176,13 @@ async def test_get_status() -> None:
cfg.label_selector = "app=test"
cfg.enable_state_reconciliation = True
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._tracked_pods = {"pod1", "pod2"}
pm._reconnect_attempts = 3
pm._last_resource_version = "v123"
status = await pm.get_status()
- assert "idle" in status["state"].lower() # Check state contains idle
+ assert "idle" in status["state"].lower()
assert status["tracked_pods"] == 2
assert status["reconnect_attempts"] == 3
assert status["last_resource_version"] == "v123"
@@ -217,28 +197,26 @@ async def test_reconciliation_loop_and_state(monkeypatch) -> None:
cfg.enable_state_reconciliation = True
cfg.reconcile_interval_seconds = 0.01
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
reconcile_called = []
+
async def mock_reconcile():
reconcile_called.append(True)
from app.services.pod_monitor.monitor import ReconciliationResult
- return ReconciliationResult(
- missing_pods={"p1"},
- extra_pods={"p2"},
- duration_seconds=0.1,
- success=True
- )
+
+ return ReconciliationResult(missing_pods={"p1"}, extra_pods={"p2"}, duration_seconds=0.1, success=True)
pm._reconcile_state = mock_reconcile
- # Run reconciliation loop until first reconcile
evt = asyncio.Event()
+
async def wrapped_reconcile():
res = await mock_reconcile()
evt.set()
return res
+
pm._reconcile_state = wrapped_reconcile
task = asyncio.create_task(pm._reconciliation_loop())
@@ -257,9 +235,8 @@ async def test_reconcile_state_success(monkeypatch) -> None:
cfg.namespace = "test"
cfg.label_selector = "app=test"
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Mock K8s API: provide a sync function suitable for asyncio.to_thread
def sync_list(namespace, label_selector): # noqa: ARG002
return types.SimpleNamespace(
items=[
@@ -269,12 +246,13 @@ def sync_list(namespace, label_selector): # noqa: ARG002
)
pm._v1 = types.SimpleNamespace(list_namespaced_pod=sync_list)
- pm._tracked_pods = {"pod2", "pod3"} # pod1 missing, pod3 extra
+ pm._tracked_pods = {"pod2", "pod3"}
- # Mock process_pod_event
processed = []
+
async def mock_process(event):
processed.append(event.pod.metadata.name)
+
pm._process_pod_event = mock_process
result = await pm._reconcile_state()
@@ -289,7 +267,7 @@ async def mock_process(event):
@pytest.mark.asyncio
async def test_reconcile_state_no_v1_api() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._v1 = None
result = await pm._reconcile_state()
@@ -300,7 +278,7 @@ async def test_reconcile_state_no_v1_api() -> None:
@pytest.mark.asyncio
async def test_reconcile_state_exception() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
class FailV1:
def list_namespaced_pod(self, *a, **k):
@@ -312,6 +290,7 @@ def list_namespaced_pod(self, *a, **k):
assert result.success is False
assert "API error" in result.error
+
@pytest.mark.asyncio
async def test_process_pod_event_full_flow() -> None:
from app.services.pod_monitor.monitor import PodEvent, WatchEventType
@@ -319,26 +298,26 @@ async def test_process_pod_event_full_flow() -> None:
cfg = PodMonitorConfig()
cfg.ignored_pod_phases = ["Unknown"]
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Mock event mapper
class MockMapper:
def map_pod_event(self, pod, event_type):
class Event:
event_type = types.SimpleNamespace(value="test_event")
metadata = types.SimpleNamespace(correlation_id=None)
aggregate_id = "agg1"
+
return [Event()]
pm._event_mapper = MockMapper()
- # Mock publish
published = []
+
async def mock_publish(event, pod):
published.append(event)
+
pm._publish_event = mock_publish
- # Create test pod events
event = PodEvent(
event_type=WatchEventType.ADDED,
pod=make_pod(name="test-pod", phase="Running"),
@@ -350,7 +329,6 @@ async def mock_publish(event, pod):
assert pm._last_resource_version == "v1"
assert len(published) == 1
- # Test DELETED event
event_del = PodEvent(
event_type=WatchEventType.DELETED,
pod=make_pod(name="test-pod", phase="Succeeded"),
@@ -361,7 +339,6 @@ async def mock_publish(event, pod):
assert "test-pod" not in pm._tracked_pods
assert pm._last_resource_version == "v2"
- # Test ignored phase
event_ignored = PodEvent(
event_type=WatchEventType.ADDED,
pod=make_pod(name="ignored-pod", phase="Unknown"),
@@ -370,7 +347,7 @@ async def mock_publish(event, pod):
published.clear()
await pm._process_pod_event(event_ignored)
- assert len(published) == 0 # Should be skipped
+ assert len(published) == 0
@pytest.mark.asyncio
@@ -378,9 +355,8 @@ async def test_process_pod_event_exception_handling() -> None:
from app.services.pod_monitor.monitor import PodEvent, WatchEventType
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Mock event mapper to raise exception
class FailMapper:
def map_pod_event(self, pod, event_type):
raise RuntimeError("Mapping failed")
@@ -388,12 +364,9 @@ def map_pod_event(self, pod, event_type):
pm._event_mapper = FailMapper()
event = PodEvent(
- event_type=WatchEventType.ADDED,
- pod=make_pod(name="fail-pod", phase="Pending"),
- resource_version=None
+ event_type=WatchEventType.ADDED, pod=make_pod(name="fail-pod", phase="Pending"), resource_version=None
)
- # Should not raise, just log error
await pm._process_pod_event(event)
@@ -402,45 +375,21 @@ async def test_publish_event_full_flow() -> None:
from app.domain.enums.events import EventType
cfg = PodMonitorConfig()
+ fake_service = _FakeKafkaEventService()
+ pm = PodMonitor(cfg, kafka_event_service=fake_service, logger=_test_logger)
- # Track published events
- published = []
-
- class TrackerProducer:
- def __init__(self):
- self._producer = object()
- async def produce(self, event_to_produce, key=None):
- published.append((event_to_produce, key))
- async def is_healthy(self):
- return True
-
- from app.services.pod_monitor.monitor import UnifiedProducerAdapter
- pm = PodMonitor(cfg, producer=_FakeProducer())
- pm._producer = UnifiedProducerAdapter(TrackerProducer())
-
- # Create test event and pod
class Event:
event_type = EventType.EXECUTION_COMPLETED
metadata = types.SimpleNamespace(correlation_id=None)
aggregate_id = "exec1"
execution_id = "exec1"
+ event_id = "evt-123"
pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "exec1"})
await pm._publish_event(Event(), pod)
- assert len(published) == 1
- assert published[0][1] == "exec1" # key
-
- # Test unhealthy producer
- class UnhealthyProducer:
- _producer = None
- async def is_healthy(self):
- return False
-
- pm._producer = UnifiedProducerAdapter(UnhealthyProducer())
- published.clear()
- await pm._publish_event(Event(), pod)
- assert len(published) == 0 # Should not publish
+ assert len(fake_service.published_events) == 1
+ assert fake_service.published_events[0][1] == "exec1"
@pytest.mark.asyncio
@@ -448,25 +397,24 @@ async def test_publish_event_exception_handling() -> None:
from app.domain.enums.events import EventType
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
- # Mock producer that raises exception
- class ExceptionProducer:
- _producer = object()
- async def is_healthy(self):
- raise RuntimeError("Health check failed")
+ class FailingKafkaEventService:
+ async def publish_base_event(self, event, key=None):
+ raise RuntimeError("Publish failed")
- from app.services.pod_monitor.monitor import UnifiedProducerAdapter
- pm._producer = UnifiedProducerAdapter(ExceptionProducer())
+ pm = PodMonitor(cfg, kafka_event_service=FailingKafkaEventService(), logger=_test_logger)
class Event:
event_type = EventType.EXECUTION_STARTED
+ metadata = types.SimpleNamespace(correlation_id=None)
+ aggregate_id = None
+ execution_id = None
class Pod:
metadata = None
status = None
- # Should not raise, just log error
+ # Should not raise - errors are caught and logged
await pm._publish_event(Event(), Pod())
@@ -475,23 +423,23 @@ async def test_handle_watch_error_max_attempts() -> None:
cfg = PodMonitorConfig()
cfg.max_reconnect_attempts = 2
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
pm._reconnect_attempts = 2
await pm._handle_watch_error()
- # Should stop after max attempts
assert pm._state == pm.state.__class__.STOPPING
@pytest.mark.asyncio
async def test_watch_pods_main_loop(monkeypatch) -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
watch_count = []
+
async def mock_watch():
watch_count.append(1)
if len(watch_count) > 2:
@@ -512,14 +460,14 @@ async def test_watch_pods_api_exception(monkeypatch) -> None:
from kubernetes.client.rest import ApiException
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
async def mock_watch():
- # 410 Gone error
raise ApiException(status=410)
error_handled = []
+
async def mock_handle():
error_handled.append(True)
pm._state = pm.state.__class__.STOPPED
@@ -536,13 +484,14 @@ async def mock_handle():
@pytest.mark.asyncio
async def test_watch_pods_generic_exception(monkeypatch) -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
async def mock_watch():
raise RuntimeError("Unexpected error")
error_handled = []
+
async def mock_handle():
error_handled.append(True)
pm._state = pm.state.__class__.STOPPED
@@ -555,55 +504,46 @@ async def mock_handle():
@pytest.mark.asyncio
-async def test_create_pod_monitor_context_manager() -> None:
+async def test_create_pod_monitor_context_manager(monkeypatch) -> None:
from app.services.pod_monitor.monitor import create_pod_monitor
+ _patch_k8s(monkeypatch)
+
cfg = PodMonitorConfig()
cfg.enable_state_reconciliation = False
- producer = _FakeProducer()
-
- async with create_pod_monitor(cfg, producer) as monitor:
- # Override kubernetes initialization
- monitor._initialize_kubernetes_client = lambda: None
- monitor._v1 = _StubV1()
- monitor._watch = _StubWatch()
- async def _fast_watch():
- return None
- monitor._watch_pods = _fast_watch
+ fake_service = _FakeKafkaEventService()
- # Monitor should be started
+ async with create_pod_monitor(cfg, fake_service, _test_logger) as monitor:
assert monitor.state == monitor.state.__class__.RUNNING
- # Monitor should be stopped after context exit
assert monitor.state == monitor.state.__class__.STOPPED
@pytest.mark.asyncio
async def test_start_already_running() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
- await pm.start() # Should log warning and return
+ await pm.start()
@pytest.mark.asyncio
async def test_stop_already_stopped() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.STOPPED
- await pm.stop() # Should return immediately
+ await pm.stop()
@pytest.mark.asyncio
async def test_stop_with_tasks() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
- # Create dummy tasks
async def dummy_task():
await asyncio.Event().wait()
@@ -620,50 +560,42 @@ async def dummy_task():
def test_update_resource_version() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # With valid stream
class Stream:
_stop_event = types.SimpleNamespace(resource_version="v123")
pm._update_resource_version(Stream())
assert pm._last_resource_version == "v123"
- # With invalid stream (no _stop_event)
class BadStream:
pass
- pm._update_resource_version(BadStream()) # Should not raise
+ pm._update_resource_version(BadStream())
@pytest.mark.asyncio
async def test_process_raw_event_with_metadata() -> None:
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Mock process_pod_event
processed = []
+
async def mock_process(event):
processed.append(event)
+
pm._process_pod_event = mock_process
- # Valid event with metadata
raw_event = {
- 'type': 'ADDED',
- 'object': types.SimpleNamespace(
- metadata=types.SimpleNamespace(resource_version='v1')
- )
+ "type": "ADDED",
+ "object": types.SimpleNamespace(metadata=types.SimpleNamespace(resource_version="v1")),
}
await pm._process_raw_event(raw_event)
assert len(processed) == 1
- assert processed[0].resource_version == 'v1'
+ assert processed[0].resource_version == "v1"
- # Event without metadata
- raw_event_no_meta = {
- 'type': 'MODIFIED',
- 'object': types.SimpleNamespace(metadata=None)
- }
+ raw_event_no_meta = {"type": "MODIFIED", "object": types.SimpleNamespace(metadata=None)}
await pm._process_raw_event(raw_event_no_meta)
assert len(processed) == 2
@@ -674,39 +606,18 @@ def test_initialize_kubernetes_client_in_cluster(monkeypatch) -> None:
cfg = PodMonitorConfig()
cfg.in_cluster = True
- # Create stubs for k8s modules
load_incluster_called = []
- class _K8sConfig:
+ class TrackingK8sConfig:
def load_incluster_config(self):
load_incluster_called.append(True)
- def load_kube_config(self, config_file=None): pass # noqa: ARG002
- class _Conf:
- @staticmethod
- def get_default_copy():
- return types.SimpleNamespace(host="https://k8s", ssl_ca_cert=None)
-
- class _ApiClient:
- def __init__(self, cfg): pass # noqa: ARG002
-
- class _Core:
- def __init__(self, api): # noqa: ARG002
- self._ok = True
- def get_api_resources(self):
- return None
-
- class _Watch:
- def __init__(self): pass
+ def load_kube_config(self, config_file=None):
+ pass # noqa: ARG002
- # Patch modules
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", _K8sConfig())
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.Configuration", _Conf)
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.ApiClient", _ApiClient)
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.CoreV1Api", _Core)
- monkeypatch.setattr("app.services.pod_monitor.monitor.watch", types.SimpleNamespace(Watch=_Watch))
+ _patch_k8s(monkeypatch, k8s_config=TrackingK8sConfig())
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._initialize_kubernetes_client()
assert len(load_incluster_called) == 1
@@ -719,52 +630,36 @@ def test_initialize_kubernetes_client_with_kubeconfig_path(monkeypatch) -> None:
load_kube_called_with = []
- class _K8sConfig:
- def load_incluster_config(self): pass
+ class TrackingK8sConfig:
+ def load_incluster_config(self):
+ pass
+
def load_kube_config(self, config_file=None):
load_kube_called_with.append(config_file)
- class _Conf:
+ class ConfWithCert:
@staticmethod
def get_default_copy():
return types.SimpleNamespace(host="https://k8s", ssl_ca_cert="cert")
- class _ApiClient:
- def __init__(self, cfg): pass # noqa: ARG002
-
- class _Core:
- def __init__(self, api): # noqa: ARG002
- self._ok = True
- def get_api_resources(self):
- return None
-
- class _Watch:
- def __init__(self): pass
+ _patch_k8s(monkeypatch, k8s_config=TrackingK8sConfig(), conf=ConfWithCert)
- # Patch modules
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", _K8sConfig())
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.Configuration", _Conf)
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.ApiClient", _ApiClient)
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.CoreV1Api", _Core)
- monkeypatch.setattr("app.services.pod_monitor.monitor.watch", types.SimpleNamespace(Watch=_Watch))
-
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._initialize_kubernetes_client()
assert load_kube_called_with == ["/custom/kubeconfig"]
def test_initialize_kubernetes_client_exception(monkeypatch) -> None:
- import pytest
cfg = PodMonitorConfig()
- class _K8sConfig:
+ class FailingK8sConfig:
def load_kube_config(self, config_file=None):
raise Exception("K8s config error")
- monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", _K8sConfig())
+ monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", FailingK8sConfig())
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
with pytest.raises(Exception) as exc_info:
pm._initialize_kubernetes_client()
@@ -777,14 +672,14 @@ async def test_watch_pods_api_exception_other_status(monkeypatch) -> None:
from kubernetes.client.rest import ApiException
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
async def mock_watch():
- # Non-410 API error
raise ApiException(status=500)
error_handled = []
+
async def mock_handle():
error_handled.append(True)
pm._state = pm.state.__class__.STOPPED
@@ -798,11 +693,9 @@ async def mock_handle():
@pytest.mark.asyncio
async def test_watch_pod_events_no_watch_or_v1() -> None:
- import pytest
cfg = PodMonitorConfig()
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # No watch
pm._watch = None
pm._v1 = _StubV1()
@@ -811,7 +704,6 @@ async def test_watch_pod_events_no_watch_or_v1() -> None:
assert "Watch or API not initialized" in str(exc_info.value)
- # No v1
pm._watch = _StubWatch()
pm._v1 = None
@@ -827,9 +719,8 @@ async def test_watch_pod_events_with_field_selector() -> None:
cfg.field_selector = "status.phase=Running"
cfg.enable_state_reconciliation = False
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
- # Mock v1 and watch
watch_kwargs = []
class V1:
@@ -848,7 +739,6 @@ def stream(self, func, **kwargs):
await pm._watch_pod_events()
- # Check field_selector was included
assert any("field_selector" in kw for kw in watch_kwargs)
@@ -858,19 +748,15 @@ async def test_reconciliation_loop_exception() -> None:
cfg.enable_state_reconciliation = True
cfg.reconcile_interval_seconds = 0.01
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._state = pm.state.__class__.RUNNING
- async def mock_reconcile():
- raise RuntimeError("Reconcile error")
-
- pm._reconcile_state = mock_reconcile
-
- # Run reconciliation loop until it hits the exception once
hit = asyncio.Event()
+
async def raising():
hit.set()
raise RuntimeError("Reconcile error")
+
pm._reconcile_state = raising
task = asyncio.create_task(pm._reconciliation_loop())
@@ -880,15 +766,13 @@ async def raising():
with pytest.raises(asyncio.CancelledError):
await task
- # Should handle exception and continue
-
@pytest.mark.asyncio
async def test_start_with_reconciliation() -> None:
cfg = PodMonitorConfig()
cfg.enable_state_reconciliation = True
- pm = PodMonitor(cfg, producer=_FakeProducer())
+ pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger)
pm._initialize_kubernetes_client = lambda: None
pm._v1 = _StubV1()
pm._watch = _StubWatch()
diff --git a/backend/tests/unit/services/result_processor/test_processor.py b/backend/tests/unit/services/result_processor/test_processor.py
index 8362eeb8..4c44dd59 100644
--- a/backend/tests/unit/services/result_processor/test_processor.py
+++ b/backend/tests/unit/services/result_processor/test_processor.py
@@ -1,3 +1,5 @@
+import logging
+
import pytest
from unittest.mock import MagicMock
@@ -8,6 +10,8 @@
pytestmark = pytest.mark.unit
+_test_logger = logging.getLogger("test.services.result_processor.processor")
+
class TestResultProcessorConfig:
def test_default_values(self):
@@ -27,7 +31,7 @@ def test_custom_values(self):
def test_create_dispatcher_registers_handlers():
- rp = ResultProcessor(execution_repo=MagicMock(), producer=MagicMock(), idempotency_manager=MagicMock())
+ rp = ResultProcessor(execution_repo=MagicMock(), producer=MagicMock(), idempotency_manager=MagicMock(), logger=_test_logger)
dispatcher = rp._create_dispatcher()
assert dispatcher is not None
assert EventType.EXECUTION_COMPLETED in dispatcher._handlers
diff --git a/backend/tests/unit/services/saga/test_execution_saga_steps.py b/backend/tests/unit/services/saga/test_execution_saga_steps.py
index ebfd32e2..ee57f431 100644
--- a/backend/tests/unit/services/saga/test_execution_saga_steps.py
+++ b/backend/tests/unit/services/saga/test_execution_saga_steps.py
@@ -1,5 +1,6 @@
import pytest
+from app.domain.saga import DomainResourceAllocation
from app.services.saga.execution_saga import (
ValidateExecutionStep,
AllocateResourcesStep,
@@ -39,16 +40,24 @@ async def test_validate_execution_step_success_and_failures() -> None:
class _FakeAllocRepo:
- def __init__(self, active: int = 0, ok: bool = True) -> None:
+ def __init__(self, active: int = 0, alloc_id: str = "alloc-1") -> None:
self.active = active
- self.ok = ok
+ self.alloc_id = alloc_id
self.released: list[str] = []
async def count_active(self, language: str) -> int: # noqa: ARG002
return self.active
- async def create_allocation(self, _id: str, **_kwargs) -> bool: # noqa: ARG002
- return self.ok
+ async def create_allocation(self, create_data) -> DomainResourceAllocation: # noqa: ARG002
+ return DomainResourceAllocation(
+ allocation_id=self.alloc_id,
+ execution_id=create_data.execution_id,
+ language=create_data.language,
+ cpu_request=create_data.cpu_request,
+ memory_request=create_data.memory_request,
+ cpu_limit=create_data.cpu_limit,
+ memory_limit=create_data.memory_limit,
+ )
async def release_allocation(self, allocation_id: str) -> None:
self.released.append(allocation_id)
@@ -58,13 +67,13 @@ async def release_allocation(self, allocation_id: str) -> None:
async def test_allocate_resources_step_paths() -> None:
ctx = SagaContext("s1", "e1")
ctx.set("execution_id", "e1")
- ok = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=0, ok=True)).execute(ctx, _req())
- assert ok is True and ctx.get("resources_allocated") is True and ctx.get("allocation_id") == "e1"
+ ok = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=0, alloc_id="alloc-1")).execute(ctx, _req())
+ assert ok is True and ctx.get("resources_allocated") is True and ctx.get("allocation_id") == "alloc-1"
# Limit exceeded
ctx2 = SagaContext("s2", "e2")
ctx2.set("execution_id", "e2")
- ok2 = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=100, ok=True)).execute(ctx2, _req())
+ ok2 = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=100)).execute(ctx2, _req())
assert ok2 is False
# Missing repo
@@ -73,12 +82,6 @@ async def test_allocate_resources_step_paths() -> None:
ok3 = await AllocateResourcesStep(alloc_repo=None).execute(ctx3, _req())
assert ok3 is False
- # Create allocation returns False -> failure path hitting line 92
- ctx4 = SagaContext("s4", "e4")
- ctx4.set("execution_id", "e4")
- ok4 = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=0, ok=False)).execute(ctx4, _req())
- assert ok4 is False
-
@pytest.mark.asyncio
async def test_queue_and_monitor_steps() -> None:
diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py
index ffe1cfea..e4b0cded 100644
--- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py
+++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
import pytest
pytestmark = pytest.mark.unit
@@ -6,6 +7,8 @@
from app.domain.enums.events import EventType
from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge
+_test_logger = logging.getLogger("test.services.sse.kafka_redis_bridge")
+
class _FakeSchema: ...
@@ -51,6 +54,7 @@ async def test_register_and_route_events_without_kafka() -> None:
settings=_FakeSettings(),
event_metrics=_FakeEventMetrics(),
sse_bus=_FakeBus(),
+ logger=_test_logger,
)
disp = _StubDispatcher()
diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py
index b39b76eb..64ac5211 100644
--- a/backend/tests/unit/services/sse/test_shutdown_manager.py
+++ b/backend/tests/unit/services/sse/test_shutdown_manager.py
@@ -1,5 +1,12 @@
+import asyncio
+import logging
+
import pytest
+from app.services.sse.sse_shutdown_manager import SSEShutdownManager
+
+_test_logger = logging.getLogger("test.services.sse.shutdown_manager")
+
class DummyRouter:
def __init__(self): self.stopped = False
@@ -9,7 +16,7 @@ async def stop(self): self.stopped = True # noqa: ANN001
@pytest.mark.asyncio
async def test_shutdown_graceful_notify_and_drain():
- mgr = SSEShutdownManager(drain_timeout=1.0, notification_timeout=0.01, force_close_timeout=0.1)
+ mgr = SSEShutdownManager(drain_timeout=1.0, notification_timeout=0.01, force_close_timeout=0.1, logger=_test_logger)
# Register two connections and arrange that they unregister when notified
ev1 = await mgr.register_connection("e1", "c1")
@@ -33,7 +40,7 @@ async def on_shutdown(event, cid): # noqa: ANN001
@pytest.mark.asyncio
async def test_shutdown_force_close_calls_router_stop_and_rejects_new():
- mgr = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.01, force_close_timeout=0.01)
+ mgr = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger)
router = DummyRouter()
mgr.set_router(router)
@@ -53,15 +60,9 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new():
assert ev2 is None
-import asyncio
-import pytest
-
-from app.services.sse.sse_shutdown_manager import SSEShutdownManager
-
-
@pytest.mark.asyncio
async def test_get_shutdown_status_transitions():
- m = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.0, force_close_timeout=0.0)
+ m = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.0, force_close_timeout=0.0, logger=_test_logger)
st0 = m.get_shutdown_status()
assert st0.phase == "ready"
await m.initiate_shutdown()
diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py
index dc8c588c..63299b4e 100644
--- a/backend/tests/unit/services/sse/test_sse_service.py
+++ b/backend/tests/unit/services/sse/test_sse_service.py
@@ -1,4 +1,5 @@
import asyncio
+import logging
from datetime import datetime, timezone
from typing import Any, Type
@@ -7,6 +8,8 @@
pytestmark = pytest.mark.unit
+_test_logger = logging.getLogger("test.services.sse.sse_service")
+
from app.domain.enums.events import EventType
from app.domain.execution import DomainExecution, ResourceUsageDomain
from app.domain.sse import ShutdownStatus, SSEHealthDomain
@@ -119,7 +122,7 @@ async def test_execution_stream_closes_on_failed_event() -> None:
repo = _FakeRepo()
bus = _FakeBus()
sm = _FakeShutdown()
- svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=_FakeSettings())
+ svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=_FakeSettings(), logger=_test_logger)
agen = svc.create_execution_stream("exec-1", user_id="u1")
first = await agen.__anext__()
@@ -156,7 +159,7 @@ async def test_execution_stream_result_stored_includes_result_payload() -> None:
)
bus = _FakeBus()
sm = _FakeShutdown()
- svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=_FakeSettings())
+ svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=_FakeSettings(), logger=_test_logger)
agen = svc.create_execution_stream("exec-2", user_id="u1")
await agen.__anext__() # connected
@@ -179,7 +182,7 @@ async def test_notification_stream_connected_and_heartbeat_and_message() -> None
sm = _FakeShutdown()
settings = _FakeSettings()
settings.SSE_HEARTBEAT_INTERVAL = 0 # emit immediately
- svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=settings)
+ svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=settings, logger=_test_logger)
agen = svc.create_notification_stream("u1")
connected = await agen.__anext__()
@@ -214,7 +217,7 @@ async def test_notification_stream_connected_and_heartbeat_and_message() -> None
@pytest.mark.asyncio
async def test_health_status_shape() -> None:
- svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), settings=_FakeSettings())
+ svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), settings=_FakeSettings(), logger=_test_logger)
h = await svc.get_health_status()
assert isinstance(h, SSEHealthDomain)
assert h.active_consumers == 3 and h.active_executions == 2
diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py
index 825e98ad..4e7300b3 100644
--- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py
+++ b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py
@@ -1,10 +1,14 @@
import asyncio
+import logging
+
import pytest
pytestmark = pytest.mark.unit
from app.services.sse.sse_shutdown_manager import SSEShutdownManager
+_test_logger = logging.getLogger("test.services.sse.sse_shutdown_manager")
+
class _FakeRouter:
def __init__(self) -> None:
@@ -16,7 +20,7 @@ async def stop(self) -> None:
@pytest.mark.asyncio
async def test_register_unregister_and_shutdown_flow() -> None:
- mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.1, force_close_timeout=0.1)
+ mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.1, force_close_timeout=0.1, logger=_test_logger)
mgr.set_router(_FakeRouter())
# Register two connections
@@ -47,7 +51,7 @@ async def _is_notifying():
@pytest.mark.asyncio
async def test_reject_new_connection_during_shutdown() -> None:
- mgr = SSEShutdownManager(drain_timeout=0.1, notification_timeout=0.01, force_close_timeout=0.01)
+ mgr = SSEShutdownManager(drain_timeout=0.1, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger)
# Pre-register one active connection to reflect realistic state
e = await mgr.register_connection("e", "c0")
assert e is not None
diff --git a/backend/uv.lock b/backend/uv.lock
index 5d057738..7f656c60 100644
--- a/backend/uv.lock
+++ b/backend/uv.lock
@@ -196,6 +196,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" },
]
+[[package]]
+name = "beanie"
+version = "2.0.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "click" },
+ { name = "lazy-model" },
+ { name = "pydantic" },
+ { name = "pymongo" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/af/c0/85857d44d1c59d8bb546bd01e7d128ae08fc9e84e3f3c5c84b365b55ea48/beanie-2.0.1.tar.gz", hash = "sha256:aad0365cba578f5686446ed0960ead140a2231cbbfa8d492220f712c5e0c06b4", size = 171502, upload-time = "2025-11-20T18:45:51.518Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/29/54/8c9a4ab2d82242074671cc35b1dd2a906c3c36b3a5c80e914c76fa9f45b7/beanie-2.0.1-py3-none-any.whl", hash = "sha256:3aad6cc0e40fb8d256a0a3fdeca92a7b3d3c1f9f47ff377c9ecd2221285e1009", size = 87693, upload-time = "2025-11-20T18:45:50.321Z" },
+]
+
[[package]]
name = "blinker"
version = "1.8.2"
@@ -981,6 +997,7 @@ dependencies = [
{ name = "attrs" },
{ name = "avro-python3" },
{ name = "backoff" },
+ { name = "beanie" },
{ name = "blinker" },
{ name = "brotli" },
{ name = "cachetools" },
@@ -1020,7 +1037,6 @@ dependencies = [
{ name = "markdown-it-py" },
{ name = "markupsafe" },
{ name = "mdurl" },
- { name = "motor" },
{ name = "msgpack" },
{ name = "multidict" },
{ name = "oauthlib" },
@@ -1123,6 +1139,7 @@ requires-dist = [
{ name = "attrs", specifier = "==25.3.0" },
{ name = "avro-python3", specifier = "==1.10.2" },
{ name = "backoff", specifier = "==2.2.1" },
+ { name = "beanie", specifier = "==2.0.1" },
{ name = "blinker", specifier = "==1.8.2" },
{ name = "brotli", specifier = "==1.2.0" },
{ name = "cachetools", specifier = "==6.2.0" },
@@ -1162,7 +1179,6 @@ requires-dist = [
{ name = "markdown-it-py", specifier = "==4.0.0" },
{ name = "markupsafe", specifier = "==3.0.2" },
{ name = "mdurl", specifier = "==0.1.2" },
- { name = "motor", specifier = "==3.6.0" },
{ name = "msgpack", specifier = "==1.1.0" },
{ name = "multidict", specifier = "==6.7.0" },
{ name = "oauthlib", specifier = "==3.2.2" },
@@ -1200,7 +1216,7 @@ requires-dist = [
{ name = "pydantic-settings", specifier = "==2.5.2" },
{ name = "pygments", specifier = "==2.19.2" },
{ name = "pyjwt", specifier = "==2.9.0" },
- { name = "pymongo", specifier = "==4.9.2" },
+ { name = "pymongo", specifier = "==4.12.1" },
{ name = "pyparsing", specifier = "==3.2.3" },
{ name = "python-dateutil", specifier = "==2.9.0.post0" },
{ name = "python-dotenv", specifier = "==1.0.1" },
@@ -1366,6 +1382,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/a8/17f5e28cecdbd6d48127c22abdb794740803491f422a11905c4569d8e139/kubernetes-31.0.0-py2.py3-none-any.whl", hash = "sha256:bf141e2d380c8520eada8b351f4e319ffee9636328c137aa432bc486ca1200e1", size = 1857013, upload-time = "2024-09-20T03:16:06.05Z" },
]
+[[package]]
+name = "lazy-model"
+version = "0.4.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/72/85/e25dc36dee49cf0726c03a1558b5c311a17095bc9361bcbf47226cb3075a/lazy-model-0.4.0.tar.gz", hash = "sha256:a851d85d0b518b0b9c8e626bbee0feb0494c0e0cb5636550637f032dbbf9c55f", size = 8256, upload-time = "2025-08-07T20:05:34.737Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5c/54/653ea0d7c578741e9867ccf0cbf47b7eac09ff22e4238f311ac20671a911/lazy_model-0.4.0-py3-none-any.whl", hash = "sha256:95ea59551c1ac557a2c299f75803c56cc973923ef78c67ea4839a238142f7927", size = 13749, upload-time = "2025-08-07T20:05:36.303Z" },
+]
+
[[package]]
name = "limits"
version = "3.13.0"
@@ -1494,18 +1522,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
-[[package]]
-name = "motor"
-version = "3.6.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "pymongo" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/6a/d1/06af0527fd02d49b203db70dba462e47275a3c1094f830fdaf090f0cb20c/motor-3.6.0.tar.gz", hash = "sha256:0ef7f520213e852bf0eac306adf631aabe849227d8aec900a2612512fb9c5b8d", size = 278447, upload-time = "2024-09-18T16:51:37.747Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/b4/c2/bba4dce0dc56e49d95c270c79c9330ed19e6b71a2a633aecf53e7e1f04c9/motor-3.6.0-py3-none-any.whl", hash = "sha256:9f07ed96f1754963d4386944e1b52d403a5350c687edc60da487d66f98dbf894", size = 74802, upload-time = "2024-09-18T16:51:35.761Z" },
-]
-
[[package]]
name = "msgpack"
version = "1.1.0"
@@ -2355,31 +2371,40 @@ wheels = [
[[package]]
name = "pymongo"
-version = "4.9.2"
+version = "4.12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dnspython" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/fb/43/d5e8993bd43e6f9cbe985e8ae1398eb73309e88694ac2ea618eacbc9cea2/pymongo-4.9.2.tar.gz", hash = "sha256:3e63535946f5df7848307b9031aa921f82bb0cbe45f9b0c3296f2173f9283eb0", size = 1889366, upload-time = "2024-10-02T16:35:35.307Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/a1/08/7d95aab0463dc5a2c460a0b4e50a45a743afbe20986f47f87a9a88f43c0c/pymongo-4.9.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8083bbe8cb10bb33dca4d93f8223dd8d848215250bb73867374650bac5fe69e1", size = 941617, upload-time = "2024-10-02T16:34:27.178Z" },
- { url = "https://files.pythonhosted.org/packages/bb/28/40613d8d97fc33bf2b9187446a6746925623aa04a9a27c9b058e97076f7a/pymongo-4.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a1b8c636bf557c7166e3799bbf1120806ca39e3f06615b141c88d9c9ceae4d8c", size = 941394, upload-time = "2024-10-02T16:34:28.562Z" },
- { url = "https://files.pythonhosted.org/packages/df/b2/7f1a0d75f538c0dcaa004ea69e28706fa3ca72d848e0a5a7dafd30939fff/pymongo-4.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8aac5dce28454f47576063fbad31ea9789bba67cab86c95788f97aafd810e65b", size = 1907396, upload-time = "2024-10-02T16:34:30.263Z" },
- { url = "https://files.pythonhosted.org/packages/ba/70/9304bae47a361a4b12adb5be714bad41478c0e5bc3d6cf403b328d6398a0/pymongo-4.9.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1d5e7123af1fddf15b2b53e58f20bf5242884e671bcc3860f5e954fe13aeddd", size = 1986029, upload-time = "2024-10-02T16:34:32.346Z" },
- { url = "https://files.pythonhosted.org/packages/ae/51/ac0378d001995c4a705da64a4a2b8e1732f95de5080b752d69f452930cc7/pymongo-4.9.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe97c847b56d61e533a7af0334193d6b28375b9189effce93129c7e4733794a9", size = 1949088, upload-time = "2024-10-02T16:34:33.916Z" },
- { url = "https://files.pythonhosted.org/packages/1a/30/e93dc808039dc29fc47acee64f128aa650aacae3e4b57b68e01ff1001cda/pymongo-4.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ad54433a996e2d1985a9cd8fc82538ca8747c95caae2daf453600cc8c317f9", size = 1910516, upload-time = "2024-10-02T16:34:35.953Z" },
- { url = "https://files.pythonhosted.org/packages/2b/34/895b9cad3bd5342d5ab51a853ed3a814840ce281d55c6928968e9f3f49f5/pymongo-4.9.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98b9cade40f5b13e04492a42ae215c3721099be1014ddfe0fbd23f27e4f62c0c", size = 1860499, upload-time = "2024-10-02T16:34:37.727Z" },
- { url = "https://files.pythonhosted.org/packages/24/7e/167818f324bf2122d45551680671a3c6406a345d3fcace4e737f57bda4e4/pymongo-4.9.2-cp312-cp312-win32.whl", hash = "sha256:dde6068ae7c62ea8ee2c5701f78c6a75618cada7e11f03893687df87709558de", size = 901282, upload-time = "2024-10-02T16:34:39.128Z" },
- { url = "https://files.pythonhosted.org/packages/12/6b/b7ffa7114177fc1c60ae529512b82629ff7e25d19be88e97f2d0ddd16717/pymongo-4.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:e1ab6cd7cd2d38ffc7ccdc79fdc166c7a91a63f844a96e3e6b2079c054391c68", size = 924925, upload-time = "2024-10-02T16:34:40.859Z" },
- { url = "https://files.pythonhosted.org/packages/5b/d6/b57ef5f376e2e171218a98b8c30dfd001aa5cac6338aa7f3ca76e6315667/pymongo-4.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1ad79d6a74f439a068caf9a1e2daeabc20bf895263435484bbd49e90fbea7809", size = 995233, upload-time = "2024-10-02T16:34:42.437Z" },
- { url = "https://files.pythonhosted.org/packages/32/80/4ec79e36e99f86a063d297a334883fb5115ad70e9af46142b8dc33f636fa/pymongo-4.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:877699e21703717507cbbea23e75b419f81a513b50b65531e1698df08b2d7094", size = 995025, upload-time = "2024-10-02T16:34:44.032Z" },
- { url = "https://files.pythonhosted.org/packages/c4/fd/8f5464321fdf165700f10aec93b07a75c3537be593291ac2f8c8f5f69bd0/pymongo-4.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc9322ce7cf116458a637ac10517b0c5926a8211202be6dbdc51dab4d4a9afc8", size = 2167429, upload-time = "2024-10-02T16:34:45.519Z" },
- { url = "https://files.pythonhosted.org/packages/da/42/0f749d805d17f5b17f48f2ee1aaf2a74e67939607b87b245e5ec9b4c1452/pymongo-4.9.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cca029f46acf475504eedb33c7839f030c4bc4f946dcba12d9a954cc48850b79", size = 2258834, upload-time = "2024-10-02T16:34:47.324Z" },
- { url = "https://files.pythonhosted.org/packages/b8/52/b0c1b8e9cbeae234dd1108a906f30b680755533b7229f9f645d7e7adad25/pymongo-4.9.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c8c861e77527eec5a4b7363c16030dd0374670b620b08a5300f97594bbf5a40", size = 2216412, upload-time = "2024-10-02T16:34:48.747Z" },
- { url = "https://files.pythonhosted.org/packages/4d/20/53395473a1023bb6a670b68fbfa937664c75b354c2444463075ff43523e2/pymongo-4.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1fc70326ae71b3c7b8d6af82f46bb71dafdba3c8f335b29382ae9cf263ef3a5c", size = 2168891, upload-time = "2024-10-02T16:34:50.702Z" },
- { url = "https://files.pythonhosted.org/packages/01/b7/fa4030279d8a4a9c0a969a719b6b89da8a59795b5cdf129ef553fce6d1f2/pymongo-4.9.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba9d2f6df977fee24437f82f7412460b0628cd6b961c4235c9cff71577a5b61f", size = 2109380, upload-time = "2024-10-02T16:34:52.493Z" },
- { url = "https://files.pythonhosted.org/packages/f3/55/f252972a039fc6bfca748625c5080d6f88801eb61f118fe79cde47342d6a/pymongo-4.9.2-cp313-cp313-win32.whl", hash = "sha256:b3254769e708bc4aa634745c262081d13c841a80038eff3afd15631540a1d227", size = 946962, upload-time = "2024-10-02T16:34:53.967Z" },
- { url = "https://files.pythonhosted.org/packages/7b/36/88d8438699ba09b714dece00a4a7462330c1d316f5eaa28db450572236f6/pymongo-4.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:169b85728cc17800344ba17d736375f400ef47c9fbb4c42910c4b3e7c0247382", size = 975113, upload-time = "2024-10-02T16:34:56.646Z" },
+sdist = { url = "https://files.pythonhosted.org/packages/85/27/3634b2e8d88ad210ee6edac69259c698aefed4a79f0f7356cd625d5c423c/pymongo-4.12.1.tar.gz", hash = "sha256:8921bac7f98cccb593d76c4d8eaa1447e7d537ba9a2a202973e92372a05bd1eb", size = 2165515, upload-time = "2025-04-29T18:46:23.62Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/59/dd/b684de28bfaf7e296538601c514d4613f98b77cfa1de323c7b160f4e04d0/pymongo-4.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a7b771aa2f0854ddf7861e8ce2365f29df9159393543d047e43d8475bc4b8813", size = 910797, upload-time = "2025-04-29T18:44:57.783Z" },
+ { url = "https://files.pythonhosted.org/packages/e8/80/4fadd5400a4fbe57e7ea0349f132461d5dfc46c124937600f5044290d817/pymongo-4.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34fd8681b6fa6e1025dd1000004f6b81cbf1961f145b8c58bd15e3957976068d", size = 910489, upload-time = "2025-04-29T18:45:01.089Z" },
+ { url = "https://files.pythonhosted.org/packages/4e/83/303be22944312cc28e3a357556d21971c388189bf90aebc79e752afa2452/pymongo-4.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981e19b8f1040247dee5f7879e45f640f7e21a4d87eabb19283ce5a2927dd2e7", size = 1689142, upload-time = "2025-04-29T18:45:03.008Z" },
+ { url = "https://files.pythonhosted.org/packages/a4/67/f4e8506caf001ab9464df2562e3e022b7324e7c10a979ce1b55b006f2445/pymongo-4.12.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9a487dc1fe92736987a156325d3d9c66cbde6eac658b2875f5f222b6d82edca", size = 1753373, upload-time = "2025-04-29T18:45:04.874Z" },
+ { url = "https://files.pythonhosted.org/packages/2e/7c/22d65c2a4e3e941b345b8cc164b3b53f2c1d0db581d4991817b6375ef507/pymongo-4.12.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1525051c13984365c4a9b88ee2d63009fae277921bc89a0d323b52c51f91cbac", size = 1722399, upload-time = "2025-04-29T18:45:06.726Z" },
+ { url = "https://files.pythonhosted.org/packages/07/0d/32fd1ebafd0090510fb4820d175fe35d646e5b28c71ad9c36cb3ce554567/pymongo-4.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ad689e0e4f364809084f9e5888b2dcd6f0431b682a1c68f3fdf241e20e14475", size = 1692374, upload-time = "2025-04-29T18:45:08.552Z" },
+ { url = "https://files.pythonhosted.org/packages/e3/9c/d7a30ce6b983c3955c225e3038dafb4f299281775323f58b378f2a7e6e59/pymongo-4.12.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f9b18abca210c2917041ab2a380c12f6ddd2810844f1d64afb39caf8a15425e", size = 1651490, upload-time = "2025-04-29T18:45:10.658Z" },
+ { url = "https://files.pythonhosted.org/packages/29/b3/7902d73df1d088ec0c60c19ef4bd7894c6e6e4dfbfd7ab4ae4fbedc9427c/pymongo-4.12.1-cp312-cp312-win32.whl", hash = "sha256:d9d90fec041c6d695a639c26ca83577aa74383f5e3744fd7931537b208d5a1b5", size = 879521, upload-time = "2025-04-29T18:45:12.993Z" },
+ { url = "https://files.pythonhosted.org/packages/8c/68/a17ff6472e6be12bae75f5d11db4e3dccc55e02dcd4e66cd87871790a20e/pymongo-4.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:d004b13e4f03d73a3ad38505ba84b61a2c8ba0a304f02fe1b27bfc986c244192", size = 897765, upload-time = "2025-04-29T18:45:15.296Z" },
+ { url = "https://files.pythonhosted.org/packages/0c/4d/e6654f3ec6819980cbad77795ccf2275cd65d6df41375a22cdbbccef8416/pymongo-4.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:90de2b060d69c22658ada162a5380a0f88cb8c0149023241b9e379732bd36152", size = 965051, upload-time = "2025-04-29T18:45:17.516Z" },
+ { url = "https://files.pythonhosted.org/packages/54/95/627a047c32789544a938abfd9311c914e622cb036ad16866e7e1b9b80239/pymongo-4.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:edf4e05331ac875d3b27b4654b74d81e44607af4aa7d6bcd4a31801ca164e6fd", size = 964732, upload-time = "2025-04-29T18:45:19.478Z" },
+ { url = "https://files.pythonhosted.org/packages/8f/6d/7a604e3ab5399f8fe1ca88abdbf7e54ceb6cf03e64f68b2ed192d9a5eaf5/pymongo-4.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa7a817c9afb7b8775d98c469ddb3fe9c17daf53225394c1a74893cf45d3ade9", size = 1953037, upload-time = "2025-04-29T18:45:22.115Z" },
+ { url = "https://files.pythonhosted.org/packages/d5/d5/269388e7b0d02d35f55440baf1e0120320b6db1b555eaed7117d04b35402/pymongo-4.12.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d142ca531694e9324b3c9ba86c0e905c5f857599c4018a386c4dc02ca490fa", size = 2030467, upload-time = "2025-04-29T18:45:24.069Z" },
+ { url = "https://files.pythonhosted.org/packages/4b/d0/04a6b48d6ca3fc2ff156185a3580799a748cf713239d6181e91234a663d3/pymongo-4.12.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5d4c0461f5cd84d9fe87d5a84b1bc16371c4dd64d56dcfe5e69b15c0545a5ac", size = 1994139, upload-time = "2025-04-29T18:45:26.215Z" },
+ { url = "https://files.pythonhosted.org/packages/ad/65/0567052d52c0ac8aaa4baa700b39cdd1cf2481d2e59bd9817a3daf169ca0/pymongo-4.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43afd2f39182731ac9fb81bbc9439d539e4bd2eda72cdee829d2fa906a1c4d37", size = 1954947, upload-time = "2025-04-29T18:45:28.423Z" },
+ { url = "https://files.pythonhosted.org/packages/c5/5b/db25747b288218dbdd97e9aeff6a3bfa3f872efb4ed06fa8bec67b2a121e/pymongo-4.12.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:827ac668c003da7b175b8e5f521850e2c182b4638a3dec96d97f0866d5508a1e", size = 1904374, upload-time = "2025-04-29T18:45:30.943Z" },
+ { url = "https://files.pythonhosted.org/packages/fc/1e/6d0eb040c02ae655fafd63bd737e96d7e832eecfd0bd37074d0066f94a78/pymongo-4.12.1-cp313-cp313-win32.whl", hash = "sha256:7c2269b37f034124a245eaeb34ce031cee64610437bd597d4a883304babda3cd", size = 925869, upload-time = "2025-04-29T18:45:32.998Z" },
+ { url = "https://files.pythonhosted.org/packages/59/b9/459da646d9750529f04e7e686f0cd8dd40174138826574885da334c01b16/pymongo-4.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:3b28ecd1305b89089be14f137ffbdf98a3b9f5c8dbbb2be4dec084f2813fbd5f", size = 948411, upload-time = "2025-04-29T18:45:35.445Z" },
+ { url = "https://files.pythonhosted.org/packages/c9/c3/75be116159f210811656ec615b2248f63f1bc9dd1ce641e18db2552160f0/pymongo-4.12.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:f27b22a8215caff68bdf46b5b61ccd843a68334f2aa4658e8d5ecb5d3fbebb3b", size = 1021562, upload-time = "2025-04-29T18:45:37.433Z" },
+ { url = "https://files.pythonhosted.org/packages/cd/d1/2e8e368cad1c126a68365a6f53feaade58f9a16bd5f7a69f218af119b0e9/pymongo-4.12.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e9d23a3c290cf7409515466a7f11069b70e38ea2b786bbd7437bdc766c9e176", size = 1021553, upload-time = "2025-04-29T18:45:39.344Z" },
+ { url = "https://files.pythonhosted.org/packages/17/6e/a6460bc1e3d3f5f46cc151417427b2687a6f87972fd68a33961a37c114df/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efeb430f7ca8649a6544a50caefead343d1fd096d04b6b6a002c6ce81148a85c", size = 2281736, upload-time = "2025-04-29T18:45:41.462Z" },
+ { url = "https://files.pythonhosted.org/packages/1a/e2/9e1d6f1a492bb02116074baa832716805a0552d757c176e7c5f40867ca80/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a34e4a08bbcff56fdee86846afbc9ce751de95706ca189463e01bf5de3dd9927", size = 2368964, upload-time = "2025-04-29T18:45:43.579Z" },
+ { url = "https://files.pythonhosted.org/packages/fa/df/88143016eca77e79e38cf072476c70dd360962934430447dabc9c6bef6df/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b063344e0282537f05dbb11147591cbf58fc09211e24fc374749e343f880910a", size = 2327834, upload-time = "2025-04-29T18:45:45.847Z" },
+ { url = "https://files.pythonhosted.org/packages/3c/0d/df2998959b52cd5682b11e6eee1b0e0c104c07abd99c9cde5a871bb299fd/pymongo-4.12.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3f7941e01b3e5d4bfb3b4711425e809df8c471b92d1da8d6fab92c7e334a4cb", size = 2279126, upload-time = "2025-04-29T18:45:48.445Z" },
+ { url = "https://files.pythonhosted.org/packages/fb/3e/102636f5aaf97ccfa2a156c253a89f234856a0cd252fa602d4bf077ba3c0/pymongo-4.12.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b41235014031739f32be37ff13992f51091dae9a5189d3bcc22a5bf81fd90dae", size = 2218136, upload-time = "2025-04-29T18:45:50.57Z" },
+ { url = "https://files.pythonhosted.org/packages/44/c9/1b534c9d8d91d9d98310f2d955c5331fb522bd2a0105bd1fc31771d53758/pymongo-4.12.1-cp313-cp313t-win32.whl", hash = "sha256:9a1f07fe83a8a34651257179bd38d0f87bd9d90577fcca23364145c5e8ba1bc0", size = 974747, upload-time = "2025-04-29T18:45:52.66Z" },
+ { url = "https://files.pythonhosted.org/packages/08/e2/7d3a30ac905c99ea93729e03d2bb3d16fec26a789e98407d61cb368ab4bb/pymongo-4.12.1-cp313-cp313t-win_amd64.whl", hash = "sha256:46d86cf91ee9609d0713242a1d99fa9e9c60b4315e1a067b9a9e769bedae629d", size = 1003332, upload-time = "2025-04-29T18:45:54.631Z" },
]
[[package]]
diff --git a/backend/workers/Dockerfile.coordinator b/backend/workers/Dockerfile.coordinator
index a6567552..ae97091b 100644
--- a/backend/workers/Dockerfile.coordinator
+++ b/backend/workers/Dockerfile.coordinator
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run the coordinator service
-CMD ["uv", "run", "python", "-m", "workers.run_coordinator"]
+CMD ["python", "-m", "workers.run_coordinator"]
diff --git a/backend/workers/Dockerfile.dlq_processor b/backend/workers/Dockerfile.dlq_processor
index 0ab0ceda..3c53a72f 100644
--- a/backend/workers/Dockerfile.dlq_processor
+++ b/backend/workers/Dockerfile.dlq_processor
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run DLQ processor
-CMD ["uv", "run", "python", "workers/dlq_processor.py"]
+CMD ["python", "workers/dlq_processor.py"]
diff --git a/backend/workers/Dockerfile.event_replay b/backend/workers/Dockerfile.event_replay
index c551666d..948da191 100644
--- a/backend/workers/Dockerfile.event_replay
+++ b/backend/workers/Dockerfile.event_replay
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run event replay service
-CMD ["uv", "run", "python", "workers/run_event_replay.py"]
+CMD ["python", "workers/run_event_replay.py"]
diff --git a/backend/workers/Dockerfile.k8s_worker b/backend/workers/Dockerfile.k8s_worker
index a04dc146..dc02131a 100644
--- a/backend/workers/Dockerfile.k8s_worker
+++ b/backend/workers/Dockerfile.k8s_worker
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run Kubernetes worker
-CMD ["uv", "run", "python", "workers/run_k8s_worker.py"]
+CMD ["python", "workers/run_k8s_worker.py"]
diff --git a/backend/workers/Dockerfile.pod_monitor b/backend/workers/Dockerfile.pod_monitor
index ff57da38..77fe71dd 100644
--- a/backend/workers/Dockerfile.pod_monitor
+++ b/backend/workers/Dockerfile.pod_monitor
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run pod monitor
-CMD ["uv", "run", "python", "workers/run_pod_monitor.py"]
+CMD ["python", "workers/run_pod_monitor.py"]
diff --git a/backend/workers/Dockerfile.result_processor b/backend/workers/Dockerfile.result_processor
index 7356d6ae..9c878459 100644
--- a/backend/workers/Dockerfile.result_processor
+++ b/backend/workers/Dockerfile.result_processor
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run result processor
-CMD ["uv", "run", "python", "workers/run_result_processor.py"]
+CMD ["python", "workers/run_result_processor.py"]
diff --git a/backend/workers/Dockerfile.saga_orchestrator b/backend/workers/Dockerfile.saga_orchestrator
index 3218b92c..f69a90c5 100644
--- a/backend/workers/Dockerfile.saga_orchestrator
+++ b/backend/workers/Dockerfile.saga_orchestrator
@@ -5,4 +5,4 @@ FROM base
COPY . .
# Run saga orchestrator
-CMD ["uv", "run", "python", "workers/run_saga_orchestrator.py"]
+CMD ["python", "workers/run_saga_orchestrator.py"]
diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py
index a3725c9e..727f8e8e 100644
--- a/backend/workers/dlq_processor.py
+++ b/backend/workers/dlq_processor.py
@@ -1,14 +1,18 @@
import asyncio
+import os
import signal
from typing import Optional
-from app.core.database_context import Database, DBClient
-from app.core.logging import logger
+from app.core.database_context import DBClient
+from app.core.logging import setup_logger
from app.dlq import DLQMessage, RetryPolicy, RetryStrategy
from app.dlq.manager import DLQManager, create_dlq_manager
from app.domain.enums.kafka import KafkaTopic
+from app.events.schema.schema_registry import create_schema_registry_manager
from app.settings import get_settings
-from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
+
+logger = setup_logger(os.environ.get("LOG_LEVEL", "INFO"))
def _configure_retry_policies(manager: DLQManager) -> None:
@@ -100,18 +104,20 @@ async def alert_on_discard(message: DLQMessage, reason: str) -> None:
async def main() -> None:
settings = get_settings()
- db_client: DBClient = AsyncIOMotorClient(
+ db_client: DBClient = AsyncMongoClient(
settings.MONGODB_URL,
tz_aware=True,
serverSelectionTimeoutMS=5000,
)
db_name = settings.DATABASE_NAME
- database: Database = db_client[db_name]
+ _ = db_client[db_name] # Access database to verify connection
await db_client.admin.command("ping")
logger.info(f"Connected to database: {db_name}")
+ schema_registry = create_schema_registry_manager(logger)
manager = create_dlq_manager(
- database=database,
+ schema_registry=schema_registry,
+ logger=logger,
dlq_topic=KafkaTopic.DEAD_LETTER_QUEUE,
retry_topic_suffix="-retry",
)
diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py
index b97bf6db..29f1c1dd 100644
--- a/backend/workers/run_coordinator.py
+++ b/backend/workers/run_coordinator.py
@@ -12,20 +12,21 @@
def main() -> None:
"""Main entry point for coordinator worker"""
+ settings = get_settings()
+
# Setup logging
- setup_logger()
+ logger = setup_logger(settings.LOG_LEVEL)
# Configure root logger for worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
logger.info("Starting ExecutionCoordinator worker...")
# Initialize tracing
- settings = get_settings()
if settings.ENABLE_TRACING:
init_tracing(
service_name=GroupId.EXECUTION_COORDINATOR,
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
enable_console_exporter=False,
sampling_rate=settings.TRACING_SAMPLING_RATE,
diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py
index 3d34ddf5..c2fda75b 100644
--- a/backend/workers/run_event_replay.py
+++ b/backend/workers/run_event_replay.py
@@ -5,14 +5,15 @@
from app.core.database_context import DBClient
from app.core.logging import setup_logger
from app.core.tracing import init_tracing
+from app.db.docs import ALL_DOCUMENTS
from app.db.repositories.replay_repository import ReplayRepository
-from app.db.schema.schema_manager import SchemaManager
from app.events.core import ProducerConfig, UnifiedProducer
from app.events.event_store import create_event_store
from app.events.schema.schema_registry import SchemaRegistryManager
from app.services.event_replay.replay_service import EventReplayService
from app.settings import get_settings
-from motor.motor_asyncio import AsyncIOMotorClient
+from beanie import init_beanie
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
async def cleanup_task(replay_service: EventReplayService, interval_hours: int = 6) -> None:
@@ -28,15 +29,13 @@ async def cleanup_task(replay_service: EventReplayService, interval_hours: int =
logger.error(f"Error during cleanup: {e}")
-async def run_replay_service() -> None:
+async def run_replay_service(logger: logging.Logger) -> None:
"""Run the event replay service with cleanup task"""
- logger = logging.getLogger(__name__)
-
# Get settings
settings = get_settings()
# Create database connection
- db_client: DBClient = AsyncIOMotorClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
+ db_client: DBClient = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
db_name = settings.DATABASE_NAME
database = db_client[db_name]
@@ -44,26 +43,24 @@ async def run_replay_service() -> None:
await db_client.admin.command("ping")
logger.info(f"Connected to database: {db_name}")
- # Ensure DB schema
- await SchemaManager(database).apply_all()
+ # Initialize Beanie ODM (indexes are idempotently created via Document.Settings.indexes)
+ await init_beanie(database=database, document_models=ALL_DOCUMENTS)
# Initialize services
- schema_registry = SchemaRegistryManager()
+ schema_registry = SchemaRegistryManager(logger)
producer_config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS)
- producer = UnifiedProducer(producer_config, schema_registry)
+ producer = UnifiedProducer(producer_config, schema_registry, logger)
# Create event store
- event_store = create_event_store(db=database, schema_registry=schema_registry)
-
- # Ensure schema (indexes) for this worker process
- schema_manager = SchemaManager(database)
- await schema_manager.apply_all()
+ event_store = create_event_store(schema_registry=schema_registry, logger=logger)
# Create repository
- replay_repository = ReplayRepository(database)
+ replay_repository = ReplayRepository(logger)
# Create replay service
- replay_service = EventReplayService(repository=replay_repository, producer=producer, event_store=event_store)
+ replay_service = EventReplayService(
+ repository=replay_repository, producer=producer, event_store=event_store, logger=logger
+ )
logger.info("Event replay service initialized")
async with AsyncExitStack() as stack:
@@ -86,27 +83,28 @@ async def _cancel_task() -> None:
def main() -> None:
"""Main entry point for event replay service"""
+ settings = get_settings()
+
# Setup logging
- setup_logger()
+ logger = setup_logger(settings.LOG_LEVEL)
# Configure root logger for worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
logger.info("Starting Event Replay Service...")
# Initialize tracing
- settings = get_settings()
if settings.ENABLE_TRACING:
init_tracing(
service_name="event-replay",
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
enable_console_exporter=False,
sampling_rate=settings.TRACING_SAMPLING_RATE,
)
logger.info("Tracing initialized for Event Replay Service")
- asyncio.run(run_replay_service())
+ asyncio.run(run_replay_service(logger))
if __name__ == "__main__":
diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py
index 47c297a5..0e080e6b 100644
--- a/backend/workers/run_k8s_worker.py
+++ b/backend/workers/run_k8s_worker.py
@@ -12,20 +12,21 @@
def main() -> None:
"""Main entry point for Kubernetes worker"""
+ settings = get_settings()
+
# Setup logging
- setup_logger()
+ logger = setup_logger(settings.LOG_LEVEL)
# Configure root logger for worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
logger.info("Starting KubernetesWorker...")
# Initialize tracing
- settings = get_settings()
if settings.ENABLE_TRACING:
init_tracing(
service_name=GroupId.K8S_WORKER,
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
enable_console_exporter=False,
sampling_rate=settings.TRACING_SAMPLING_RATE,
diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py
index 74bfb1be..ebed2fc8 100644
--- a/backend/workers/run_pod_monitor.py
+++ b/backend/workers/run_pod_monitor.py
@@ -12,20 +12,21 @@
def main() -> None:
"""Main entry point for pod monitor worker"""
+ settings = get_settings()
+
# Setup logging
- setup_logger()
+ logger = setup_logger(settings.LOG_LEVEL)
# Configure root logger for worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
logger.info("Starting PodMonitor worker...")
# Initialize tracing
- settings = get_settings()
if settings.ENABLE_TRACING:
init_tracing(
service_name=GroupId.POD_MONITOR,
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
enable_console_exporter=False,
sampling_rate=settings.TRACING_SAMPLING_RATE,
diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py
index 21468431..2150b112 100644
--- a/backend/workers/run_result_processor.py
+++ b/backend/workers/run_result_processor.py
@@ -10,20 +10,21 @@
def main() -> None:
"""Main entry point for result processor worker"""
+ settings = get_settings()
+
# Setup logging
- setup_logger()
+ logger = setup_logger(settings.LOG_LEVEL)
# Configure root logger for worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
logger.info("Starting ResultProcessor worker...")
# Initialize tracing
- settings = get_settings()
if settings.ENABLE_TRACING:
init_tracing(
service_name=GroupId.RESULT_PROCESSOR,
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
enable_console_exporter=False,
sampling_rate=settings.TRACING_SAMPLING_RATE,
diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py
index 819a297b..53e45a83 100644
--- a/backend/workers/run_saga_orchestrator.py
+++ b/backend/workers/run_saga_orchestrator.py
@@ -5,9 +5,9 @@
from app.core.database_context import DBClient
from app.core.logging import setup_logger
from app.core.tracing import init_tracing
+from app.db.docs import ALL_DOCUMENTS
from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository
from app.db.repositories.saga_repository import SagaRepository
-from app.db.schema.schema_manager import SchemaManager
from app.domain.enums.kafka import GroupId
from app.domain.saga.models import SagaConfig
from app.events.core import ProducerConfig, UnifiedProducer
@@ -17,7 +17,8 @@
from app.services.idempotency.redis_repository import RedisIdempotencyRepository
from app.services.saga import create_saga_orchestrator
from app.settings import get_settings
-from motor.motor_asyncio import AsyncIOMotorClient
+from beanie import init_beanie
+from pymongo.asynchronous.mongo_client import AsyncMongoClient
async def run_saga_orchestrator() -> None:
@@ -27,7 +28,7 @@ async def run_saga_orchestrator() -> None:
logger = logging.getLogger(__name__)
# Create database connection
- db_client: DBClient = AsyncIOMotorClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
+ db_client: DBClient = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000)
db_name = settings.DATABASE_NAME
database = db_client[db_name]
@@ -35,26 +36,26 @@ async def run_saga_orchestrator() -> None:
await db_client.admin.command("ping")
logger.info(f"Connected to database: {db_name}")
- # Ensure DB schema (indexes/validators)
- await SchemaManager(database).apply_all()
+ # Initialize Beanie ODM (indexes are idempotently created via Document.Settings.indexes)
+ await init_beanie(database=database, document_models=ALL_DOCUMENTS)
# Initialize schema registry
logger.info("Initializing schema registry...")
- schema_registry_manager = SchemaRegistryManager()
+ schema_registry_manager = SchemaRegistryManager(logger)
await schema_registry_manager.initialize_schemas()
# Initialize Kafka producer
logger.info("Initializing Kafka producer...")
producer_config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS)
- producer = UnifiedProducer(producer_config, schema_registry_manager)
+ producer = UnifiedProducer(producer_config, schema_registry_manager, logger)
await producer.start()
# Create event store (schema ensured separately)
logger.info("Creating event store...")
- event_store = create_event_store(db=database, schema_registry=schema_registry_manager, ttl_days=90)
+ event_store = create_event_store(schema_registry=schema_registry_manager, logger=logger, ttl_days=90)
# Create repository and idempotency manager (Redis-backed)
- saga_repository = SagaRepository(database)
+ saga_repository = SagaRepository()
r = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
@@ -67,8 +68,8 @@ async def run_saga_orchestrator() -> None:
socket_timeout=5,
)
idem_repo = RedisIdempotencyRepository(r, key_prefix="idempotency")
- idempotency_manager = create_idempotency_manager(repository=idem_repo, config=IdempotencyConfig())
- resource_allocation_repository = ResourceAllocationRepository(database)
+ idempotency_manager = create_idempotency_manager(repository=idem_repo, config=IdempotencyConfig(), logger=logger)
+ resource_allocation_repository = ResourceAllocationRepository()
# Create saga orchestrator
saga_config = SagaConfig(
@@ -111,26 +112,27 @@ async def run_saga_orchestrator() -> None:
await producer.stop()
await idempotency_manager.close()
await r.aclose()
- db_client.close()
+ await db_client.close()
logger.info("Saga orchestrator shutdown complete")
def main() -> None:
"""Main entry point for saga orchestrator worker"""
+ settings = get_settings()
+
# Setup logging
- setup_logger()
+ logger = setup_logger(settings.LOG_LEVEL)
# Configure root logger for worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
logger.info("Starting Saga Orchestrator worker...")
# Initialize tracing
- settings = get_settings()
if settings.ENABLE_TRACING:
init_tracing(
service_name=GroupId.SAGA_ORCHESTRATOR,
+ logger=logger,
service_version=settings.TRACING_SERVICE_VERSION,
enable_console_exporter=False,
sampling_rate=settings.TRACING_SAMPLING_RATE,
diff --git a/docker-compose.ci.yaml b/docker-compose.ci.yaml
new file mode 100644
index 00000000..3b92dd9d
--- /dev/null
+++ b/docker-compose.ci.yaml
@@ -0,0 +1,240 @@
+# CI-optimized Docker Compose configuration
+#
+# Usage:
+# Backend integration tests (infra only, no builds):
+# docker compose -f docker-compose.ci.yaml up -d --wait
+#
+# Frontend E2E tests (full stack with builds):
+# docker compose -f docker-compose.ci.yaml --profile full up -d --wait
+#
+# Key differences from docker-compose.yaml:
+# - KRaft Kafka (no Zookeeper) - simpler, faster startup
+# - No SASL/TLS for Kafka - not needed for tests
+# - Profiles separate infra from app services
+# - Minimal services for fast CI
+
+services:
+ # =============================================================================
+ # INFRASTRUCTURE SERVICES (no profile = always started)
+ # =============================================================================
+
+ mongo:
+ image: mongo:8.0
+ container_name: mongo
+ ports:
+ - "27017:27017"
+ environment:
+ MONGO_INITDB_ROOT_USERNAME: root
+ MONGO_INITDB_ROOT_PASSWORD: rootpassword
+ MONGO_INITDB_DATABASE: integr8scode
+ tmpfs:
+ - /data/db # Use tmpfs for faster CI
+ networks:
+ - ci-network
+ healthcheck:
+ test: mongosh --eval 'db.runCommand("ping").ok' --quiet
+ interval: 2s
+ timeout: 3s
+ retries: 15
+ start_period: 5s
+
+ redis:
+ image: redis:7-alpine
+ container_name: redis
+ ports:
+ - "6379:6379"
+ command: redis-server --maxmemory 128mb --maxmemory-policy allkeys-lru --save ""
+ networks:
+ - ci-network
+ healthcheck:
+ test: ["CMD", "redis-cli", "ping"]
+ interval: 2s
+ timeout: 2s
+ retries: 10
+ start_period: 2s
+
+ # KRaft mode Kafka - official Apache image, no Zookeeper needed
+ kafka:
+ image: apache/kafka:3.9.0
+ container_name: kafka
+ ports:
+ - "9092:9092"
+ environment:
+ # KRaft mode configuration
+ KAFKA_NODE_ID: 1
+ KAFKA_PROCESS_ROLES: broker,controller
+ KAFKA_CONTROLLER_QUORUM_VOTERS: 1@localhost:9093
+ # Listeners: CONTROLLER for raft, HOST for external, DOCKER for internal
+ KAFKA_LISTENERS: CONTROLLER://localhost:9093,HOST://0.0.0.0:9092,DOCKER://0.0.0.0:29092
+ KAFKA_ADVERTISED_LISTENERS: HOST://localhost:9092,DOCKER://kafka:29092
+ KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: CONTROLLER:PLAINTEXT,HOST:PLAINTEXT,DOCKER:PLAINTEXT
+ KAFKA_CONTROLLER_LISTENER_NAMES: CONTROLLER
+ KAFKA_INTER_BROKER_LISTENER_NAME: DOCKER
+ # CI optimizations
+ KAFKA_AUTO_CREATE_TOPICS_ENABLE: "true"
+ KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1
+ KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: 1
+ KAFKA_TRANSACTION_STATE_LOG_MIN_ISR: 1
+ KAFKA_NUM_PARTITIONS: 1
+ KAFKA_DEFAULT_REPLICATION_FACTOR: 1
+ # Reduce memory usage
+ KAFKA_HEAP_OPTS: "-Xms256m -Xmx512m"
+ networks:
+ - ci-network
+ healthcheck:
+ test: /opt/kafka/bin/kafka-broker-api-versions.sh --bootstrap-server localhost:9092 || exit 1
+ interval: 2s
+ timeout: 5s
+ retries: 30
+ start_period: 10s
+
+ schema-registry:
+ image: confluentinc/cp-schema-registry:7.5.0
+ container_name: schema-registry
+ ports:
+ - "8081:8081"
+ environment:
+ SCHEMA_REGISTRY_HOST_NAME: schema-registry
+ SCHEMA_REGISTRY_KAFKASTORE_BOOTSTRAP_SERVERS: kafka:29092
+ SCHEMA_REGISTRY_LISTENERS: http://0.0.0.0:8081
+ SCHEMA_REGISTRY_HEAP_OPTS: "-Xms128m -Xmx256m"
+ depends_on:
+ kafka:
+ condition: service_healthy
+ networks:
+ - ci-network
+ healthcheck:
+ test: curl -f http://localhost:8081/config || exit 1
+ interval: 2s
+ timeout: 3s
+ retries: 20
+ start_period: 3s
+
+ # =============================================================================
+ # APPLICATION SERVICES (profile: full - only for E2E tests)
+ # =============================================================================
+
+ # Shared base image for backend
+ base:
+ build:
+ context: ./backend
+ dockerfile: Dockerfile.base
+ image: integr8scode-base:latest
+ profiles: ["full"]
+
+ # Certificate generator for TLS
+ shared-ca:
+ image: alpine:latest
+ profiles: ["full"]
+ volumes:
+ - shared_ca:/shared_ca
+ command: sh -c "mkdir -p /shared_ca && chmod 777 /shared_ca && sleep 1"
+ networks:
+ - ci-network
+
+ cert-generator:
+ build:
+ context: ./cert-generator
+ dockerfile: Dockerfile
+ profiles: ["full"]
+ volumes:
+ - ./backend/certs:/backend-certs
+ - ./frontend/certs:/frontend-certs
+ - shared_ca:/shared_ca
+ - ./backend:/backend
+ environment:
+ - SHARED_CA_DIR=/shared_ca
+ - BACKEND_CERT_DIR=/backend-certs
+ - FRONTEND_CERT_DIR=/frontend-certs
+ - CI=true
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
+ restart: "no"
+ network_mode: host
+ depends_on:
+ shared-ca:
+ condition: service_completed_successfully
+
+ backend:
+ build:
+ context: ./backend
+ dockerfile: Dockerfile
+ additional_contexts:
+ base: service:base
+ profiles: ["full"]
+ container_name: backend
+ ports:
+ - "443:443"
+ environment:
+ - SERVER_HOST=0.0.0.0
+ - TESTING=true
+ - MONGODB_URL=mongodb://root:rootpassword@mongo:27017/integr8scode?authSource=admin
+ - KAFKA_BOOTSTRAP_SERVERS=kafka:29092
+ - SCHEMA_REGISTRY_URL=http://schema-registry:8081
+ - REDIS_HOST=redis
+ - REDIS_PORT=6379
+ - OTEL_SDK_DISABLED=true
+ - ENABLE_TRACING=false
+ - SECRET_KEY=ci-test-secret-key-for-testing-only-32chars!!
+ volumes:
+ - ./backend/certs:/app/certs:ro
+ - shared_ca:/shared_ca:ro
+ - ./backend/kubeconfig.yaml:/app/kubeconfig.yaml:ro
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
+ depends_on:
+ base:
+ condition: service_completed_successfully
+ cert-generator:
+ condition: service_completed_successfully
+ mongo:
+ condition: service_healthy
+ redis:
+ condition: service_healthy
+ kafka:
+ condition: service_healthy
+ schema-registry:
+ condition: service_healthy
+ networks:
+ - ci-network
+ healthcheck:
+ test: ["CMD-SHELL", "curl -k -f -s https://localhost:443/api/v1/health/live || exit 1"]
+ interval: 5s
+ timeout: 5s
+ retries: 20
+ start_period: 30s
+
+ frontend:
+ build:
+ context: ./frontend
+ dockerfile: Dockerfile
+ profiles: ["full"]
+ container_name: frontend
+ ports:
+ - "5001:5001"
+ environment:
+ - VITE_BACKEND_URL=https://backend:443
+ - NODE_EXTRA_CA_CERTS=/shared_ca/mkcert-ca.pem
+ volumes:
+ - ./frontend/certs:/app/certs:ro
+ - shared_ca:/shared_ca:ro
+ depends_on:
+ cert-generator:
+ condition: service_completed_successfully
+ backend:
+ condition: service_healthy
+ networks:
+ - ci-network
+ healthcheck:
+ test: ["CMD-SHELL", "curl -k -f -s https://localhost:5001/ || exit 1"]
+ interval: 5s
+ timeout: 5s
+ retries: 20
+ start_period: 30s
+
+volumes:
+ shared_ca:
+
+networks:
+ ci-network:
+ driver: bridge
diff --git a/docker-compose.yaml b/docker-compose.yaml
index 89af7332..f68ec656 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -82,6 +82,8 @@ services:
condition: service_completed_successfully
cert-generator:
condition: service_completed_successfully
+ user-seed:
+ condition: service_completed_successfully
mongo:
condition: service_healthy
redis:
@@ -91,7 +93,9 @@ services:
schema-registry:
condition: service_healthy
volumes:
- - ./backend:/app
+ - ./backend/app:/app/app
+ - ./backend/workers:/app/workers
+ - ./backend/scripts:/app/scripts
- ./backend/certs:/app/certs:ro
- shared_ca:/shared_ca:ro
- ./backend/kubeconfig.yaml:/app/kubeconfig.yaml:ro
@@ -113,10 +117,10 @@ services:
healthcheck:
# Simpler, reliable healthcheck: curl fails non-zero for HTTP >=400 with -f
test: ["CMD-SHELL", "curl -k -f -s https://localhost:443/api/v1/health/live >/dev/null || exit 1"]
- interval: 5s
- timeout: 5s
- retries: 10
- start_period: 15s
+ interval: 3s
+ timeout: 3s
+ retries: 50
+ start_period: 10s
frontend:
container_name: frontend
@@ -363,7 +367,7 @@ services:
environment:
- KAFKA_BOOTSTRAP_SERVERS=kafka:29092
- SCHEMA_REGISTRY_URL=http://schema-registry:8081
- command: ["uv", "run", "python", "-m", "scripts.create_topics"]
+ command: ["python", "-m", "scripts.create_topics"]
networks:
- app-network
restart: "no" # Run once and exit
@@ -385,7 +389,7 @@ services:
- MONGODB_URL=mongodb://${MONGO_ROOT_USER:-root}:${MONGO_ROOT_PASSWORD:-rootpassword}@mongo:27017/integr8scode?authSource=admin
- DEFAULT_USER_PASSWORD=${DEFAULT_USER_PASSWORD:-user123}
- ADMIN_USER_PASSWORD=${ADMIN_USER_PASSWORD:-admin123}
- command: ["uv", "run", "python", "-m", "scripts.seed_users"]
+ command: ["python", "-m", "scripts.seed_users"]
networks:
- app-network
restart: "no" # Run once and exit
@@ -452,7 +456,8 @@ services:
- KUBECONFIG=/app/kubeconfig.yaml
- KAFKA_CONSUMER_GROUP_ID=k8s-worker
volumes:
- - ./backend:/app:ro
+ - ./backend/app:/app/app:ro
+ - ./backend/workers:/app/workers:ro
networks:
- app-network
extra_hosts:
@@ -486,7 +491,8 @@ services:
- KUBECONFIG=/app/kubeconfig.yaml
- KAFKA_CONSUMER_GROUP_ID=pod-monitor
volumes:
- - ./backend:/app:ro
+ - ./backend/app:/app/app:ro
+ - ./backend/workers:/app/workers:ro
networks:
- app-network
extra_hosts:
@@ -522,7 +528,8 @@ services:
- KAFKA_CONSUMER_GROUP_ID=result-processor-group
- KUBECONFIG=/app/kubeconfig.yaml
volumes:
- - ./backend:/app:ro
+ - ./backend/app:/app/app:ro
+ - ./backend/workers:/app/workers:ro
networks:
- app-network
extra_hosts:
@@ -593,7 +600,7 @@ services:
condition: service_completed_successfully
mongo:
condition: service_started
- command: ["uv", "run", "python", "workers/run_event_replay.py"]
+ command: ["python", "workers/run_event_replay.py"]
env_file:
- ./backend/.env
environment:
@@ -624,7 +631,7 @@ services:
condition: service_completed_successfully
mongo:
condition: service_started
- command: ["uv", "run", "python", "workers/dlq_processor.py"]
+ command: ["python", "workers/dlq_processor.py"]
env_file:
- ./backend/.env
environment:
diff --git a/docs/architecture/domain-exceptions.md b/docs/architecture/domain-exceptions.md
new file mode 100644
index 00000000..c04b4553
--- /dev/null
+++ b/docs/architecture/domain-exceptions.md
@@ -0,0 +1,206 @@
+# Domain exceptions
+
+This document explains how the backend handles errors using domain exceptions. It covers the exception hierarchy, how services use them, and how the middleware maps them to HTTP responses.
+
+## Why domain exceptions
+
+Services used to throw `HTTPException` directly with status codes like 404 or 422. That worked but created tight coupling between business logic and HTTP semantics. A service that throws `HTTPException(status_code=404)` knows it's running behind an HTTP API, which breaks when you want to reuse that service from a CLI tool, a message consumer, or a test harness.
+
+Domain exceptions fix this by letting services speak in business terms. A service raises `ExecutionNotFoundError(execution_id)` instead of `HTTPException(404, "Execution not found")`. The exception handler middleware maps domain exceptions to HTTP responses in one place. Services stay transport-agnostic, tests assert on meaningful exception types, and the mapping logic lives where it belongs.
+
+## Exception hierarchy
+
+All domain exceptions inherit from `DomainError`, which lives in `app/domain/exceptions.py`. The base classes map to HTTP status codes:
+
+| Base class | HTTP status | Use case |
+|------------|-------------|----------|
+| `NotFoundError` | 404 | Entity doesn't exist |
+| `ValidationError` | 422 | Invalid input or state |
+| `ThrottledError` | 429 | Rate limit exceeded |
+| `ConflictError` | 409 | Concurrent modification or duplicate |
+| `UnauthorizedError` | 401 | Missing or invalid credentials |
+| `ForbiddenError` | 403 | Authenticated but not allowed |
+| `InvalidStateError` | 400 | Operation invalid for current state |
+| `InfrastructureError` | 500 | External system failure |
+
+Each domain module defines specific exceptions that inherit from these bases. The hierarchy looks like this:
+
+```
+DomainError
+├── NotFoundError
+│ ├── ExecutionNotFoundError
+│ ├── SagaNotFoundError
+│ ├── NotificationNotFoundError
+│ ├── SavedScriptNotFoundError
+│ ├── ReplaySessionNotFoundError
+│ └── UserNotFoundError
+├── ValidationError
+│ ├── RuntimeNotSupportedError
+│ └── NotificationValidationError
+├── ThrottledError
+│ └── NotificationThrottledError
+├── ConflictError
+│ └── SagaConcurrencyError
+├── UnauthorizedError
+│ ├── AuthenticationRequiredError
+│ ├── InvalidCredentialsError
+│ └── TokenExpiredError
+├── ForbiddenError
+│ ├── SagaAccessDeniedError
+│ ├── AdminAccessRequiredError
+│ └── CSRFValidationError
+├── InvalidStateError
+│ └── SagaInvalidStateError
+└── InfrastructureError
+ ├── EventPublishError
+ ├── SagaCompensationError
+ ├── SagaTimeoutError
+ └── ReplayOperationError
+```
+
+## Exception locations
+
+Domain exceptions live in their respective domain modules:
+
+| Module | File | Exceptions |
+|--------|------|------------|
+| Base | `app/domain/exceptions.py` | `DomainError`, `NotFoundError`, `ValidationError`, etc. |
+| Execution | `app/domain/execution/exceptions.py` | `ExecutionNotFoundError`, `RuntimeNotSupportedError`, `EventPublishError` |
+| Saga | `app/domain/saga/exceptions.py` | `SagaNotFoundError`, `SagaAccessDeniedError`, `SagaInvalidStateError`, `SagaCompensationError`, `SagaTimeoutError`, `SagaConcurrencyError` |
+| Notification | `app/domain/notification/exceptions.py` | `NotificationNotFoundError`, `NotificationThrottledError`, `NotificationValidationError` |
+| Saved Script | `app/domain/saved_script/exceptions.py` | `SavedScriptNotFoundError` |
+| Replay | `app/domain/replay/exceptions.py` | `ReplaySessionNotFoundError`, `ReplayOperationError` |
+| User/Auth | `app/domain/user/exceptions.py` | `AuthenticationRequiredError`, `InvalidCredentialsError`, `TokenExpiredError`, `CSRFValidationError`, `AdminAccessRequiredError`, `UserNotFoundError` |
+
+## Rich constructors
+
+Specific exceptions have constructors that capture context for logging and debugging. Instead of just a message string, they take structured arguments:
+
+```python
+class SagaAccessDeniedError(ForbiddenError):
+ def __init__(self, saga_id: str, user_id: str) -> None:
+ self.saga_id = saga_id
+ self.user_id = user_id
+ super().__init__(f"Access denied to saga '{saga_id}' for user '{user_id}'")
+
+class NotificationThrottledError(ThrottledError):
+ def __init__(self, user_id: str, limit: int, window_hours: int) -> None:
+ self.user_id = user_id
+ self.limit = limit
+ self.window_hours = window_hours
+ super().__init__(f"Rate limit exceeded for user '{user_id}': max {limit} per {window_hours}h")
+```
+
+This means you can log `exc.saga_id` or `exc.limit` without parsing the message, and tests can assert on specific fields.
+
+## Exception handler
+
+The middleware in `app/core/exceptions/handlers.py` catches all `DomainError` subclasses and maps them to JSON responses:
+
+```python
+def configure_exception_handlers(app: FastAPI) -> None:
+ @app.exception_handler(DomainError)
+ async def domain_error_handler(request: Request, exc: DomainError) -> JSONResponse:
+ status_code = _map_to_status_code(exc)
+ return JSONResponse(
+ status_code=status_code,
+ content={"detail": exc.message, "type": type(exc).__name__},
+ )
+
+def _map_to_status_code(exc: DomainError) -> int:
+ if isinstance(exc, NotFoundError): return 404
+ if isinstance(exc, ValidationError): return 422
+ if isinstance(exc, ThrottledError): return 429
+ if isinstance(exc, ConflictError): return 409
+ if isinstance(exc, UnauthorizedError): return 401
+ if isinstance(exc, ForbiddenError): return 403
+ if isinstance(exc, InvalidStateError): return 400
+ if isinstance(exc, InfrastructureError): return 500
+ return 500
+```
+
+The response includes the exception type name, so clients can distinguish between `ExecutionNotFoundError` and `SagaNotFoundError` even though both return 404.
+
+## Using exceptions in services
+
+Services import exceptions from their domain module and raise them instead of `HTTPException`:
+
+```python
+# Before (coupled to HTTP)
+from fastapi import HTTPException
+
+async def get_execution(self, execution_id: str) -> DomainExecution:
+ execution = await self.repo.get_execution(execution_id)
+ if not execution:
+ raise HTTPException(status_code=404, detail="Execution not found")
+ return execution
+
+# After (transport-agnostic)
+from app.domain.execution import ExecutionNotFoundError
+
+async def get_execution(self, execution_id: str) -> DomainExecution:
+ execution = await self.repo.get_execution(execution_id)
+ if not execution:
+ raise ExecutionNotFoundError(execution_id)
+ return execution
+```
+
+The service no longer knows about HTTP. It raises a domain exception that describes what went wrong in business terms. The middleware handles the translation to HTTP.
+
+## Testing with domain exceptions
+
+Tests can assert on specific exception types and their fields:
+
+```python
+import pytest
+from app.domain.saga import SagaNotFoundError, SagaAccessDeniedError
+
+async def test_saga_not_found():
+ with pytest.raises(SagaNotFoundError) as exc_info:
+ await service.get_saga("nonexistent-id")
+ assert exc_info.value.identifier == "nonexistent-id"
+
+async def test_saga_access_denied():
+ with pytest.raises(SagaAccessDeniedError) as exc_info:
+ await service.get_saga_with_access_check(saga_id, unauthorized_user)
+ assert exc_info.value.saga_id == saga_id
+ assert exc_info.value.user_id == unauthorized_user.user_id
+```
+
+This is more precise than asserting on HTTP status codes and parsing error messages.
+
+## Adding new exceptions
+
+When adding a new exception:
+
+1. Choose the right base class based on the HTTP status it should map to
+2. Put it in the appropriate domain module's `exceptions.py`
+3. Export it from the module's `__init__.py`
+4. Use a rich constructor if the exception needs to carry context
+5. Raise it from the service layer, not the API layer
+
+Example for a new "quota exceeded" exception:
+
+```python
+# app/domain/execution/exceptions.py
+from app.domain.exceptions import ThrottledError
+
+class ExecutionQuotaExceededError(ThrottledError):
+ def __init__(self, user_id: str, current: int, limit: int) -> None:
+ self.user_id = user_id
+ self.current = current
+ self.limit = limit
+ super().__init__(f"Execution quota exceeded for user '{user_id}': {current}/{limit}")
+```
+
+The handler automatically maps it to 429 because it inherits from `ThrottledError`.
+
+## What stays as HTTPException
+
+API routes can still use `HTTPException` for route-level concerns that don't belong in the service layer:
+
+- Request validation that FastAPI doesn't catch (rare)
+- Authentication checks in route dependencies
+- Route-specific access control before calling services
+
+The general rule: if it's about the business domain, use domain exceptions. If it's about HTTP mechanics at the route level, `HTTPException` is fine.
diff --git a/docs/architecture/event-storage.md b/docs/architecture/event-storage.md
new file mode 100644
index 00000000..287aa8da
--- /dev/null
+++ b/docs/architecture/event-storage.md
@@ -0,0 +1,100 @@
+# Event storage architecture
+
+## Two collections, one purpose
+
+The system maintains *two separate MongoDB collections* for events: `events` and `event_store`. This implements a hybrid CQRS pattern where writes and reads are optimized for different use cases.
+
+## EventDocument vs EventStoreDocument
+
+**event_store** is the system's *permanent audit log* — an immutable append-only record of everything that happened:
+
+- Sourced from Kafka via `EventStoreConsumer`
+- No TTL — events persist indefinitely
+- Used for replay, compliance, and forensics
+- Single writer: the event store consumer
+
+**events** is an *operational projection* — a working copy optimized for day-to-day queries:
+
+- Sourced from application code via `KafkaEventService`
+- 30-day TTL — old events expire automatically
+- Used for admin dashboards, user-facing queries, analytics
+- Written by any service publishing events
+
+Both collections share identical schemas. The difference is *retention and purpose*.
+
+## Write flow
+
+When application code publishes an event, it flows through two paths:
+
+```mermaid
+graph TD
+ App[Application Code] --> KES[KafkaEventService.publish_event]
+ KES --> ER[EventRepository.store_event]
+ ER --> Events[(events collection)]
+ KES --> Producer[UnifiedProducer]
+ Producer --> Kafka[(Kafka)]
+ Kafka --> ESC[EventStoreConsumer]
+ ESC --> ES[EventStore.store_batch]
+ ES --> EventStore[(event_store collection)]
+```
+
+1. `KafkaEventService.publish_event()` stores to `events` collection AND publishes to Kafka
+2. `EventStoreConsumer` consumes from Kafka and stores to `event_store` collection
+
+This dual-write ensures:
+
+- **Immediate availability**: Events appear in `events` instantly for operational queries
+- **Permanent record**: Events flow through Kafka to `event_store` for audit trail
+- **Decoupling**: If Kafka consumer falls behind, operational queries remain fast
+
+## Read patterns
+
+Different repositories query different collections based on use case:
+
+| Repository | Collection | Use Case |
+|------------|------------|----------|
+| `EventRepository` | events | User-facing queries, recent events |
+| `AdminEventsRepository` | events | Admin dashboard, analytics |
+| `EventStore` | event_store | Replay, audit, historical queries |
+
+The admin console and user-facing features query `events` for fast access to recent data. The event store is reserved for replay scenarios and compliance needs.
+
+## Why not just one collection?
+
+**Storage costs**: The `events` collection with 30-day TTL keeps storage bounded. Without TTL, event volume would grow unbounded — problematic for operational queries that scan recent data.
+
+**Query performance**: Operational queries (last 24 hours, user's recent events) benefit from a smaller, indexed dataset. Scanning a years-long audit log for recent events wastes resources.
+
+**Retention policies**: Different data has different retention requirements. Operational data can expire. Audit logs often cannot.
+
+**Failure isolation**: If the event store consumer falls behind (Kafka lag), operational queries remain unaffected. The `events` collection stays current through direct writes.
+
+## Pod monitor integration
+
+The `PodMonitor` watches Kubernetes pods and publishes lifecycle events. These events must appear in both collections:
+
+```mermaid
+graph LR
+ K8s[Kubernetes Watch] --> PM[PodMonitor]
+ PM --> KES[KafkaEventService.publish_base_event]
+ KES --> Events[(events)]
+ KES --> Kafka[(Kafka)]
+ Kafka --> EventStore[(event_store)]
+```
+
+`PodMonitor` uses `KafkaEventService.publish_base_event()` to:
+
+1. Store pre-built events to `events` collection
+2. Publish to Kafka for downstream consumers and `event_store`
+
+This ensures pod events appear in admin dashboards immediately while maintaining the permanent audit trail.
+
+## Key files
+
+- `db/docs/event.py` — `EventDocument` and `EventStoreDocument` definitions
+- `db/repositories/event_repository.py` — operational event queries
+- `db/repositories/admin/admin_events_repository.py` — admin dashboard queries
+- `events/event_store.py` — permanent event store operations
+- `events/event_store_consumer.py` — Kafka to event_store consumer
+- `services/kafka_event_service.py` — unified publish (store + Kafka)
+- `services/pod_monitor/monitor.py` — pod lifecycle events
diff --git a/docs/architecture/model-conversion.md b/docs/architecture/model-conversion.md
new file mode 100644
index 00000000..83719588
--- /dev/null
+++ b/docs/architecture/model-conversion.md
@@ -0,0 +1,215 @@
+# Model Conversion Patterns
+
+This document describes the patterns for converting between domain models, Pydantic schemas, and ODM documents.
+
+## Core Principles
+
+1. **Domain models are dataclasses** - pure Python, no framework dependencies
+2. **Pydantic models are for boundaries** - API schemas, ODM documents, Kafka events
+3. **No custom converter methods** - no `to_dict()`, `from_dict()`, `from_response()`, etc.
+4. **Conversion at boundaries** - happens in repositories and services, not in models
+
+## Model Layers
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│ API Layer (Pydantic schemas) │
+│ app/schemas_pydantic/ │
+├─────────────────────────────────────────────────────────────┤
+│ Service Layer │
+│ app/services/ │
+├─────────────────────────────────────────────────────────────┤
+│ Domain Layer (dataclasses) │
+│ app/domain/ │
+├─────────────────────────────────────────────────────────────┤
+│ Infrastructure Layer (Pydantic/ODM) │
+│ app/db/docs/, app/infrastructure/kafka/events/ │
+└─────────────────────────────────────────────────────────────┘
+```
+
+## Conversion Patterns
+
+### Dataclass to Dict
+
+Use `asdict()` with dict comprehension for enum conversion and None filtering:
+
+```python
+from dataclasses import asdict
+
+# With enum conversion and None filtering
+update_dict = {
+ k: (v.value if hasattr(v, "value") else v)
+ for k, v in asdict(domain_obj).items()
+ if v is not None
+}
+
+# Without None filtering (keep all values)
+data = {
+ k: (v.value if hasattr(v, "value") else v)
+ for k, v in asdict(domain_obj).items()
+}
+```
+
+### Pydantic to Dict
+
+Use `model_dump()` directly:
+
+```python
+# Exclude None values
+data = pydantic_obj.model_dump(exclude_none=True)
+
+# Include all values
+data = pydantic_obj.model_dump()
+
+# JSON-compatible (datetimes as ISO strings)
+data = pydantic_obj.model_dump(mode="json")
+```
+
+### Dict to Pydantic
+
+Use `model_validate()` or constructor unpacking:
+
+```python
+# From dict
+obj = SomeModel.model_validate(data)
+
+# With unpacking
+obj = SomeModel(**data)
+```
+
+### Pydantic to Pydantic
+
+Use `model_validate()` when models have `from_attributes=True`:
+
+```python
+class User(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ ...
+
+# Convert between compatible Pydantic models
+user = User.model_validate(user_response)
+```
+
+### Dict to Dataclass
+
+Use constructor unpacking:
+
+```python
+# Direct unpacking
+domain_obj = DomainModel(**data)
+
+# With nested conversion
+domain_obj = DomainModel(
+ **{
+ **doc.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainMetadata(**doc.metadata.model_dump()),
+ }
+)
+```
+
+## Examples
+
+### Repository: Saving Domain to Document
+
+```python
+async def store_event(self, event: Event) -> str:
+ data = asdict(event)
+ # Convert nested dataclass with enum handling
+ data["metadata"] = {
+ k: (v.value if hasattr(v, "value") else v)
+ for k, v in asdict(event.metadata).items()
+ }
+ doc = EventDocument(**data)
+ await doc.insert()
+```
+
+### Repository: Loading Document to Domain
+
+```python
+async def get_event(self, event_id: str) -> Event | None:
+ doc = await EventDocument.find_one({"event_id": event_id})
+ if not doc:
+ return None
+ return Event(
+ **{
+ **doc.model_dump(exclude={"id", "revision_id"}),
+ "metadata": DomainMetadata(**doc.metadata.model_dump()),
+ }
+ )
+```
+
+### Repository: Updating with Typed Input
+
+```python
+async def update_session(self, session_id: str, updates: SessionUpdate) -> bool:
+ update_dict = {
+ k: (v.value if hasattr(v, "value") else v)
+ for k, v in asdict(updates).items()
+ if v is not None
+ }
+ if not update_dict:
+ return False
+ doc = await SessionDocument.find_one({"session_id": session_id})
+ if not doc:
+ return False
+ await doc.set(update_dict)
+ return True
+```
+
+### Service: Converting Between Pydantic Models
+
+```python
+# In API route
+user = User.model_validate(current_user)
+
+# In service converting Kafka metadata to domain
+domain_metadata = DomainEventMetadata(**avro_metadata.model_dump())
+```
+
+## Anti-Patterns
+
+### Don't: Custom Converter Methods
+
+```python
+# BAD - adds unnecessary abstraction
+class MyModel:
+ def to_dict(self) -> dict:
+ return {...}
+
+ @classmethod
+ def from_dict(cls, data: dict) -> "MyModel":
+ return cls(...)
+```
+
+### Don't: Pydantic in Domain Layer
+
+```python
+# BAD - domain should be framework-agnostic
+from pydantic import BaseModel
+
+class DomainEntity(BaseModel): # Wrong!
+ ...
+```
+
+### Don't: Manual Field-by-Field Conversion
+
+```python
+# BAD - verbose and error-prone
+def from_response(cls, resp):
+ return cls(
+ field1=resp.field1,
+ field2=resp.field2,
+ field3=resp.field3,
+ ...
+ )
+```
+
+## Summary
+
+| From | To | Method |
+|------|-----|--------|
+| Dataclass | Dict | `{k: (v.value if hasattr(v, "value") else v) for k, v in asdict(obj).items()}` |
+| Pydantic | Dict | `obj.model_dump()` |
+| Dict | Pydantic | `Model.model_validate(data)` or `Model(**data)` |
+| Pydantic | Pydantic | `TargetModel.model_validate(source)` |
+| Dict | Dataclass | `DataclassModel(**data)` |
diff --git a/docs/architecture/pydantic-dataclasses.md b/docs/architecture/pydantic-dataclasses.md
new file mode 100644
index 00000000..b169958c
--- /dev/null
+++ b/docs/architecture/pydantic-dataclasses.md
@@ -0,0 +1,173 @@
+# Pydantic dataclasses
+
+This document explains why domain models use `pydantic.dataclasses.dataclass` instead of the standard library
+`dataclasses.dataclass`. It covers the problem with nested dict conversion, the solution, and migration considerations.
+
+## Why pydantic dataclasses
+
+Domain models are dataclasses that represent business entities like `DomainUserSettings`, `DomainExecution`, and `Saga`.
+These models often have nested structures - for example, `DomainUserSettings` contains `DomainNotificationSettings` and
+`DomainEditorSettings` as nested dataclasses.
+
+The problem appears when loading data from MongoDB. Beanie documents are Pydantic models, and calling `model_dump()` on
+them returns plain Python dicts, including nested dicts for nested models. When you pass these dicts to a stdlib
+dataclass constructor, nested dicts stay as dicts instead of being converted to their proper dataclass types.
+
+```python
+# Data from MongoDB via Beanie document.model_dump()
+data = {
+ "user_id": "user123",
+ "notifications": {
+ "execution_completed": False,
+ "execution_failed": True
+ }
+}
+
+# With stdlib dataclass - FAILS
+settings = DomainUserSettings(**data)
+settings.notifications.execution_completed # AttributeError: 'dict' has no attribute 'execution_completed'
+
+# With pydantic dataclass - WORKS
+settings = DomainUserSettings(**data)
+settings.notifications.execution_completed # Returns False
+```
+
+Pydantic dataclasses use type annotations to automatically convert nested dicts into the correct dataclass instances. No
+reflection, no isinstance checks, no manual conversion code.
+
+## What pydantic dataclasses provide
+
+Pydantic dataclasses are a drop-in replacement for stdlib dataclasses with added features:
+
+| Feature | stdlib | pydantic |
+|------------------------|--------|----------|
+| Nested dict conversion | No | Yes |
+| Enum from string | No | Yes |
+| Type validation | No | Yes |
+| String-to-int coercion | No | Yes |
+| `asdict()` | Yes | Yes |
+| `is_dataclass()` | Yes | Yes |
+| `__dataclass_fields__` | Yes | Yes |
+| `field()` | Yes | Yes |
+| `__post_init__` | Yes | Yes |
+| `replace()` | Yes | Yes |
+| frozen/eq/hash | Yes | Yes |
+| Inheritance | Yes | Yes |
+
+The migration requires changing one import:
+
+```python
+# Before
+from dataclasses import dataclass
+
+# After
+from pydantic.dataclasses import dataclass
+```
+
+Everything else stays the same. The `field` function still comes from stdlib `dataclasses`.
+
+## Performance
+
+Pydantic dataclasses add validation overhead at construction time:
+
+| Operation | stdlib | pydantic | Ratio |
+|--------------------|-------------|-------------|-------------|
+| Creation from dict | 0.2 us | 1.4 us | 6x slower |
+| Attribute access | 4.1 ms/100k | 4.6 ms/100k | 1.1x slower |
+
+The creation overhead is negligible for typical usage patterns - domain models are created during request handling, not
+in tight loops. Attribute access after construction has no meaningful overhead.
+
+## Domain model locations
+
+All domain models live in `app/domain/` and use pydantic dataclasses:
+
+| Module | File | Key models |
+|--------------|---------------------------------------|----------------------------------------------------------------------------|
+| User | `app/domain/user/settings_models.py` | `DomainUserSettings`, `DomainNotificationSettings`, `DomainEditorSettings` |
+| User | `app/domain/user/user_models.py` | `User`, `UserCreation`, `UserUpdate` |
+| Execution | `app/domain/execution/models.py` | `DomainExecution`, `ExecutionResultDomain` |
+| Events | `app/domain/events/event_models.py` | `Event`, `EventFilter`, `EventQuery` |
+| Events | `app/domain/events/event_metadata.py` | `EventMetadata` |
+| Saga | `app/domain/saga/models.py` | `Saga`, `SagaInstance`, `SagaConfig` |
+| Replay | `app/domain/replay/models.py` | `ReplaySessionState` |
+| Notification | `app/domain/notification/models.py` | `DomainNotification`, `DomainNotificationSubscription` |
+| Admin | `app/domain/admin/settings_models.py` | `SystemSettings`, `ExecutionLimits` |
+
+## Using domain models in repositories
+
+Repositories that load from MongoDB convert Beanie documents to domain models:
+
+```python
+from app.domain.user.settings_models import DomainUserSettings
+
+class UserSettingsRepository:
+ async def get_snapshot(self, user_id: str) -> DomainUserSettings | None:
+ doc = await UserSettingsDocument.find_one({"user_id": user_id})
+ if not doc:
+ return None
+ # Pydantic dataclass handles nested conversion automatically
+ return DomainUserSettings(**doc.model_dump(exclude={"id", "revision_id"}))
+```
+
+No manual conversion of nested fields needed. The type annotations on `DomainUserSettings` tell pydantic how to convert
+each nested dict.
+
+## Validation behavior
+
+Pydantic dataclasses validate input data at construction time. Invalid data raises `ValidationError`:
+
+```python
+# Invalid enum value
+DomainUserSettings(user_id="u1", theme="invalid_theme")
+# ValidationError: Input should be 'light', 'dark' or 'auto'
+
+# Invalid type
+DomainNotificationSettings(execution_completed="not_a_bool")
+# ValidationError: Input should be a valid boolean
+```
+
+This catches data problems at the boundary where data enters the domain, rather than later during processing. Services
+can trust that domain models contain valid data.
+
+## What stays as Pydantic BaseModel
+
+Some classes still use `pydantic.BaseModel` instead of dataclasses:
+
+- Beanie documents (require BaseModel for ODM features)
+- Request/response schemas (FastAPI integration)
+- Configuration models with complex validation
+- Classes that need `model_validate()`, `model_json_schema()`, or other BaseModel methods
+
+The rule: use pydantic dataclasses for domain models that represent business entities. Use BaseModel for infrastructure
+concerns like documents, schemas, and configs.
+
+## Adding new domain models
+
+When creating a new domain model:
+
+1. Import dataclass from pydantic: `from pydantic.dataclasses import dataclass`
+2. Import field from stdlib if needed: `from dataclasses import field`
+3. Define the class with `@dataclass` decorator
+4. Use type annotations - pydantic uses them for conversion and validation
+5. Put nested dataclasses before the parent class that uses them
+
+```python
+from dataclasses import field
+from datetime import datetime
+from pydantic.dataclasses import dataclass
+
+@dataclass
+class NestedModel:
+ value: int
+ label: str = "default"
+
+@dataclass
+class ParentModel:
+ id: str
+ nested: NestedModel
+ items: list[str] = field(default_factory=list)
+ created_at: datetime = field(default_factory=datetime.utcnow)
+```
+
+The model automatically handles nested dict conversion, enum parsing, and type coercion.
diff --git a/docs/components/saga/resource-allocation.md b/docs/components/saga/resource-allocation.md
new file mode 100644
index 00000000..0a542455
--- /dev/null
+++ b/docs/components/saga/resource-allocation.md
@@ -0,0 +1,106 @@
+# Resource allocation
+
+## Why it exists
+
+When you run code on Integr8sCode, behind the scenes the system spins up a Kubernetes pod with specific CPU and memory limits. But what happens if a thousand users all hit "Run" at the same time? Without some form of throttling, you'd either exhaust cluster resources or start rejecting requests with cryptic errors.
+
+The resource allocation system acts as a booking ledger for execution resources. Before creating a pod, the system checks how many executions are already running for that language and either proceeds or backs off. If something goes wrong mid-execution, the allocation gets released so it doesn't count against the limit forever.
+
+## How it works
+
+Resource allocation is a step within the execution saga — the distributed transaction pattern that orchestrates the full lifecycle of running user code. The `AllocateResourcesStep` runs early in the saga, right after validation:
+
+```
+ExecutionRequested
+ → ValidateExecutionStep
+ → AllocateResourcesStep ← creates allocation record
+ → QueueExecutionStep
+ → CreatePodStep
+ → MonitorExecutionStep
+```
+
+When the step executes, it does two things:
+
+1. **Counts active allocations** for the requested language
+2. **Enforces a concurrency limit** (currently 100 per language)
+
+If the limit hasn't been reached, it creates an allocation record with status "active". If any later step fails, the saga's compensation mechanism kicks in and `ReleaseResourcesCompensation` marks the allocation as "released".
+
+## Data model
+
+Each allocation tracks everything needed for resource accounting:
+
+| Field | Description |
+|-------|-------------|
+| `allocation_id` | Unique identifier (same as execution_id) |
+| `execution_id` | The execution this allocation belongs to |
+| `language` | Programming language (used for per-language limits) |
+| `cpu_request` / `cpu_limit` | Kubernetes CPU settings |
+| `memory_request` / `memory_limit` | Kubernetes memory settings |
+| `status` | Either "active" or "released" |
+| `allocated_at` | When the allocation was created |
+| `released_at` | When it was released (if applicable) |
+
+## Saga compensation
+
+The beauty of tying resource allocation to the saga pattern is automatic cleanup. Consider this failure scenario:
+
+```mermaid
+graph TD
+ A[ValidateExecutionStep] -->|success| B[AllocateResourcesStep]
+ B -->|success, creates allocation| C[QueueExecutionStep]
+ C -->|success| D[CreatePodStep]
+ D -->|FAILS| E[Compensation begins]
+ E --> F[DeletePodCompensation]
+ F --> G[RemoveFromQueueCompensation]
+ G --> H[ReleaseResourcesCompensation]
+ H -->|marks allocation released| I[Done]
+```
+
+If pod creation fails for any reason, the saga automatically runs compensation steps in reverse order. The `ReleaseResourcesCompensation` step finds the allocation record and marks it as "released", freeing up that slot for another execution.
+
+Without this pattern, you'd need manual cleanup logic scattered throughout the codebase, likely with edge cases where allocations leak and slowly eat into your concurrency budget.
+
+## Repository interface
+
+The `ResourceAllocationRepository` provides three simple operations:
+
+```python
+# Check how many executions are running for a language
+count = await repo.count_active("python")
+
+# Create a new allocation
+await repo.create_allocation(
+ allocation_id=execution_id,
+ execution_id=execution_id,
+ language="python",
+ cpu_request="100m",
+ memory_request="128Mi",
+ cpu_limit="500m",
+ memory_limit="256Mi",
+)
+
+# Release an allocation (called during compensation)
+await repo.release_allocation(allocation_id)
+```
+
+The repository uses MongoDB through Beanie ODM. Allocations are indexed by `allocation_id`, `execution_id`, `language`, and `status` for efficient queries.
+
+## Failure modes
+
+**If MongoDB is unavailable:** The allocation step fails, which fails the saga, which means no pod gets created. The user sees an error, but the system stays consistent.
+
+**If an allocation leaks (never released):** This would happen if both the saga and its compensation failed catastrophically. The allocation would stay "active" forever, counting against the limit. In practice, you'd want a periodic cleanup job to release stale allocations older than, say, the maximum execution timeout plus a buffer.
+
+**If the count query is slow:** Under high load, counting active allocations could become a bottleneck. The current implementation uses a simple count query, but this could be optimized with a cached counter or Redis if needed.
+
+## Configuration
+
+The concurrency limit is currently hardcoded to 100 per language in `AllocateResourcesStep`. To change this, modify the check in `execution_saga.py`:
+
+```python
+if active_count >= 100: # <- adjust this value
+ raise ValueError("Resource limit exceeded")
+```
+
+Future improvements could make this configurable per-language or dynamically adjustable based on cluster capacity.
diff --git a/docs/operations/cicd.md b/docs/operations/cicd.md
index 7469f192..c0940c4f 100644
--- a/docs/operations/cicd.md
+++ b/docs/operations/cicd.md
@@ -105,9 +105,14 @@ graph TD
### Base image
The base image (`Dockerfile.base`) contains Python, system dependencies, and all pip packages. It
-uses [uv](https://docs.astral.sh/uv/) to install dependencies from the lockfile, ensuring reproducible builds. The base
-includes gcc, curl, and compression libraries needed by some Python packages. Separating base from application means
-dependency changes rebuild the base layer while code changes only rebuild the thin application layer.
+uses [uv](https://docs.astral.sh/uv/) to install dependencies from the lockfile with `uv sync --locked --no-dev`,
+ensuring reproducible builds without development tools. The base includes gcc, curl, and compression libraries needed
+by some Python packages.
+
+The image sets `PATH="/app/.venv/bin:$PATH"` so services can run Python directly without `uv run` at startup. This
+avoids dependency resolution at container start, making services launch in seconds rather than minutes. Separating base
+from application means dependency changes rebuild the base layer while code changes only rebuild the thin application
+layer. See [Docker build strategy](deployment.md#docker-build-strategy) for details on the local development setup.
### Build contexts
diff --git a/docs/operations/deployment.md b/docs/operations/deployment.md
index bc80665b..82890546 100644
--- a/docs/operations/deployment.md
+++ b/docs/operations/deployment.md
@@ -57,6 +57,46 @@ DEFAULT_USER_PASSWORD=mypass ADMIN_USER_PASSWORD=myadmin ./deploy.sh dev
Hot reloading works for the backend since the source directory is mounted into the container. Changes to Python files
trigger Uvicorn to restart automatically. The frontend runs its own dev server with similar behavior.
+### Docker build strategy
+
+The backend uses a multi-stage build with a shared base image to keep startup fast. All Python dependencies are
+installed at build time, so containers start in seconds rather than waiting for package downloads.
+
+```
+Dockerfile.base Dockerfile (backend) Dockerfile.* (workers)
+┌──────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐
+│ python:3.12-slim │ │ FROM base │ │ FROM base │
+│ system deps │────▶│ COPY app, workers │ │ COPY app, workers │
+│ uv sync --locked │ │ entrypoint.sh │ │ CMD ["python", ...] │
+│ ENV PATH=.venv │ └──────────────────────┘ └──────────────────────┘
+└──────────────────┘
+```
+
+The base image (`Dockerfile.base`) installs all production dependencies using `uv sync --locked --no-dev`. The
+`--locked` flag ensures the lockfile is respected exactly, and `--no-dev` skips development tools like ruff and mypy
+that aren't needed at runtime. The key optimization is setting `PATH="/app/.venv/bin:$PATH"` so Python and all
+installed packages are available directly without needing `uv run` at startup.
+
+Each service image extends the base and copies only application code. Since dependencies rarely change compared to
+code, Docker's layer caching means most builds only rebuild the thin application layer. First builds take longer
+because they install all packages, but subsequent builds are fast.
+
+For local development, the compose file mounts source directories into the container:
+
+```yaml
+volumes:
+ - ./backend/app:/app/app
+ - ./backend/workers:/app/workers
+ - ./backend/scripts:/app/scripts
+```
+
+This selective mounting preserves the container's `.venv` directory (with all installed packages) while allowing live
+code changes. The mounted directories overlay the baked-in copies, so edits take effect immediately. Gunicorn watches
+for file changes and reloads workers automatically.
+
+The design means `git clone` followed by `docker compose up` just works. No local Python environment needed, no named
+volumes for caching, no waiting for package downloads. Dependencies live in the image, code comes from the mount.
+
To stop everything and clean up volumes:
```bash
diff --git a/docs/operations/logging.md b/docs/operations/logging.md
new file mode 100644
index 00000000..b1c652d4
--- /dev/null
+++ b/docs/operations/logging.md
@@ -0,0 +1,57 @@
+# Logging
+
+This backend uses structured JSON logging with automatic correlation IDs, trace context injection, and sensitive data sanitization. The goal is logs that are both secure against injection attacks and easy to query in aggregation systems like Elasticsearch or Loki.
+
+## How it's wired
+
+The logger is created once during application startup via dependency injection. The `setup_logger` function in `app/core/logging.py` configures a JSON formatter and attaches filters for correlation IDs and OpenTelemetry trace context. Every log line comes out as a JSON object with timestamp, level, logger name, message, and whatever structured fields you added. Workers and background services use the same setup, so log format is consistent across the entire system.
+
+The JSON formatter does two things beyond basic formatting. First, it injects context that would be tedious to pass manually - the correlation ID from the current request, the trace and span IDs from OpenTelemetry, and request metadata like method and path. Second, it sanitizes sensitive data by pattern-matching things like API keys, JWT tokens, and database URLs, replacing them with redaction placeholders. This sanitization applies to both the log message and exception tracebacks.
+
+## Structured logging
+
+All log calls use the `extra` parameter to pass structured data rather than interpolating values into the message string. The message itself is a static string that describes what happened; the details go in `extra` where they become separate JSON fields.
+
+```python
+# This is how logging looks throughout the codebase
+self.logger.info(
+ "Event deleted by admin",
+ extra={
+ "event_id": event_id,
+ "admin_email": admin.email,
+ "event_type": result.event_type,
+ },
+)
+```
+
+The reason for this pattern is partly about queryability - log aggregators can index the `event_id` field separately and let you filter on it - but mostly about security. When you interpolate user-controlled data into a log message, you open the door to log injection attacks.
+
+## Log injection
+
+Log injection is what happens when an attacker crafts input that corrupts your log output. The classic attack looks like this: a user submits an event ID containing a newline and a fake log entry.
+
+```python
+# Attacker submits this as event_id
+event_id = "abc123\n[CRITICAL] System compromised - contact security@evil.com"
+
+# If you log it directly in the message...
+logger.warning(f"Processing event {event_id}")
+
+# Your log output now contains a forged critical alert
+```
+
+The fix is to keep user data out of the message string entirely. When you put it in `extra`, the JSON formatter escapes special characters, and the malicious content becomes a harmless string value rather than a log line injection.
+
+The codebase treats these as user-controlled and keeps them in `extra`: path parameters like execution_id or saga_id, query parameters, request body fields, Kafka message content, database results derived from user input, and exception messages (which often contain user data).
+
+## What gets logged
+
+Correlation and trace IDs are injected automatically by filters. The correlation ID follows a request through all services - it's set from incoming headers or generated for new requests. The trace and span IDs come from OpenTelemetry and link logs to distributed traces in Jaeger or Tempo. You don't need to pass these explicitly; they appear in every log line from code running in that request context.
+
+For domain-specific context, developers add fields to `extra` based on what operation they're logging. An execution service method might include `execution_id`, `user_id`, `language`, and `status`. A replay session logs `session_id`, `replayed_events`, `failed_events`, and `duration_seconds`. A saga operation includes `saga_id` and `user_id`. The pattern is consistent: the message says what happened, `extra` says to what and by whom.
+
+## Practical use
+
+When something goes wrong, start by filtering logs by correlation_id to see everything that happened during that request. If you need to correlate with traces, use the trace_id to jump to Jaeger. If you're investigating a specific execution or saga, filter by those IDs - they're in the structured fields, not buried in message text.
+
+The log level is controlled by the `LOG_LEVEL` environment variable. In production it's typically INFO, which captures normal operations (started, completed, processed) and problems (warnings for recoverable issues, errors for failures). DEBUG adds detailed diagnostic info and is usually too noisy for production but useful when investigating specific issues locally.
diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json
index e0053102..4e543876 100644
--- a/docs/reference/openapi.json
+++ b/docs/reference/openapi.json
@@ -732,7 +732,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/SavedScriptCreateRequest"
+ "$ref": "#/components/schemas/SavedScriptUpdate"
}
}
}
@@ -1261,7 +1261,7 @@
"schema": {
"anyOf": [
{
- "type": "string"
+ "$ref": "#/components/schemas/EventType"
},
{
"type": "null"
@@ -1629,6 +1629,29 @@
"title": "Include System Events"
},
"description": "Include system-generated events"
+ },
+ {
+ "name": "limit",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "type": "integer",
+ "maximum": 1000,
+ "minimum": 1,
+ "default": 100,
+ "title": "Limit"
+ }
+ },
+ {
+ "name": "skip",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "type": "integer",
+ "minimum": 0,
+ "default": 0,
+ "title": "Skip"
+ }
}
],
"responses": {
@@ -1855,6 +1878,17 @@
"default": 100,
"title": "Limit"
}
+ },
+ {
+ "name": "skip",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "type": "integer",
+ "minimum": 0,
+ "default": 0,
+ "title": "Skip"
+ }
}
],
"responses": {
@@ -1900,6 +1934,17 @@
"default": 100,
"title": "Limit"
}
+ },
+ {
+ "name": "skip",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "type": "integer",
+ "minimum": 0,
+ "default": 0,
+ "title": "Skip"
+ }
}
],
"responses": {
@@ -4078,7 +4123,7 @@
"sagas"
],
"summary": "Get Execution Sagas",
- "description": "Get all sagas for an execution.\n\nArgs:\n execution_id: The execution identifier\n request: FastAPI request object\n saga_service: Saga service from DI\n auth_service: Auth service from DI\n state: Optional state filter\n\nReturns:\n List of sagas for the execution\n\nRaises:\n HTTPException: 403 if access denied",
+ "description": "Get all sagas for an execution.\n\nArgs:\n execution_id: The execution identifier\n request: FastAPI request object\n saga_service: Saga service from DI\n auth_service: Auth service from DI\n state: Optional state filter\n limit: Maximum number of results\n skip: Number of results to skip\n\nReturns:\n Paginated list of sagas for the execution\n\nRaises:\n HTTPException: 403 if access denied",
"operationId": "get_execution_sagas_api_v1_sagas_execution__execution_id__get",
"parameters": [
{
@@ -4107,6 +4152,29 @@
"title": "State"
},
"description": "Filter by saga state"
+ },
+ {
+ "name": "limit",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "type": "integer",
+ "maximum": 1000,
+ "minimum": 1,
+ "default": 100,
+ "title": "Limit"
+ }
+ },
+ {
+ "name": "skip",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "type": "integer",
+ "minimum": 0,
+ "default": 0,
+ "title": "Skip"
+ }
}
],
"responses": {
@@ -4139,7 +4207,7 @@
"sagas"
],
"summary": "List Sagas",
- "description": "List sagas accessible by the current user.\n\nArgs:\n request: FastAPI request object\n saga_service: Saga service from DI\n auth_service: Auth service from DI\n state: Optional state filter\n limit: Maximum number of results\n offset: Number of results to skip\n\nReturns:\n Paginated list of sagas",
+ "description": "List sagas accessible by the current user.\n\nArgs:\n request: FastAPI request object\n saga_service: Saga service from DI\n auth_service: Auth service from DI\n state: Optional state filter\n limit: Maximum number of results\n skip: Number of results to skip\n\nReturns:\n Paginated list of sagas",
"operationId": "list_sagas_api_v1_sagas__get",
"parameters": [
{
@@ -4173,14 +4241,14 @@
}
},
{
- "name": "offset",
+ "name": "skip",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"minimum": 0,
"default": 0,
- "title": "Offset"
+ "title": "Skip"
}
}
],
@@ -4737,9 +4805,54 @@
"type": "number",
"title": "Age Seconds"
},
- "details": {
- "type": "object",
- "title": "Details"
+ "producer_id": {
+ "type": "string",
+ "title": "Producer Id"
+ },
+ "dlq_offset": {
+ "anyOf": [
+ {
+ "type": "integer"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Dlq Offset"
+ },
+ "dlq_partition": {
+ "anyOf": [
+ {
+ "type": "integer"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Dlq Partition"
+ },
+ "last_error": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Last Error"
+ },
+ "next_retry_at": {
+ "anyOf": [
+ {
+ "type": "string",
+ "format": "date-time"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Next Retry At"
}
},
"type": "object",
@@ -4752,7 +4865,7 @@
"failed_at",
"status",
"age_seconds",
- "details"
+ "producer_id"
],
"title": "DLQMessageResponse",
"description": "Response model for a DLQ message."
@@ -5385,7 +5498,7 @@
"title": "End Time",
"description": "Filter events before this time"
},
- "text_search": {
+ "search_text": {
"anyOf": [
{
"type": "string"
@@ -5394,7 +5507,7 @@
"type": "null"
}
],
- "title": "Text Search",
+ "title": "Search Text",
"description": "Full-text search in event data"
},
"sort_by": {
@@ -5464,6 +5577,68 @@
],
"title": "EventListResponse"
},
+ "EventMetadataResponse": {
+ "properties": {
+ "service_name": {
+ "type": "string",
+ "title": "Service Name"
+ },
+ "service_version": {
+ "type": "string",
+ "title": "Service Version"
+ },
+ "correlation_id": {
+ "type": "string",
+ "title": "Correlation Id"
+ },
+ "user_id": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "User Id"
+ },
+ "ip_address": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Ip Address"
+ },
+ "user_agent": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "User Agent"
+ },
+ "environment": {
+ "type": "string",
+ "title": "Environment",
+ "default": "production"
+ }
+ },
+ "type": "object",
+ "required": [
+ "service_name",
+ "service_version",
+ "correlation_id"
+ ],
+ "title": "EventMetadataResponse",
+ "description": "Pydantic schema for event metadata in API responses."
+ },
"EventReplayRequest": {
"properties": {
"event_ids": {
@@ -5660,20 +5835,19 @@
],
"title": "Completed At"
},
- "error": {
+ "errors": {
"anyOf": [
{
- "type": "string"
+ "items": {
+ "$ref": "#/components/schemas/ReplayErrorInfo"
+ },
+ "type": "array"
},
{
"type": "null"
}
],
- "title": "Error"
- },
- "progress_percentage": {
- "type": "number",
- "title": "Progress Percentage"
+ "title": "Errors"
},
"estimated_completion": {
"anyOf": [
@@ -5691,7 +5865,7 @@
"anyOf": [
{
"items": {
- "type": "object"
+ "$ref": "#/components/schemas/ExecutionResult"
},
"type": "array"
},
@@ -5700,6 +5874,11 @@
}
],
"title": "Execution Results"
+ },
+ "progress_percentage": {
+ "type": "number",
+ "title": "Progress Percentage",
+ "readOnly": true
}
},
"type": "object",
@@ -5769,8 +5948,7 @@
"title": "Causation Id"
},
"metadata": {
- "type": "object",
- "title": "Metadata"
+ "$ref": "#/components/schemas/EventMetadataResponse"
},
"payload": {
"type": "object",
@@ -5822,7 +6000,7 @@
},
"events_by_hour": {
"items": {
- "type": "object"
+ "$ref": "#/components/schemas/HourlyEventCountSchema"
},
"type": "array",
"title": "Events By Hour"
@@ -5899,14 +6077,14 @@
},
"events_by_hour": {
"items": {
- "type": "object"
+ "$ref": "#/components/schemas/HourlyEventCountSchema"
},
"type": "array",
"title": "Events By Hour"
},
"top_users": {
"items": {
- "type": "object"
+ "$ref": "#/components/schemas/UserEventCountSchema"
},
"type": "array",
"title": "Top Users"
@@ -6382,6 +6560,25 @@
"type": "object",
"title": "HTTPValidationError"
},
+ "HourlyEventCountSchema": {
+ "properties": {
+ "hour": {
+ "type": "string",
+ "title": "Hour"
+ },
+ "count": {
+ "type": "integer",
+ "title": "Count"
+ }
+ },
+ "type": "object",
+ "required": [
+ "hour",
+ "count"
+ ],
+ "title": "HourlyEventCountSchema",
+ "description": "Hourly event count for statistics."
+ },
"LanguageInfo": {
"properties": {
"versions": {
@@ -7305,12 +7502,53 @@
},
"type": "object",
"required": [
- "replay_type",
- "filter"
+ "replay_type"
],
"title": "ReplayConfigSchema"
},
- "ReplayFilterSchema": {
+ "ReplayErrorInfo": {
+ "properties": {
+ "timestamp": {
+ "type": "string",
+ "format": "date-time",
+ "title": "Timestamp"
+ },
+ "error": {
+ "type": "string",
+ "title": "Error"
+ },
+ "event_id": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Event Id"
+ },
+ "error_type": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Error Type"
+ }
+ },
+ "type": "object",
+ "required": [
+ "timestamp",
+ "error"
+ ],
+ "title": "ReplayErrorInfo",
+ "description": "Error info for replay operations."
+ },
+ "ReplayFilter": {
"properties": {
"execution_id": {
"anyOf": [
@@ -7327,7 +7565,7 @@
"anyOf": [
{
"items": {
- "type": "string"
+ "$ref": "#/components/schemas/EventType"
},
"type": "array"
},
@@ -7398,7 +7636,7 @@
"anyOf": [
{
"items": {
- "type": "string"
+ "$ref": "#/components/schemas/EventType"
},
"type": "array"
},
@@ -7410,17 +7648,10 @@
}
},
"type": "object",
- "title": "ReplayFilterSchema"
+ "title": "ReplayFilter"
},
- "ReplayRequest": {
+ "ReplayFilterSchema": {
"properties": {
- "replay_type": {
- "$ref": "#/components/schemas/ReplayType"
- },
- "target": {
- "$ref": "#/components/schemas/ReplayTarget",
- "default": "kafka"
- },
"execution_id": {
"anyOf": [
{
@@ -7492,6 +7723,47 @@
],
"title": "Service Name"
},
+ "custom_query": {
+ "anyOf": [
+ {
+ "type": "object"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Custom Query"
+ },
+ "exclude_event_types": {
+ "anyOf": [
+ {
+ "items": {
+ "$ref": "#/components/schemas/EventType"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Exclude Event Types"
+ }
+ },
+ "type": "object",
+ "title": "ReplayFilterSchema"
+ },
+ "ReplayRequest": {
+ "properties": {
+ "replay_type": {
+ "$ref": "#/components/schemas/ReplayType"
+ },
+ "target": {
+ "$ref": "#/components/schemas/ReplayTarget",
+ "default": "kafka"
+ },
+ "filter": {
+ "$ref": "#/components/schemas/ReplayFilter"
+ },
"speed_multiplier": {
"type": "number",
"maximum": 100.0,
@@ -7538,6 +7810,37 @@
}
],
"title": "Target File Path"
+ },
+ "target_topics": {
+ "anyOf": [
+ {
+ "additionalProperties": {
+ "type": "string"
+ },
+ "type": "object"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Target Topics"
+ },
+ "retry_failed": {
+ "type": "boolean",
+ "title": "Retry Failed",
+ "default": false
+ },
+ "retry_attempts": {
+ "type": "integer",
+ "maximum": 10.0,
+ "minimum": 1.0,
+ "title": "Retry Attempts",
+ "default": 3
+ },
+ "enable_progress_tracking": {
+ "type": "boolean",
+ "title": "Enable Progress Tracking",
+ "default": true
}
},
"type": "object",
@@ -7737,52 +8040,24 @@
"ResourceUsage": {
"properties": {
"execution_time_wall_seconds": {
- "anyOf": [
- {
- "type": "number"
- },
- {
- "type": "null"
- }
- ],
+ "type": "number",
"title": "Execution Time Wall Seconds",
- "description": "Wall clock execution time in seconds"
+ "default": 0.0
},
"cpu_time_jiffies": {
- "anyOf": [
- {
- "type": "integer"
- },
- {
- "type": "null"
- }
- ],
+ "type": "integer",
"title": "Cpu Time Jiffies",
- "description": "CPU time in jiffies (multiply by 10 for milliseconds)"
+ "default": 0
},
"clk_tck_hertz": {
- "anyOf": [
- {
- "type": "integer"
- },
- {
- "type": "null"
- }
- ],
+ "type": "integer",
"title": "Clk Tck Hertz",
- "description": "Clock ticks per second (usually 100)"
+ "default": 0
},
"peak_memory_kb": {
- "anyOf": [
- {
- "type": "integer"
- },
- {
- "type": "null"
- }
- ],
+ "type": "integer",
"title": "Peak Memory Kb",
- "description": "Peak memory usage in KB"
+ "default": 0
}
},
"type": "object",
@@ -7983,12 +8258,27 @@
"total": {
"type": "integer",
"title": "Total"
+ },
+ "skip": {
+ "type": "integer",
+ "title": "Skip"
+ },
+ "limit": {
+ "type": "integer",
+ "title": "Limit"
+ },
+ "has_more": {
+ "type": "boolean",
+ "title": "Has More"
}
},
"type": "object",
"required": [
"sagas",
- "total"
+ "total",
+ "skip",
+ "limit",
+ "has_more"
],
"title": "SagaListResponse",
"description": "Response schema for saga list"
@@ -8062,16 +8352,19 @@
},
"created_at": {
"type": "string",
+ "format": "date-time",
"title": "Created At"
},
"updated_at": {
"type": "string",
+ "format": "date-time",
"title": "Updated At"
},
"completed_at": {
"anyOf": [
{
- "type": "string"
+ "type": "string",
+ "format": "date-time"
},
{
"type": "null"
@@ -8197,6 +8490,72 @@
],
"title": "SavedScriptResponse"
},
+ "SavedScriptUpdate": {
+ "properties": {
+ "name": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Name"
+ },
+ "script": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Script"
+ },
+ "lang": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Lang"
+ },
+ "lang_version": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Lang Version"
+ },
+ "description": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Description"
+ },
+ "updated_at": {
+ "type": "string",
+ "format": "date-time",
+ "title": "Updated At"
+ }
+ },
+ "type": "object",
+ "title": "SavedScriptUpdate"
+ },
"SecuritySettingsSchema": {
"properties": {
"password_min_length": {
@@ -8305,7 +8664,8 @@
"type": "null"
}
],
- "title": "Duration Seconds"
+ "title": "Duration Seconds",
+ "readOnly": true
},
"throughput_events_per_second": {
"anyOf": [
@@ -8316,7 +8676,8 @@
"type": "null"
}
],
- "title": "Throughput Events Per Second"
+ "title": "Throughput Events Per Second",
+ "readOnly": true
}
},
"type": "object",
@@ -8331,7 +8692,9 @@
"skipped_events",
"created_at",
"started_at",
- "completed_at"
+ "completed_at",
+ "duration_seconds",
+ "throughput_events_per_second"
],
"title": "SessionSummary",
"description": "Summary information for replay sessions"
@@ -8399,18 +8762,18 @@
"type": "array",
"title": "History"
},
- "total": {
+ "limit": {
"type": "integer",
- "title": "Total"
+ "title": "Limit"
}
},
"type": "object",
"required": [
"history",
- "total"
+ "limit"
],
"title": "SettingsHistoryResponse",
- "description": "Response model for settings history"
+ "description": "Response model for settings history (limited snapshot of recent changes)"
},
"ShutdownStatusResponse": {
"properties": {
@@ -8700,6 +9063,25 @@
"title": "UserCreate",
"description": "Model for creating a new user"
},
+ "UserEventCountSchema": {
+ "properties": {
+ "user_id": {
+ "type": "string",
+ "title": "User Id"
+ },
+ "event_count": {
+ "type": "integer",
+ "title": "Event Count"
+ }
+ },
+ "type": "object",
+ "required": [
+ "user_id",
+ "event_count"
+ ],
+ "title": "UserEventCountSchema",
+ "description": "User event count schema"
+ },
"UserListResponse": {
"properties": {
"users": {
diff --git a/frontend/src/components/admin/events/ReplayProgressBanner.svelte b/frontend/src/components/admin/events/ReplayProgressBanner.svelte
index f662b5c5..2c145959 100644
--- a/frontend/src/components/admin/events/ReplayProgressBanner.svelte
+++ b/frontend/src/components/admin/events/ReplayProgressBanner.svelte
@@ -32,40 +32,23 @@
Progress: {session.replayed_events} / {session.total_events} events
{session.progress_percentage}%
-
Execution Results: