-
Notifications
You must be signed in to change notification settings - Fork 11
Processing service V2 - Phase 1 #987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ae02d2e
24a15af
0da97a6
2db7d66
8a714cd
700f594
3b42e08
8ea5d7d
61fc2c5
9af597c
7ff8865
0fbe899
7899fc5
d9f8ffd
edad552
d254867
1cc890e
84ee5a2
09fee92
4480b0d
3032709
3e7ef3b
04be994
a8b94e3
1fc20b5
0a5c89e
344f883
df7eaa3
0391642
4ae27b0
4f50b3d
a8fc79a
4efdf07
f221a1a
1a9b80a
3657fd2
2483592
0ae9674
3c034a9
e9d2a1c
3d198d0
f9a1226
3a73329
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -322,15 +322,13 @@ def run(cls, job: "Job"): | |
| """ | ||
| Procedure for an ML pipeline as a job. | ||
| """ | ||
| from ami.ml.orchestration.jobs import queue_images_to_nats | ||
|
|
||
| job.update_status(JobState.STARTED) | ||
| job.started_at = datetime.datetime.now() | ||
| job.finished_at = None | ||
| job.save() | ||
|
|
||
| # Keep track of sub-tasks for saving results, pair with batch number | ||
| save_tasks: list[tuple[int, AsyncResult]] = [] | ||
| save_tasks_completed: list[tuple[int, AsyncResult]] = [] | ||
|
|
||
| if job.delay: | ||
| update_interval_seconds = 2 | ||
| last_update = time.time() | ||
|
|
@@ -365,7 +363,7 @@ def run(cls, job: "Job"): | |
| progress=0, | ||
| ) | ||
|
|
||
| images = list( | ||
| images: list[SourceImage] = list( | ||
| # @TODO return generator plus image count | ||
| # @TODO pass to celery group chain? | ||
| job.pipeline.collect_images( | ||
|
|
@@ -389,8 +387,6 @@ def run(cls, job: "Job"): | |
| images = images[: job.limit] | ||
| image_count = len(images) | ||
| job.progress.add_stage_param("collect", "Limit", image_count) | ||
| else: | ||
| image_count = source_image_count | ||
|
|
||
| job.progress.update_stage( | ||
| "collect", | ||
|
|
@@ -401,6 +397,24 @@ def run(cls, job: "Job"): | |
| # End image collection stage | ||
| job.save() | ||
|
|
||
| if job.project.feature_flags.async_pipeline_workers: | ||
| queued = queue_images_to_nats(job, images) | ||
| if not queued: | ||
| job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) | ||
| job.progress.update_stage("collect", status=JobState.FAILURE) | ||
| job.update_status(JobState.FAILURE) | ||
| job.finished_at = datetime.datetime.now() | ||
| job.save() | ||
| return | ||
| else: | ||
| cls.process_images(job, images) | ||
|
Comment on lines
+400
to
+410
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Async ML path leaves overall job lifecycle undefined (status/finished_at, Celery status mismatch) When
Net effect: jobs can show Consider tightening this by:
This will make the async path match the synchronous 🤖 Prompt for AI Agents |
||
|
|
||
| @classmethod | ||
| def process_images(cls, job, images): | ||
| image_count = len(images) | ||
| # Keep track of sub-tasks for saving results, pair with batch number | ||
| save_tasks: list[tuple[int, AsyncResult]] = [] | ||
| save_tasks_completed: list[tuple[int, AsyncResult]] = [] | ||
| total_captures = 0 | ||
| total_detections = 0 | ||
| total_classifications = 0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,16 @@ | ||
| import datetime | ||
| import functools | ||
| import logging | ||
| import time | ||
| from collections.abc import Callable | ||
|
|
||
| from asgiref.sync import async_to_sync | ||
| from celery.signals import task_failure, task_postrun, task_prerun | ||
| from django.db import transaction | ||
|
|
||
| from ami.ml.orchestration.nats_queue import TaskQueueManager | ||
| from ami.ml.orchestration.task_state import TaskStateManager | ||
| from ami.ml.schemas import PipelineResultsResponse | ||
| from ami.tasks import default_soft_time_limit, default_time_limit | ||
| from config import celery_app | ||
|
|
||
|
|
@@ -29,6 +38,132 @@ def run_job(self, job_id: int) -> None: | |
| job.logger.info(f"Finished job {job}") | ||
|
|
||
|
|
||
| @celery_app.task( | ||
| bind=True, | ||
| max_retries=0, # don't retry since we already have retry logic in the NATS queue | ||
| soft_time_limit=300, # 5 minutes | ||
| time_limit=360, # 6 minutes | ||
| ) | ||
| def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None: | ||
| """ | ||
| Process a single pipeline result asynchronously. | ||
|
|
||
| This task: | ||
| 1. Deserializes the pipeline result | ||
| 2. Saves it to the database | ||
| 3. Updates progress by removing processed image IDs from Redis | ||
| 4. Acknowledges the task via NATS | ||
|
|
||
| Args: | ||
| job_id: The job ID | ||
| result_data: Dictionary containing the pipeline result | ||
| reply_subject: NATS reply subject for acknowledgment | ||
| """ | ||
| from ami.jobs.models import Job # avoid circular import | ||
|
|
||
| _, t = log_time() | ||
| error = result_data.get("error") | ||
| pipeline_result = None | ||
| if not error: | ||
| pipeline_result = PipelineResultsResponse(**result_data) | ||
| processed_image_ids = {str(img.id) for img in pipeline_result.source_images} | ||
| else: | ||
| image_id = result_data.get("image_id") | ||
| processed_image_ids = {str(image_id)} if image_id else set() | ||
| logger.error(f"Pipeline returned error for job {job_id}, image {image_id}: {error}") | ||
|
|
||
| state_manager = TaskStateManager(job_id) | ||
|
|
||
| progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) | ||
| if not progress_info: | ||
| logger.warning( | ||
| f"Another task is already processing results for job {job_id}. " | ||
| f"Retrying task {self.request.id} in 5 seconds..." | ||
| ) | ||
| raise self.retry(countdown=5, max_retries=10) | ||
|
|
||
| try: | ||
| _update_job_progress(job_id, "process", progress_info.percentage) | ||
|
|
||
| _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") | ||
| job = Job.objects.get(pk=job_id) | ||
| job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") | ||
| job.logger.info( | ||
| f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " | ||
| f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " | ||
| "processed" | ||
| ) | ||
| except Job.DoesNotExist: | ||
| # don't raise and ack so that we don't retry since the job doesn't exists | ||
| logger.error(f"Job {job_id} not found") | ||
| _ack_task_via_nats(reply_subject, logger) | ||
| return | ||
|
|
||
| try: | ||
| # Save to database (this is the slow operation) | ||
| if pipeline_result: | ||
| # should never happen since otherwise we could not be processing results here | ||
| assert job.pipeline is not None, "Job pipeline is None" | ||
| job.pipeline.save_results(results=pipeline_result, job_id=job.pk) | ||
| job.logger.info(f"Successfully saved results for job {job_id}") | ||
|
|
||
| _, t = t( | ||
| f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" | ||
| f", percentage: {progress_info.percentage*100}%" | ||
| ) | ||
|
|
||
| _ack_task_via_nats(reply_subject, job.logger) | ||
| # Update job stage with calculated progress | ||
| progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) | ||
|
|
||
| if not progress_info: | ||
| logger.warning( | ||
| f"Another task is already processing results for job {job_id}. " | ||
| f"Retrying task {self.request.id} in 5 seconds..." | ||
| ) | ||
| raise self.retry(countdown=5, max_retries=10) | ||
| _update_job_progress(job_id, "results", progress_info.percentage) | ||
|
|
||
| except Exception as e: | ||
| job.logger.error(f"Failed to process pipeline result for job {job_id}: {e}. Retrying ...") | ||
|
|
||
|
Comment on lines
+102
to
+129
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # Search for save_results method definition to check for idempotency handling
rg -n "def save_results" --type=py -A 20Repository: RolnickLab/antenna Length of output: 2730 🏁 Script executed: # Also search for _ack_task_via_nats to understand what happens on ACK failure
rg -n "_ack_task_via_nats|def _ack_task_via_nats" --type=py -A 15Repository: RolnickLab/antenna Length of output: 3185 🏁 Script executed: # Check the exception handler and retry logic in the task
rg -n "except Exception" ami/jobs/tasks.py -B 5 -A 10Repository: RolnickLab/antenna Length of output: 1999 🏁 Script executed: # Get the full save_results function implementation
sed -n '814,900p' ami/ml/models/pipeline.pyRepository: RolnickLab/antenna Length of output: 3415 🏁 Script executed: # Look for Detection.objects.create calls in save_results
rg -n "Detection.*create|bulk_create" ami/ml/models/pipeline.py -B 2 -A 2 | head -50Repository: RolnickLab/antenna Length of output: 1139 Confirm: The implementation at Implement idempotency by either:
Minor: Clarify the exception handler log message. Line 128 logs "Retrying..." but the task doesn't call 🧰 Tools🪛 Ruff (0.14.8)124-124: Abstract (TRY301) 127-127: Do not catch blind exception: (BLE001) 128-128: Use Replace with (TRY400) |
||
|
|
||
| def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: | ||
| try: | ||
|
|
||
| async def ack_task(): | ||
| async with TaskQueueManager() as manager: | ||
| return await manager.acknowledge_task(reply_subject) | ||
|
|
||
| ack_success = async_to_sync(ack_task)() | ||
|
|
||
| if ack_success: | ||
| job_logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") | ||
| else: | ||
| job_logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") | ||
| except Exception as ack_error: | ||
| job_logger.error(f"Error acknowledging task via NATS: {ack_error}") | ||
| # Don't fail the task if ACK fails - data is already saved | ||
|
|
||
|
|
||
| def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: | ||
| from ami.jobs.models import Job, JobState # avoid circular import | ||
|
|
||
| with transaction.atomic(): | ||
| job = Job.objects.select_for_update().get(pk=job_id) | ||
| job.progress.update_stage( | ||
| stage, | ||
| status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, | ||
| progress=progress_percentage, | ||
| ) | ||
| if stage == "results" and progress_percentage >= 1.0: | ||
| job.status = JobState.SUCCESS | ||
| job.progress.summary.status = JobState.SUCCESS | ||
| job.finished_at = datetime.datetime.now() # Use naive datetime in local time | ||
| job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") | ||
| job.save() | ||
|
|
||
|
|
||
| @task_prerun.connect(sender=run_job) | ||
| def pre_update_job_status(sender, task_id, task, **kwargs): | ||
| # in the prerun signal, set the job status to PENDING | ||
|
|
@@ -65,3 +200,28 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): | |
| job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}') | ||
|
|
||
| job.save() | ||
|
|
||
|
|
||
| def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: | ||
| """ | ||
| Small helper to measure time between calls. | ||
|
|
||
| Returns: elapsed time since the last call, and a partial function to measure from the current call | ||
| Usage: | ||
|
|
||
| _, tlog = log_time() | ||
| # do something | ||
| _, tlog = tlog("Did something") # will log the time taken by 'something' | ||
| # do something else | ||
| t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' | ||
| """ | ||
|
|
||
| end = time.perf_counter() | ||
| if start == 0: | ||
| dur = 0.0 | ||
| else: | ||
| dur = end - start | ||
| if msg and start > 0: | ||
| logger.info(f"{msg}: {dur:.3f}s") | ||
| new_start = time.perf_counter() | ||
| return dur, functools.partial(log_time, new_start) | ||
Uh oh!
There was an error while loading. Please reload this page.