diff --git a/.envs/.ci/.django b/.envs/.ci/.django index bec17501b..6577ebece 100644 --- a/.envs/.ci/.django +++ b/.envs/.ci/.django @@ -26,3 +26,7 @@ CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672/ CELERY_RESULT_BACKEND=rpc:// # Use RabbitMQ for results backend RABBITMQ_DEFAULT_USER=rabbituser RABBITMQ_DEFAULT_PASS=rabbitpass + +# NATS +# ------------------------------------------------------------------------------ +NATS_URL=nats://nats:4222 diff --git a/.envs/.local/.django b/.envs/.local/.django index 29780e680..8eb5610f7 100644 --- a/.envs/.local/.django +++ b/.envs/.local/.django @@ -12,6 +12,9 @@ DJANGO_SUPERUSER_PASSWORD=localadmin # Redis REDIS_URL=redis://redis:6379/0 +# NATS +NATS_URL=nats://nats:4222 + # Celery / Flower CELERY_FLOWER_USER=QSocnxapfMvzLqJXSsXtnEZqRkBtsmKT CELERY_FLOWER_PASSWORD=BEQgmCtgyrFieKNoGTsux9YIye0I7P5Q7vEgfJD2C4jxmtHDetFaE2jhS7K7rxaf diff --git a/.envs/.production/.django-example b/.envs/.production/.django-example index a54d4ae60..93737d527 100644 --- a/.envs/.production/.django-example +++ b/.envs/.production/.django-example @@ -65,3 +65,7 @@ WEB_CONCURRENCY=4 DEFAULT_PROCESSING_SERVICE_NAME="AMI Data Companion" DEFAULT_PROCESSING_SERVICE_ENDPOINT=https://ml.antenna.insectai.org/ DEFAULT_PIPELINES_ENABLED=global_moths_2024,quebec_vermont_moths_2023,panama_moths_2023,uk_denmark_moths_2023 + +# NATS +# ------------------------------------------------------------------------------ +NATS_URL=nats://nats:4222 diff --git a/README.md b/README.md index 16ddbb07f..4f1a00b2a 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ docker compose -f processing_services/example/docker-compose.yml up -d - Django admin: http://localhost:8000/admin/ - OpenAPI / Swagger documentation: http://localhost:8000/api/v2/docs/ - Minio UI: http://minio:9001, Minio service: http://minio:9000 +- NATS dashboard: https://natsdashboard.com/ (Add localhost) NOTE: If one of these services is not working properly, it could be due another process is using the port. You can check for this with `lsof -i :`. diff --git a/ami/jobs/models.py b/ami/jobs/models.py index b94baa9a2..482d01a58 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -322,15 +322,13 @@ def run(cls, job: "Job"): """ Procedure for an ML pipeline as a job. """ + from ami.ml.orchestration.jobs import queue_images_to_nats + job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None job.save() - # Keep track of sub-tasks for saving results, pair with batch number - save_tasks: list[tuple[int, AsyncResult]] = [] - save_tasks_completed: list[tuple[int, AsyncResult]] = [] - if job.delay: update_interval_seconds = 2 last_update = time.time() @@ -365,7 +363,7 @@ def run(cls, job: "Job"): progress=0, ) - images = list( + images: list[SourceImage] = list( # @TODO return generator plus image count # @TODO pass to celery group chain? job.pipeline.collect_images( @@ -389,8 +387,6 @@ def run(cls, job: "Job"): images = images[: job.limit] image_count = len(images) job.progress.add_stage_param("collect", "Limit", image_count) - else: - image_count = source_image_count job.progress.update_stage( "collect", @@ -401,6 +397,24 @@ def run(cls, job: "Job"): # End image collection stage job.save() + if job.project.feature_flags.async_pipeline_workers: + queued = queue_images_to_nats(job, images) + if not queued: + job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) + job.progress.update_stage("collect", status=JobState.FAILURE) + job.update_status(JobState.FAILURE) + job.finished_at = datetime.datetime.now() + job.save() + return + else: + cls.process_images(job, images) + + @classmethod + def process_images(cls, job, images): + image_count = len(images) + # Keep track of sub-tasks for saving results, pair with batch number + save_tasks: list[tuple[int, AsyncResult]] = [] + save_tasks_completed: list[tuple[int, AsyncResult]] = [] total_captures = 0 total_detections = 0 total_classifications = 0 diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 30c594141..9e224a1ed 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,7 +1,16 @@ +import datetime +import functools import logging +import time +from collections.abc import Callable +from asgiref.sync import async_to_sync from celery.signals import task_failure, task_postrun, task_prerun +from django.db import transaction +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.schemas import PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app @@ -29,6 +38,132 @@ def run_job(self, job_id: int) -> None: job.logger.info(f"Finished job {job}") +@celery_app.task( + bind=True, + max_retries=0, # don't retry since we already have retry logic in the NATS queue + soft_time_limit=300, # 5 minutes + time_limit=360, # 6 minutes +) +def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None: + """ + Process a single pipeline result asynchronously. + + This task: + 1. Deserializes the pipeline result + 2. Saves it to the database + 3. Updates progress by removing processed image IDs from Redis + 4. Acknowledges the task via NATS + + Args: + job_id: The job ID + result_data: Dictionary containing the pipeline result + reply_subject: NATS reply subject for acknowledgment + """ + from ami.jobs.models import Job # avoid circular import + + _, t = log_time() + error = result_data.get("error") + pipeline_result = None + if not error: + pipeline_result = PipelineResultsResponse(**result_data) + processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + else: + image_id = result_data.get("image_id") + processed_image_ids = {str(image_id)} if image_id else set() + logger.error(f"Pipeline returned error for job {job_id}, image {image_id}: {error}") + + state_manager = TaskStateManager(job_id) + + progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + + try: + _update_job_progress(job_id, "process", progress_info.percentage) + + _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") + job = Job.objects.get(pk=job_id) + job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") + job.logger.info( + f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " + "processed" + ) + except Job.DoesNotExist: + # don't raise and ack so that we don't retry since the job doesn't exists + logger.error(f"Job {job_id} not found") + _ack_task_via_nats(reply_subject, logger) + return + + try: + # Save to database (this is the slow operation) + if pipeline_result: + # should never happen since otherwise we could not be processing results here + assert job.pipeline is not None, "Job pipeline is None" + job.pipeline.save_results(results=pipeline_result, job_id=job.pk) + job.logger.info(f"Successfully saved results for job {job_id}") + + _, t = t( + f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" + f", percentage: {progress_info.percentage*100}%" + ) + + _ack_task_via_nats(reply_subject, job.logger) + # Update job stage with calculated progress + progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + _update_job_progress(job_id, "results", progress_info.percentage) + + except Exception as e: + job.logger.error(f"Failed to process pipeline result for job {job_id}: {e}. Retrying ...") + + +def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: + try: + + async def ack_task(): + async with TaskQueueManager() as manager: + return await manager.acknowledge_task(reply_subject) + + ack_success = async_to_sync(ack_task)() + + if ack_success: + job_logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") + else: + job_logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") + except Exception as ack_error: + job_logger.error(f"Error acknowledging task via NATS: {ack_error}") + # Don't fail the task if ACK fails - data is already saved + + +def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: + from ami.jobs.models import Job, JobState # avoid circular import + + with transaction.atomic(): + job = Job.objects.select_for_update().get(pk=job_id) + job.progress.update_stage( + stage, + status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + progress=progress_percentage, + ) + if stage == "results" and progress_percentage >= 1.0: + job.status = JobState.SUCCESS + job.progress.summary.status = JobState.SUCCESS + job.finished_at = datetime.datetime.now() # Use naive datetime in local time + job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") + job.save() + + @task_prerun.connect(sender=run_job) def pre_update_job_status(sender, task_id, task, **kwargs): # in the prerun signal, set the job status to PENDING @@ -65,3 +200,28 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}') job.save() + + +def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: + """ + Small helper to measure time between calls. + + Returns: elapsed time since the last call, and a partial function to measure from the current call + Usage: + + _, tlog = log_time() + # do something + _, tlog = tlog("Did something") # will log the time taken by 'something' + # do something else + t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' + """ + + end = time.perf_counter() + if start == 0: + dur = 0.0 + else: + dur = end - start + if msg and start > 0: + logger.info(f"{msg}: {dur:.3f}s") + new_start = time.perf_counter() + return dur, functools.partial(log_time, new_start) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 4d61a9ea4..8e04d9dd9 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -11,6 +11,7 @@ from ami.jobs.models import Job, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.orchestration.jobs import queue_images_to_nats from ami.users.models import User logger = logging.getLogger(__name__) @@ -326,6 +327,15 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() job = self._create_ml_job("Job for batch test", pipeline) + images = [ + SourceImage.objects.create( + path=f"image_{i}.jpg", + public_base_url="http://example.com", + project=self.project, + ) + for i in range(8) # more than 5 since we test with batch=5 + ] + queue_images_to_nats(job, images) self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params( @@ -390,10 +400,9 @@ def test_result_endpoint_stub(self): self.assertEqual(resp.status_code, 200) data = resp.json() - self.assertEqual(data["status"], "received") + self.assertEqual(data["status"], "accepted") self.assertEqual(data["job_id"], job.pk) - self.assertEqual(data["results_received"], 1) - self.assertIn("message", data) + self.assertEqual(data["results_queued"], 1) def test_result_endpoint_validation(self): """Test the result endpoint validates request data.""" diff --git a/ami/jobs/views.py b/ami/jobs/views.py index fb94fd60b..a00052342 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,6 +1,7 @@ import logging import pydantic +from asgiref.sync import async_to_sync from django.db.models import Q from django.db.models.query import QuerySet from django.forms import IntegerField @@ -15,11 +16,12 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param +from ami.jobs.tasks import process_pipeline_result from ami.main.api.schemas import project_id_doc_param # from ami.jobs.tasks import process_pipeline_result # TODO: Uncomment when available in main from ami.main.api.views import DefaultViewSet -from ami.ml.schemas import PipelineProcessingTask, PipelineTaskResult +from ami.ml.schemas import PipelineTaskResult from ami.utils.fields import url_boolean_param from .models import Job, JobState @@ -232,33 +234,34 @@ def tasks(self, request, pk=None): if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") - # TODO: Implement task queue integration - logger.warning(f"Task queue endpoint called for job {job.pk} but the implementation is not yet available.") + # Get tasks from NATS JetStream + from ami.ml.orchestration.nats_queue import TaskQueueManager - dummy_task = PipelineProcessingTask( - id="1", - image_id="1", - image_url="http://example.com/image1", - queue_timestamp=timezone.now().isoformat(), - ) + async def get_tasks(): + tasks = [] + async with TaskQueueManager() as manager: + for _ in range(batch): + task = await manager.reserve_task(job.pk, timeout=0.1) + if task: + tasks.append(task.dict()) + return tasks + + # Use async_to_sync to properly handle the async call + tasks = async_to_sync(get_tasks)() - # @TODO when this gets fully implemented, use a Serializer or Pydantic schema - # for the full repsponse structure. - return Response({"tasks": [task.dict() for task in [dummy_task] * batch]}) + return Response({"tasks": tasks}) @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ - Submit pipeline results for asynchronous processing. + The request body should be a list of results: list[PipelineTaskResult] This endpoint accepts a list of pipeline results and queues them for - background processing. Each result will be validated and saved. - - The request body should be a list of results: list[PipelineTaskResult] + background processing. Each result will be validated, saved to the database, + and acknowledged via NATS in a Celery task. """ job = self.get_object() - job_id = job.pk # Validate request data is a list if isinstance(request.data, list): @@ -266,33 +269,50 @@ def result(self, request, pk=None): else: results = [request.data] + queued_tasks = [] try: - queued_tasks = [] + # Queue each result for background processing for item in results: task_result = PipelineTaskResult(**item) - # Stub: Log that we received the result but don't process it yet - logger.warning( - f"Result endpoint called for job {job_id} (reply_subject: {task_result.reply_subject}) " - "but result processing not yet available." + reply_subject = task_result.reply_subject + result_data = task_result.result + + # Queue the background task + # Convert Pydantic model to dict for JSON serialization + task = process_pipeline_result.delay( + job_id=job.pk, result_data=result_data.dict(), reply_subject=reply_subject ) - # TODO: Implement result storage and processing queued_tasks.append( { - "reply_subject": task_result.reply_subject, - "status": "pending_implementation", - "message": "Result processing not yet implemented.", + "reply_subject": reply_subject, + "status": "queued", + "task_id": task.id, } ) + + logger.info( + f"Queued pipeline result processing for job {job.pk}, " + f"task_id: {task.id}, reply_subject: {reply_subject}" + ) + + return Response( + { + "status": "accepted", + "job_id": job.pk, + "results_queued": len([t for t in queued_tasks if t["status"] == "queued"]), + "tasks": queued_tasks, + } + ) except pydantic.ValidationError as e: raise ValidationError(f"Invalid result data: {e}") from e - return Response( - { - "status": "received", - "job_id": job_id, - "results_received": len(queued_tasks), - "tasks": queued_tasks, - "message": "Result processing not yet implemented.", - } - ) + except Exception as e: + logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}") + return Response( + { + "status": "error", + "job_id": job.pk, + }, + status=500, + ) diff --git a/ami/main/models.py b/ami/main/models.py index 515f5286a..0e45fc0b5 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -218,6 +218,7 @@ class ProjectFeatureFlags(pydantic.BaseModel): default_filters: bool = False # Whether to show default filters form in UI # Feature flag for jobs to reprocess all images in the project, even if already processed reprocess_all_images: bool = False + async_pipeline_workers: bool = False # Whether to use async pipeline workers that pull tasks from a queue def get_default_feature_flags() -> ProjectFeatureFlags: diff --git a/ami/ml/orchestration/__init__.py b/ami/ml/orchestration/__init__.py index d05bbbd82..75c2ec3b5 100644 --- a/ami/ml/orchestration/__init__.py +++ b/ami/ml/orchestration/__init__.py @@ -1 +1,5 @@ -from .processing import * # noqa: F401, F403 +# cgjs: This creates a circular import: +# - ami.jobs.models imports ami.jobs.tasks.run_job +# - ami.jobs.tasks imports ami.ml.orchestration +# -.processing imports ami.jobs.models +# from .processing import * # noqa: F401, F403 diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py new file mode 100644 index 000000000..ab2b19b52 --- /dev/null +++ b/ami/ml/orchestration/jobs.py @@ -0,0 +1,102 @@ +from asgiref.sync import async_to_sync + +from ami.jobs.models import Job, JobState, logger +from ami.main.models import SourceImage +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.schemas import PipelineProcessingTask + + +# TODO CGJS: Call this once a job is fully complete (all images processed and saved) +def cleanup_nats_resources(job: "Job") -> bool: + """ + Clean up NATS JetStream resources (stream and consumer) for a completed job. + + Args: + job: The Job instance + Returns: + bool: True if cleanup was successful, False otherwise + """ + + async def cleanup(): + async with TaskQueueManager() as manager: + return await manager.cleanup_job_resources(job.pk) + + return async_to_sync(cleanup)() + + +def queue_images_to_nats(job: "Job", images: list[SourceImage]): + """ + Queue all images for a job to a NATS JetStream stream for the job. + + Args: + job: The Job instance + images: List of SourceImage instances to queue + + Returns: + bool: True if all images were successfully queued, False otherwise + """ + job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job.pk}'") + + # Prepare all messages outside of async context to avoid Django ORM issues + tasks: list[tuple[int, PipelineProcessingTask]] = [] + image_ids = [] + for image in images: + image_id = str(image.pk) + image_url = image.url() if hasattr(image, "url") and image.url() else "" + if not image_url: + job.logger.warning(f"Image {image.pk} has no URL, skipping queuing to NATS for job '{job.pk}'") + continue + image_ids.append(image_id) + task = PipelineProcessingTask( + id=image_id, + image_id=image_id, + image_url=image_url, + ) + tasks.append((image.pk, task)) + + # Store all image IDs in Redis for progress tracking + state_manager = TaskStateManager(job.pk) + state_manager.initialize_job(image_ids) + job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") + + async def queue_all_images(): + successful_queues = 0 + failed_queues = 0 + + async with TaskQueueManager() as manager: + for image_pk, task in tasks: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + success = await manager.publish_task( + job_id=job.pk, + data=task, + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 + + return successful_queues, failed_queues + + if tasks: + successful_queues, failed_queues = async_to_sync(queue_all_images)() + else: + job.progress.update_stage("process", status=JobState.SUCCESS, progress=1.0) + job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) + job.save() + successful_queues, failed_queues = 0, 0 + + # Log results (back in sync context) + if successful_queues > 0: + job.logger.info(f"Successfully queued {successful_queues}/{len(images)} images to stream for job '{job.pk}'") + + if failed_queues > 0: + job.logger.warning(f"Failed to queue {failed_queues}/{len(images)} images to stream for job '{job.pk}'") + return False + + return True diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py new file mode 100644 index 000000000..9436d1df5 --- /dev/null +++ b/ami/ml/orchestration/nats_queue.py @@ -0,0 +1,297 @@ +""" +NATS JetStream utility for task queue management in the antenna project. + +This module provides a TaskQueueManager that uses NATS JetStream for distributed +task queuing with acknowledgment support via reply subjects. This allows workers +to pull tasks over HTTP and acknowledge them later without maintaining a persistent +connection to NATS. +""" + +import json +import logging + +import nats +from django.conf import settings +from nats.js import JetStreamContext +from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy + +from ami.ml.schemas import PipelineProcessingTask + +logger = logging.getLogger(__name__) + + +async def get_connection(nats_url: str): + nc = await nats.connect(nats_url) + js = nc.jetstream() + return nc, js + + +TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds + + +class TaskQueueManager: + """ + Manager for NATS JetStream task queue operations. + + Use as an async context manager: + async with TaskQueueManager() as manager: + await manager.publish_task('job123', {'data': 'value'}) + task = await manager.reserve_task('job123') + await manager.acknowledge_task(task['reply_subject']) + """ + + def __init__(self, nats_url: str | None = None): + self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") + self.nc: nats.NATS | None = None + self.js: JetStreamContext | None = None + + async def __aenter__(self): + """Create connection on enter.""" + self.nc, self.js = await get_connection(self.nats_url) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.js: + self.js = None + if self.nc and not self.nc.is_closed: + await self.nc.close() + self.nc = None + + return False + + def _get_stream_name(self, job_id: int) -> str: + """Get stream name from job_id.""" + return f"job_{job_id}" + + def _get_subject(self, job_id: int) -> str: + """Get subject name from job_id.""" + return f"job.{job_id}.tasks" + + def _get_consumer_name(self, job_id: int) -> str: + """Get consumer name from job_id.""" + return f"job-{job_id}-consumer" + + async def _ensure_stream(self, job_id: int): + """Ensure stream exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + subject = self._get_subject(job_id) + + try: + await self.js.stream_info(stream_name) + logger.debug(f"Stream {stream_name} already exists") + except Exception as e: + logger.warning(f"Stream {stream_name} does not exist: {e}") + # Stream doesn't exist, create it + await self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ) + logger.info(f"Created stream {stream_name}") + + async def _ensure_consumer(self, job_id: int): + """Ensure consumer exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + try: + info = await self.js.consumer_info(stream_name, consumer_name) + logger.debug(f"Consumer {consumer_name} already exists: {info}") + except Exception: + # Consumer doesn't exist, create it + await self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=TASK_TTR, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=100, # Max unacked messages + filter_subject=subject, + ), + ) + logger.info(f"Created consumer {consumer_name}") + + async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: + """ + Publish a task to it's job queue. + + Args: + job_id: The job ID (integer primary key) + data: PipelineProcessingTask object to be published + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + subject = self._get_subject(job_id) + # Convert Pydantic model to JSON + task_data = json.dumps(data.dict()) + + # Publish to JetStream + ack = await self.js.publish(subject, task_data.encode()) + + logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") + return True + + except Exception as e: + logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") + return False + + async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: + """ + Reserve a task from the specified stream. + + Args: + job_id: The job ID (integer primary key) to pull tasks from + timeout: Timeout in seconds for reservation (default: 5 seconds) + + Returns: + PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + if timeout is None: + timeout = 5 + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + # Create ephemeral subscription for this pull + psub = await self.js.pull_subscribe(subject, consumer_name) + + try: + # Fetch a single message + msgs = await psub.fetch(1, timeout=timeout) + + if msgs: + msg = msgs[0] + task_data = json.loads(msg.data.decode()) + metadata = msg.metadata + + # Parse the task data into PipelineProcessingTask + task = PipelineProcessingTask(**task_data) + # Set the reply_subject for acknowledgment + task.reply_subject = msg.reply + + logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") + return task + + except nats.errors.TimeoutError: + # No messages available + logger.debug(f"No tasks available in stream for job '{job_id}'") + return None + finally: + # Always unsubscribe + await psub.unsubscribe() + + except Exception as e: + logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") + return None + + async def acknowledge_task(self, reply_subject: str) -> bool: + """ + Acknowledge (delete) a completed task using its reply subject. + + Args: + reply_subject: The reply subject from reserve_task + + Returns: + bool: True if successful + """ + if self.nc is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + await self.nc.publish(reply_subject, b"+ACK") + logger.debug(f"Acknowledged task with reply subject {reply_subject}") + return True + except Exception as e: + logger.error(f"Failed to acknowledge task: {e}") + return False + + async def delete_consumer(self, job_id: int) -> bool: + """ + Delete the consumer for a job. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + + await self.js.delete_consumer(stream_name, consumer_name) + logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete consumer for job '{job_id}': {e}") + return False + + async def delete_stream(self, job_id: int) -> bool: + """ + Delete the stream for a job. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + stream_name = self._get_stream_name(job_id) + + await self.js.delete_stream(stream_name) + logger.info(f"Deleted stream {stream_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete stream for job '{job_id}': {e}") + return False + + async def cleanup_job_resources(self, job_id: int) -> bool: + """ + Clean up all NATS resources (consumer and stream) for a job. + + This should be called when a job completes or is cancelled. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + # Delete consumer first, then stream + consumer_deleted = await self.delete_consumer(job_id) + stream_deleted = await self.delete_stream(job_id) + + return consumer_deleted and stream_deleted diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py new file mode 100644 index 000000000..483275453 --- /dev/null +++ b/ami/ml/orchestration/task_state.py @@ -0,0 +1,125 @@ +""" +Task state management for job progress tracking using Redis. +""" + +import logging +from collections import namedtuple + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +# Define a namedtuple for a TaskProgress with the image counts +TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) + + +class TaskStateManager: + """ + Manages job progress tracking state in Redis. + + Tracks pending images for jobs to calculate progress percentages + as workers process images asynchronously. + """ + + TIMEOUT = 86400 * 7 # 7 days in seconds + STAGES = ["process", "results"] + + def __init__(self, job_id: int): + """ + Initialize the task state manager for a specific job. + + Args: + job_id: The job primary key + """ + self.job_id = job_id + self._pending_key = f"job:{job_id}:pending_images" + self._total_key = f"job:{job_id}:pending_images_total" + + def initialize_job(self, image_ids: list[str]) -> None: + """ + Initialize job tracking with a list of image IDs to process. + + Args: + image_ids: List of image IDs that need to be processed + """ + for stage in self.STAGES: + cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + + def _get_pending_key(self, stage: str) -> str: + return f"{self._pending_key}:{stage}" + + def update_state( + self, + processed_image_ids: set[str], + stage: str, + request_id: str, + ) -> None | TaskProgress: + """ + Update the task state with newly processed images. + + Args: + processed_image_ids: Set of image IDs that have just been processed + """ + # Create a unique lock key for this job + lock_key = f"job:{self.job_id}:process_results_lock" + lock_timeout = 360 # 6 minutes (matches task time_limit) + lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) + if not lock_acquired: + return None + + try: + # Update progress tracking in Redis + progress_info = self._get_progress(processed_image_ids, stage) + return progress_info + finally: + # Always release the lock when done + current_lock_value = cache.get(lock_key) + # Only delete if we still own the lock (prevents race condition) + if current_lock_value == request_id: + cache.delete(lock_key) + logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + + def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: + """ + Get current progress information for the job. + + Returns: + TaskProgress namedtuple with fields: + - remaining: Number of images still pending (or None if not tracked) + - total: Total number of images (or None if not tracked) + - processed: Number of images processed (or None if not tracked) + - percentage: Progress as float 0.0-1.0 (or None if not tracked) + """ + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + assert len(pending_images) >= len(remaining_images) + cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) + + remaining = len(remaining_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " + f"{remaining}/{total_images}: {percentage*100}%" + ) + + return TaskProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + ) + + def cleanup(self) -> None: + """ + Delete all Redis keys associated with this job. + """ + for stage in self.STAGES: + cache.delete(self._get_pending_key(stage)) + cache.delete(self._total_key) diff --git a/ami/ml/orchestration/test_nats_queue.py b/ami/ml/orchestration/test_nats_queue.py new file mode 100644 index 000000000..0cd2c3bef --- /dev/null +++ b/ami/ml/orchestration/test_nats_queue.py @@ -0,0 +1,150 @@ +"""Unit tests for TaskQueueManager.""" + +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.schemas import PipelineProcessingTask + + +class TestTaskQueueManager(unittest.IsolatedAsyncioTestCase): + """Test suite for TaskQueueManager.""" + + def _create_sample_task(self): + """Helper to create a sample PipelineProcessingTask.""" + return PipelineProcessingTask( + id="task-123", + image_id="img-456", + image_url="https://example.com/image.jpg", + ) + + def _create_mock_nats_connection(self): + """Helper to create mock NATS connection and JetStream context.""" + nc = MagicMock() + nc.is_closed = False + nc.close = AsyncMock() + + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.add_consumer = AsyncMock() + js.consumer_info = AsyncMock() + js.publish = AsyncMock(return_value=MagicMock(seq=1)) + js.pull_subscribe = AsyncMock() + js.delete_consumer = AsyncMock() + js.delete_stream = AsyncMock() + + return nc, js + + async def test_context_manager_lifecycle(self): + """Test that context manager properly opens and closes connections.""" + nc, js = self._create_mock_nats_connection() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager("nats://test:4222") as manager: + self.assertIsNotNone(manager.nc) + self.assertIsNotNone(manager.js) + + nc.close.assert_called_once() + + async def test_publish_task_creates_stream_and_consumer(self): + """Test that publish_task ensures stream and consumer exist.""" + nc, js = self._create_mock_nats_connection() + sample_task = self._create_sample_task() + js.stream_info.side_effect = Exception("Not found") + js.consumer_info.side_effect = Exception("Not found") + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + await manager.publish_task(456, sample_task) + + js.add_stream.assert_called_once() + self.assertIn("job_456", str(js.add_stream.call_args)) + js.add_consumer.assert_called_once() + + async def test_reserve_task_success(self): + """Test successful task reservation.""" + nc, js = self._create_mock_nats_connection() + sample_task = self._create_sample_task() + + # Mock message with task data + mock_msg = MagicMock() + mock_msg.data = sample_task.json().encode() + mock_msg.reply = "reply.subject.123" + mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + task = await manager.reserve_task(123) + + self.assertIsNotNone(task) + self.assertEqual(task.id, sample_task.id) + self.assertEqual(task.reply_subject, "reply.subject.123") + mock_psub.unsubscribe.assert_called_once() + + async def test_reserve_task_no_messages(self): + """Test reserve_task when no messages are available.""" + nc, js = self._create_mock_nats_connection() + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + task = await manager.reserve_task(123) + + self.assertIsNone(task) + mock_psub.unsubscribe.assert_called_once() + + async def test_acknowledge_task_success(self): + """Test successful task acknowledgment.""" + nc, js = self._create_mock_nats_connection() + nc.publish = AsyncMock() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.acknowledge_task("reply.subject.123") + + self.assertTrue(result) + nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") + + async def test_cleanup_job_resources(self): + """Test cleanup of job resources (consumer and stream).""" + nc, js = self._create_mock_nats_connection() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.cleanup_job_resources(123) + + self.assertTrue(result) + js.delete_consumer.assert_called_once() + js.delete_stream.assert_called_once() + + async def test_naming_conventions(self): + """Test stream, subject, and consumer naming conventions.""" + manager = TaskQueueManager() + + self.assertEqual(manager._get_stream_name(123), "job_123") + self.assertEqual(manager._get_subject(123), "job.123.tasks") + self.assertEqual(manager._get_consumer_name(123), "job-123-consumer") + + async def test_operations_without_connection_raise_error(self): + """Test that operations without connection raise RuntimeError.""" + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + with self.assertRaisesRegex(RuntimeError, "Connection is not open"): + await manager.publish_task(123, sample_task) + + with self.assertRaisesRegex(RuntimeError, "Connection is not open"): + await manager.reserve_task(123) + + with self.assertRaisesRegex(RuntimeError, "Connection is not open"): + await manager.delete_stream(123) diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 478b4c8fd..2a17c3daa 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -222,7 +222,6 @@ class PipelineProcessingTask(pydantic.BaseModel): id: str image_id: str image_url: str - queue_timestamp: str reply_subject: str | None = None # The NATS subject to send the result to # TODO: Do we need these? # detections: list[DetectionRequest] | None = None diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 14e4374f2..20e0368fe 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -855,3 +855,127 @@ def test_small_size_filter_assigns_not_identifiable(self): not_identifiable_taxon, f"Occurrence {occurrence.pk} should have its determination set to 'Not identifiable'.", ) + + +class TestTaskStateManager(TestCase): + """Test TaskStateManager for job progress tracking.""" + + def setUp(self): + """Set up test fixtures.""" + from django.core.cache import cache + + from ami.ml.orchestration.task_state import TaskStateManager + + cache.clear() + self.job_id = 123 + self.manager = TaskStateManager(self.job_id) + self.image_ids = ["img1", "img2", "img3", "img4", "img5"] + + def _init_and_verify(self, image_ids): + """Helper to initialize job and verify initial state.""" + self.manager.initialize_job(image_ids) + progress = self.manager._get_progress(set(), "process") + assert progress is not None + self.assertEqual(progress.total, len(image_ids)) + self.assertEqual(progress.remaining, len(image_ids)) + self.assertEqual(progress.processed, 0) + self.assertEqual(progress.percentage, 0.0) + return progress + + def test_initialize_job(self): + """Test job initialization sets up tracking for all stages.""" + self._init_and_verify(self.image_ids) + + # Verify both stages are initialized + for stage in self.manager.STAGES: + progress = self.manager._get_progress(set(), stage) + assert progress is not None + self.assertEqual(progress.total, len(self.image_ids)) + + def test_progress_tracking(self): + """Test progress updates correctly as images are processed.""" + self._init_and_verify(self.image_ids) + + # Process 2 images + progress = self.manager._get_progress({"img1", "img2"}, "process") + assert progress is not None + self.assertEqual(progress.remaining, 3) + self.assertEqual(progress.processed, 2) + self.assertEqual(progress.percentage, 0.4) + + # Process 2 more images + progress = self.manager._get_progress({"img3", "img4"}, "process") + assert progress is not None + self.assertEqual(progress.remaining, 1) + self.assertEqual(progress.processed, 4) + self.assertEqual(progress.percentage, 0.8) + + # Process last image + progress = self.manager._get_progress({"img5"}, "process") + assert progress is not None + self.assertEqual(progress.remaining, 0) + self.assertEqual(progress.processed, 5) + self.assertEqual(progress.percentage, 1.0) + + def test_update_state_with_locking(self): + """Test update_state acquires lock, updates progress, and releases lock.""" + from django.core.cache import cache + + self._init_and_verify(self.image_ids) + + # First update should succeed + progress = self.manager.update_state({"img1", "img2"}, "process", "task1") + assert progress is not None + self.assertEqual(progress.processed, 2) + + # Simulate concurrent update by holding the lock + lock_key = f"job:{self.job_id}:process_results_lock" + cache.set(lock_key, "other_task", timeout=60) + + # Update should fail (lock held by another task) + progress = self.manager.update_state({"img3"}, "process", "task1") + self.assertIsNone(progress) + + # Release the lock and retry + cache.delete(lock_key) + progress = self.manager.update_state({"img3"}, "process", "task1") + assert progress is not None + self.assertEqual(progress.processed, 3) + + def test_stages_independent(self): + """Test that different stages track progress independently.""" + self._init_and_verify(self.image_ids) + + # Update process stage + self.manager._get_progress({"img1", "img2"}, "process") + progress_process = self.manager._get_progress(set(), "process") + assert progress_process is not None + self.assertEqual(progress_process.remaining, 3) + + # Results stage should still have all images pending + progress_results = self.manager._get_progress(set(), "results") + assert progress_results is not None + self.assertEqual(progress_results.remaining, 5) + + def test_empty_job(self): + """Test handling of job with no images.""" + self.manager.initialize_job([]) + progress = self.manager._get_progress(set(), "process") + assert progress is not None + self.assertEqual(progress.total, 0) + self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete + + def test_cleanup(self): + """Test cleanup removes all tracking keys.""" + self._init_and_verify(self.image_ids) + + # Verify keys exist + progress = self.manager._get_progress(set(), "process") + self.assertIsNotNone(progress) + + # Cleanup + self.manager.cleanup() + + # Verify keys are gone + progress = self.manager._get_progress(set(), "process") + self.assertIsNone(progress) diff --git a/ami/utils/requests.py b/ami/utils/requests.py index c4396b725..e4de57c0f 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -2,6 +2,8 @@ import requests from django.forms import BooleanField, FloatField +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter from requests.adapters import HTTPAdapter from rest_framework.request import Request from urllib3.util import Retry @@ -142,3 +144,30 @@ def get_default_classification_threshold(project: "Project | None" = None, reque return project.default_filters_score_threshold else: return default_threshold + + +project_id_doc_param = OpenApiParameter( + name="project_id", + description="Filter by project ID", + required=False, + type=int, +) + +ids_only_param = OpenApiParameter( + name="ids_only", + description="Return only job IDs instead of full job objects", + required=False, + type=OpenApiTypes.BOOL, +) +incomplete_only_param = OpenApiParameter( + name="incomplete_only", + description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", + required=False, + type=OpenApiTypes.BOOL, +) +batch_param = OpenApiParameter( + name="batch", + description="Number of tasks to pull in the batch", + required=False, + type=OpenApiTypes.INT, +) diff --git a/config/settings/base.py b/config/settings/base.py index c9a8a9681..d7f0c11e6 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -263,6 +263,10 @@ } REDIS_URL = env("REDIS_URL", default=None) +# NATS +# ------------------------------------------------------------------------------ +NATS_URL = env("NATS_URL", default="nats://localhost:4222") # type: ignore[no-untyped-call] + # ADMIN # ------------------------------------------------------------------------------ # Django Admin URL. diff --git a/docker-compose.ci.yml b/docker-compose.ci.yml index 8e93b684d..57f6fbc9f 100644 --- a/docker-compose.ci.yml +++ b/docker-compose.ci.yml @@ -22,6 +22,7 @@ services: - minio-init - ml_backend - rabbitmq + - nats env_file: - ./.envs/.ci/.django - ./.envs/.ci/.postgres @@ -39,6 +40,17 @@ services: redis: image: redis:6 + nats: + image: nats:2.10-alpine + container_name: ami_ci_nats + hostname: nats + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + celeryworker: <<: *django depends_on: @@ -58,7 +70,7 @@ services: env_file: - ./.envs/.ci/.django healthcheck: - test: [ "CMD", "mc", "ready", "local" ] + test: ["CMD", "mc", "ready", "local"] interval: 5s timeout: 5s retries: 5 diff --git a/docker-compose.staging.yml b/docker-compose.staging.yml index 13525c044..045de9a21 100644 --- a/docker-compose.staging.yml +++ b/docker-compose.staging.yml @@ -5,7 +5,6 @@ # 1. The database is a service in the Docker Compose configuration rather than external as in production. # 2. Redis is a service in the Docker Compose configuration rather than external as in production. # 3. Port 5001 is exposed for the Django application. -version: "3" volumes: ami_local_postgres_data: {} @@ -21,6 +20,7 @@ services: depends_on: - postgres - redis + # - nats env_file: - ./.envs/.production/.django - ./.envs/.local/.postgres @@ -29,6 +29,7 @@ services: ports: - "5001:5000" command: /start + restart: always postgres: build: @@ -42,9 +43,11 @@ services: - ./data/db/snapshots:/backups env_file: - ./.envs/.local/.postgres + restart: always redis: image: redis:6 + restart: always celeryworker: <<: *django @@ -62,3 +65,18 @@ services: ports: - "5550:5555" command: /start-flower + + nats: + image: nats:2.10-alpine + container_name: ami_local_nats + hostname: nats + ports: + - "4222:4222" # Client port + - "8222:8222" # HTTP monitoring port + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + restart: always diff --git a/docker-compose.yml b/docker-compose.yml index e2ad3a100..703ecea0d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,7 @@ services: depends_on: - postgres - redis + - nats - minio-init - ml_backend - rabbitmq @@ -75,7 +76,12 @@ services: volumes: - ./.git:/app/.git:ro - ./ui:/app - entrypoint: ["sh", "-c", "yarn install && yarn start --debug --host 0.0.0.0 --port 4000"] + entrypoint: + [ + "sh", + "-c", + "yarn install && yarn start --debug --host 0.0.0.0 --port 4000", + ] environment: - API_PROXY_TARGET=http://django:8000 - CHOKIDAR_USEPOLLING=true @@ -84,6 +90,20 @@ services: image: redis:6 container_name: ami_local_redis + nats: + image: nats:2.10-alpine + container_name: ami_local_nats + hostname: nats + ports: + - "4222:4222" # Client port + - "8222:8222" # HTTP monitoring port + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + celeryworker: <<: *django image: ami_local_celeryworker diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..d6f27a4ec 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,6 +8,7 @@ celery==5.4.0 # pyup: < 6.0 # https://github.com/celery/celery django-celery-beat==2.5.0 # https://github.com/celery/django-celery-beat flower==2.0.1 # https://github.com/mher/flower kombu==5.4.2 +nats-py==2.10.0 # https://github.com/nats-io/nats.py uvicorn[standard]==0.22.0 # https://github.com/encode/uvicorn rich==13.5.0 markdown==3.4.4 @@ -41,7 +42,7 @@ djoser==2.2.0 django-guardian==2.4.0 # Email sending django-sendgrid-v5==1.2.2 -django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail +django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail/ ## Formerly dev-only dependencies # However we cannot run the app without some of these these dependencies @@ -52,6 +53,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing