diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0503cb10..e59e0388 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -80,7 +80,7 @@ jobs: exit 1 - name: Upload benchmark results - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: bench-${{ matrix.service }}-${{ matrix.dataset }}-${{ matrix.count }}.csv path: bench-${{ matrix.service }}-${{ matrix.dataset }}-${{ matrix.count }}.csv diff --git a/.github/workflows/cd-staging.yml b/.github/workflows/cd-staging.yml index 44743d66..9cb41cc3 100644 --- a/.github/workflows/cd-staging.yml +++ b/.github/workflows/cd-staging.yml @@ -7,6 +7,7 @@ on: jobs: build: + if: false runs-on: ubuntu-latest environment: AWS_ENV_VARS diff --git a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py new file mode 100644 index 00000000..84c386e8 --- /dev/null +++ b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py @@ -0,0 +1,109 @@ +"""extend collection table for provider agnostic support + +Revision ID: 041 +Revises: 040 +Create Date: 2026-01-15 16:53:19.495583 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + +provider_type = postgresql.ENUM( + "OPENAI", + # aws + # gemini + name="providertype", + create_type=False, +) + + +def upgrade(): + provider_type.create(op.get_bind(), checkfirst=True) + op.add_column( + "collection", + sa.Column( + "provider", + provider_type, + nullable=False, + comment="LLM provider used for this collection", + ), + ) + op.execute("UPDATE collection SET provider = 'OPENAI' WHERE provider IS NULL") + op.add_column( + "collection", + sa.Column( + "name", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="Name of the collection", + ), + ) + op.add_column( + "collection", + sa.Column( + "description", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="Description of the collection", + ), + ) + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM service provider", + existing_nullable=False, + ) + op.create_unique_constraint(None, "collection", ["name"]) + op.drop_constraint( + op.f("collection_organization_id_fkey"), "collection", type_="foreignkey" + ) + op.drop_column("collection", "organization_id") + + +def downgrade(): + op.add_column( + "collection", + sa.Column( + "organization_id", + sa.INTEGER(), + autoincrement=False, + nullable=True, + comment="Reference to the organization", + ), + ) + op.execute( + """UPDATE collection SET organization_id = (SELECT organization_id FROM project + WHERE project.id = collection.project_id)""" + ) + op.alter_column("collection", "organization_id", nullable=False) + op.create_foreign_key( + op.f("collection_organization_id_fkey"), + "collection", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint("collection_name_key", "collection", type_="unique") + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "description") + op.drop_column("collection", "name") + op.drop_column("collection", "provider") diff --git a/backend/app/api/main.py b/backend/app/api/main.py index c071a9e1..47cea3b1 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -20,7 +20,7 @@ onboarding, credentials, cron, - evaluation, + evaluations, fine_tuning, model_evaluation, collection_job, @@ -37,7 +37,7 @@ api_router.include_router(cron.router) api_router.include_router(documents.router) api_router.include_router(doc_transformation_job.router) -api_router.include_router(evaluation.router) +api_router.include_router(evaluations.router) api_router.include_router(llm.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index d19fad31..614871fb 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -28,6 +28,7 @@ CollectionPublic, ) from app.utils import APIResponse, load_description, validate_callback_url +from app.services.collections.helpers import ensure_unique_name from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -87,6 +88,9 @@ def create_collection( if request.callback_url: validate_callback_url(str(request.callback_url)) + if request.name: + ensure_unique_name(session, current_user.project_.id, request.name) + collection_job_crud = CollectionJobCrud(session, current_user.project_.id) collection_job = collection_job_crud.create( CollectionJobCreate( diff --git a/backend/app/api/routes/evaluation.py b/backend/app/api/routes/evaluation.py deleted file mode 100644 index 6175476d..00000000 --- a/backend/app/api/routes/evaluation.py +++ /dev/null @@ -1,748 +0,0 @@ -import csv -import io -import logging -import re -from pathlib import Path - -from fastapi import ( - APIRouter, - Body, - File, - Form, - HTTPException, - Query, - UploadFile, - Depends, -) - -from app.api.deps import AuthContextDep, SessionDep -from app.api.permissions import Permission, require_permission -from app.core.cloud import get_cloud_storage -from app.crud.assistants import get_assistant_by_id -from app.crud.evaluations import ( - create_evaluation_dataset, - create_evaluation_run, - get_dataset_by_id, - get_evaluation_run_by_id, - list_datasets, - start_evaluation_batch, - upload_csv_to_object_store, - upload_dataset_to_langfuse, -) -from app.crud.evaluations import list_evaluation_runs as list_evaluation_runs_crud -from app.crud.evaluations.core import save_score -from app.crud.evaluations.dataset import delete_dataset as delete_dataset_crud -from app.crud.evaluations.langfuse import fetch_trace_scores_from_langfuse -from app.models.evaluation import ( - DatasetUploadResponse, - EvaluationRunPublic, -) -from app.utils import ( - APIResponse, - get_langfuse_client, - get_openai_client, - load_description, -) - -logger = logging.getLogger(__name__) - -# File upload security constants -MAX_FILE_SIZE = 1024 * 1024 # 1 MB -ALLOWED_EXTENSIONS = {".csv"} -ALLOWED_MIME_TYPES = { - "text/csv", - "application/csv", - "text/plain", # Some systems report CSV as text/plain -} - -router = APIRouter(tags=["Evaluation"]) - - -def _dataset_to_response(dataset) -> DatasetUploadResponse: - """Convert a dataset model to a DatasetUploadResponse.""" - return DatasetUploadResponse( - dataset_id=dataset.id, - dataset_name=dataset.name, - total_items=dataset.dataset_metadata.get("total_items_count", 0), - original_items=dataset.dataset_metadata.get("original_items_count", 0), - duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), - langfuse_dataset_id=dataset.langfuse_dataset_id, - object_store_url=dataset.object_store_url, - ) - - -def sanitize_dataset_name(name: str) -> str: - """ - Sanitize dataset name for Langfuse compatibility. - - Langfuse has issues with spaces and special characters in dataset names. - This function ensures the name can be both created and fetched. - - Rules: - - Replace spaces with underscores - - Replace hyphens with underscores - - Keep only alphanumeric characters and underscores - - Convert to lowercase for consistency - - Remove leading/trailing underscores - - Collapse multiple consecutive underscores into one - - Args: - name: Original dataset name - - Returns: - Sanitized dataset name safe for Langfuse - - Examples: - "testing 0001" -> "testing_0001" - "My Dataset!" -> "my_dataset" - "Test--Data__Set" -> "test_data_set" - """ - # Convert to lowercase - sanitized = name.lower() - - # Replace spaces and hyphens with underscores - sanitized = sanitized.replace(" ", "_").replace("-", "_") - - # Keep only alphanumeric characters and underscores - sanitized = re.sub(r"[^a-z0-9_]", "", sanitized) - - # Collapse multiple underscores into one - sanitized = re.sub(r"_+", "_", sanitized) - - # Remove leading/trailing underscores - sanitized = sanitized.strip("_") - - # Ensure name is not empty - if not sanitized: - raise ValueError("Dataset name cannot be empty after sanitization") - - return sanitized - - -@router.post( - "/evaluations/datasets", - description=load_description("evaluation/upload_dataset.md"), - response_model=APIResponse[DatasetUploadResponse], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -async def upload_dataset( - _session: SessionDep, - auth_context: AuthContextDep, - file: UploadFile = File( - ..., description="CSV file with 'question' and 'answer' columns" - ), - dataset_name: str = Form(..., description="Name for the dataset"), - description: str | None = Form(None, description="Optional dataset description"), - duplication_factor: int = Form( - default=1, - ge=1, - le=5, - description="Number of times to duplicate each item (min: 1, max: 5)", - ), -) -> APIResponse[DatasetUploadResponse]: - # Sanitize dataset name for Langfuse compatibility - original_name = dataset_name - try: - dataset_name = sanitize_dataset_name(dataset_name) - except ValueError as e: - raise HTTPException(status_code=422, detail=f"Invalid dataset name: {str(e)}") - - if original_name != dataset_name: - logger.info( - f"[upload_dataset] Dataset name sanitized | '{original_name}' -> '{dataset_name}'" - ) - - logger.info( - f"[upload_dataset] Uploading dataset | dataset={dataset_name} | " - f"duplication_factor={duplication_factor} | org_id={auth_context.organization_.id} | " - f"project_id={auth_context.project_.id}" - ) - - # Security validation: Check file extension - file_ext = Path(file.filename).suffix.lower() - if file_ext not in ALLOWED_EXTENSIONS: - raise HTTPException( - status_code=422, - detail=f"Invalid file type. Only CSV files are allowed. Got: {file_ext}", - ) - - # Security validation: Check MIME type - content_type = file.content_type - if content_type not in ALLOWED_MIME_TYPES: - raise HTTPException( - status_code=422, - detail=f"Invalid content type. Expected CSV, got: {content_type}", - ) - - # Security validation: Check file size - file.file.seek(0, 2) # Seek to end - file_size = file.file.tell() - file.file.seek(0) # Reset to beginning - - if file_size > MAX_FILE_SIZE: - raise HTTPException( - status_code=413, - detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024 * 1024):.0f}MB", - ) - - if file_size == 0: - raise HTTPException(status_code=422, detail="Empty file uploaded") - - # Read CSV content - csv_content = await file.read() - - # Step 1: Parse and validate CSV - try: - csv_text = csv_content.decode("utf-8") - csv_reader = csv.DictReader(io.StringIO(csv_text)) - - if not csv_reader.fieldnames: - raise HTTPException(status_code=422, detail="CSV file has no headers") - - # Normalize headers for case-insensitive matching - clean_headers = { - field.strip().lower(): field for field in csv_reader.fieldnames - } - - # Validate required headers (case-insensitive) - if "question" not in clean_headers or "answer" not in clean_headers: - raise HTTPException( - status_code=422, - detail=f"CSV must contain 'question' and 'answer' columns " - f"Found columns: {csv_reader.fieldnames}", - ) - - # Get the actual column names from the CSV - question_col = clean_headers["question"] - answer_col = clean_headers["answer"] - - # Count original items - original_items = [] - for row in csv_reader: - question = row.get(question_col, "").strip() - answer = row.get(answer_col, "").strip() - if question and answer: - original_items.append({"question": question, "answer": answer}) - - if not original_items: - raise HTTPException( - status_code=422, detail="No valid items found in CSV file" - ) - - original_items_count = len(original_items) - total_items_count = original_items_count * duplication_factor - - logger.info( - f"[upload_dataset] Parsed items from CSV | original={original_items_count} | " - f"total_with_duplication={total_items_count}" - ) - - except Exception as e: - logger.error(f"[upload_dataset] Failed to parse CSV | {e}", exc_info=True) - raise HTTPException(status_code=422, detail=f"Invalid CSV file: {e}") - - # Step 2: Upload to object store (if credentials configured) - object_store_url = None - try: - storage = get_cloud_storage( - session=_session, project_id=auth_context.project_.id - ) - object_store_url = upload_csv_to_object_store( - storage=storage, csv_content=csv_content, dataset_name=dataset_name - ) - if object_store_url: - logger.info( - f"[upload_dataset] Successfully uploaded CSV to object store | {object_store_url}" - ) - else: - logger.info( - "[upload_dataset] Object store upload returned None | continuing without object store storage" - ) - except Exception as e: - logger.warning( - f"[upload_dataset] Failed to upload CSV to object store (continuing without object store) | {e}", - exc_info=True, - ) - object_store_url = None - - # Step 3: Upload to Langfuse - langfuse_dataset_id = None - try: - # Get Langfuse client - langfuse = get_langfuse_client( - session=_session, - org_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - # Upload to Langfuse - langfuse_dataset_id, _ = upload_dataset_to_langfuse( - langfuse=langfuse, - items=original_items, - dataset_name=dataset_name, - duplication_factor=duplication_factor, - ) - - logger.info( - f"[upload_dataset] Successfully uploaded dataset to Langfuse | " - f"dataset={dataset_name} | id={langfuse_dataset_id}" - ) - - except Exception as e: - logger.error( - f"[upload_dataset] Failed to upload dataset to Langfuse | {e}", - exc_info=True, - ) - raise HTTPException( - status_code=500, detail=f"Failed to upload dataset to Langfuse: {e}" - ) - - # Step 4: Store metadata in database - metadata = { - "original_items_count": original_items_count, - "total_items_count": total_items_count, - "duplication_factor": duplication_factor, - } - - dataset = create_evaluation_dataset( - session=_session, - name=dataset_name, - description=description, - dataset_metadata=metadata, - object_store_url=object_store_url, - langfuse_dataset_id=langfuse_dataset_id, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - logger.info( - f"[upload_dataset] Successfully created dataset record in database | " - f"id={dataset.id} | name={dataset_name}" - ) - - # Return response - return APIResponse.success_response( - data=DatasetUploadResponse( - dataset_id=dataset.id, - dataset_name=dataset_name, - total_items=total_items_count, - original_items=original_items_count, - duplication_factor=duplication_factor, - langfuse_dataset_id=langfuse_dataset_id, - object_store_url=object_store_url, - ) - ) - - -@router.get( - "/evaluations/datasets", - description=load_description("evaluation/list_datasets.md"), - response_model=APIResponse[list[DatasetUploadResponse]], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -def list_datasets_endpoint( - _session: SessionDep, - auth_context: AuthContextDep, - limit: int = 50, - offset: int = 0, -) -> APIResponse[list[DatasetUploadResponse]]: - # Enforce maximum limit - if limit > 100: - limit = 100 - - datasets = list_datasets( - session=_session, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - limit=limit, - offset=offset, - ) - - return APIResponse.success_response( - data=[_dataset_to_response(dataset) for dataset in datasets] - ) - - -@router.get( - "/evaluations/datasets/{dataset_id}", - description=load_description("evaluation/get_dataset.md"), - response_model=APIResponse[DatasetUploadResponse], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -def get_dataset( - dataset_id: int, - _session: SessionDep, - auth_context: AuthContextDep, -) -> APIResponse[DatasetUploadResponse]: - logger.info( - f"[get_dataset] Fetching dataset | id={dataset_id} | " - f"org_id={auth_context.organization_.id} | " - f"project_id={auth_context.project_.id}" - ) - - dataset = get_dataset_by_id( - session=_session, - dataset_id=dataset_id, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - if not dataset: - raise HTTPException( - status_code=404, detail=f"Dataset {dataset_id} not found or not accessible" - ) - - return APIResponse.success_response(data=_dataset_to_response(dataset)) - - -@router.delete( - "/evaluations/datasets/{dataset_id}", - description=load_description("evaluation/delete_dataset.md"), - response_model=APIResponse[dict], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -def delete_dataset( - dataset_id: int, - _session: SessionDep, - auth_context: AuthContextDep, -) -> APIResponse[dict]: - logger.info( - f"[delete_dataset] Deleting dataset | id={dataset_id} | " - f"org_id={auth_context.organization_.id} | " - f"project_id={auth_context.project_.id}" - ) - - success, message = delete_dataset_crud( - session=_session, - dataset_id=dataset_id, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - if not success: - # Check if it's a not found error or other error type - if "not found" in message.lower(): - raise HTTPException(status_code=404, detail=message) - else: - raise HTTPException(status_code=400, detail=message) - - logger.info(f"[delete_dataset] Successfully deleted dataset | id={dataset_id}") - return APIResponse.success_response( - data={"message": message, "dataset_id": dataset_id} - ) - - -@router.post( - "/evaluations", - description=load_description("evaluation/create_evaluation.md"), - response_model=APIResponse[EvaluationRunPublic], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -def evaluate( - _session: SessionDep, - auth_context: AuthContextDep, - dataset_id: int = Body(..., description="ID of the evaluation dataset"), - experiment_name: str = Body( - ..., description="Name for this evaluation experiment/run" - ), - config: dict = Body(default_factory=dict, description="Evaluation configuration"), - assistant_id: str - | None = Body( - None, description="Optional assistant ID to fetch configuration from" - ), -) -> APIResponse[EvaluationRunPublic]: - logger.info( - f"[evaluate] Starting evaluation | experiment_name={experiment_name} | " - f"dataset_id={dataset_id} | " - f"org_id={auth_context.organization_.id} | " - f"assistant_id={assistant_id} | " - f"config_keys={list(config.keys())}" - ) - - # Step 1: Fetch dataset from database - dataset = get_dataset_by_id( - session=_session, - dataset_id=dataset_id, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - if not dataset: - raise HTTPException( - status_code=404, - detail=f"Dataset {dataset_id} not found or not accessible to this " - f"organization/project", - ) - - logger.info( - f"[evaluate] Found dataset | id={dataset.id} | name={dataset.name} | " - f"object_store_url={'present' if dataset.object_store_url else 'None'} | " - f"langfuse_id={dataset.langfuse_dataset_id}" - ) - - dataset_name = dataset.name - - # Get API clients - openai_client = get_openai_client( - session=_session, - org_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - langfuse = get_langfuse_client( - session=_session, - org_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - # Validate dataset has Langfuse ID (should have been set during dataset creation) - if not dataset.langfuse_dataset_id: - raise HTTPException( - status_code=400, - detail=f"Dataset {dataset_id} does not have a Langfuse dataset ID. " - "Please ensure Langfuse credentials were configured when the dataset was created.", - ) - - # Handle assistant_id if provided - if assistant_id: - # Fetch assistant details from database - assistant = get_assistant_by_id( - session=_session, - assistant_id=assistant_id, - project_id=auth_context.project_.id, - ) - - if not assistant: - raise HTTPException( - status_code=404, detail=f"Assistant {assistant_id} not found" - ) - - logger.info( - f"[evaluate] Found assistant in DB | id={assistant.id} | " - f"model={assistant.model} | instructions=" - f"{assistant.instructions[:50] if assistant.instructions else 'None'}..." - ) - - # Build config from assistant (use provided config values to override - # if present) - config = { - "model": config.get("model", assistant.model), - "instructions": config.get("instructions", assistant.instructions), - "temperature": config.get("temperature", assistant.temperature), - } - - # Add tools if vector stores are available - vector_store_ids = config.get( - "vector_store_ids", assistant.vector_store_ids or [] - ) - if vector_store_ids and len(vector_store_ids) > 0: - config["tools"] = [ - { - "type": "file_search", - "vector_store_ids": vector_store_ids, - } - ] - - logger.info("[evaluate] Using config from assistant") - else: - logger.info("[evaluate] Using provided config directly") - # Validate that config has minimum required fields - if not config.get("model"): - raise HTTPException( - status_code=400, - detail="Config must include 'model' when assistant_id is not provided", - ) - - # Create EvaluationRun record - eval_run = create_evaluation_run( - session=_session, - run_name=experiment_name, - dataset_name=dataset_name, - dataset_id=dataset_id, - config=config, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - # Start the batch evaluation - try: - eval_run = start_evaluation_batch( - langfuse=langfuse, - openai_client=openai_client, - session=_session, - eval_run=eval_run, - config=config, - ) - - logger.info( - f"[evaluate] Evaluation started successfully | " - f"batch_job_id={eval_run.batch_job_id} | total_items={eval_run.total_items}" - ) - - return APIResponse.success_response(data=eval_run) - - except Exception as e: - logger.error( - f"[evaluate] Failed to start evaluation | run_id={eval_run.id} | {e}", - exc_info=True, - ) - # Error is already handled in start_evaluation_batch - _session.refresh(eval_run) - return APIResponse.success_response(data=eval_run) - - -@router.get( - "/evaluations", - description=load_description("evaluation/list_evaluations.md"), - response_model=APIResponse[list[EvaluationRunPublic]], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -def list_evaluation_runs( - _session: SessionDep, - auth_context: AuthContextDep, - limit: int = 50, - offset: int = 0, -) -> APIResponse[list[EvaluationRunPublic]]: - logger.info( - f"[list_evaluation_runs] Listing evaluation runs | " - f"org_id={auth_context.organization_.id} | " - f"project_id={auth_context.project_.id} | limit={limit} | offset={offset}" - ) - - return APIResponse.success_response( - data=list_evaluation_runs_crud( - session=_session, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - limit=limit, - offset=offset, - ) - ) - - -@router.get( - "/evaluations/{evaluation_id}", - description=load_description("evaluation/get_evaluation.md"), - response_model=APIResponse[EvaluationRunPublic], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], -) -def get_evaluation_run_status( - evaluation_id: int, - _session: SessionDep, - auth_context: AuthContextDep, - get_trace_info: bool = Query( - False, - description=( - "If true, fetch and include Langfuse trace scores with Q&A context. " - "On first request, data is fetched from Langfuse and cached. " - "Subsequent requests return cached data." - ), - ), - resync_score: bool = Query( - False, - description=( - "If true, clear cached scores and re-fetch from Langfuse. " - "Useful when new evaluators have been added or scores have been updated. " - "Requires get_trace_info=true." - ), - ), -) -> APIResponse[EvaluationRunPublic]: - logger.info( - f"[get_evaluation_run_status] Fetching status for evaluation run | " - f"evaluation_id={evaluation_id} | " - f"org_id={auth_context.organization_.id} | " - f"project_id={auth_context.project_.id} | " - f"get_trace_info={get_trace_info} | " - f"resync_score={resync_score}" - ) - - if resync_score and not get_trace_info: - raise HTTPException( - status_code=400, - detail="resync_score=true requires get_trace_info=true", - ) - - eval_run = get_evaluation_run_by_id( - session=_session, - evaluation_id=evaluation_id, - organization_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - if not eval_run: - raise HTTPException( - status_code=404, - detail=( - f"Evaluation run {evaluation_id} not found or not accessible " - "to this organization" - ), - ) - - if get_trace_info: - # Only fetch trace info for completed evaluations - if eval_run.status != "completed": - return APIResponse.failure_response( - error=f"Trace info is only available for completed evaluations. " - f"Current status: {eval_run.status}", - data=eval_run, - ) - - # Check if we already have cached scores (before any slow operations) - has_cached_score = eval_run.score is not None and "traces" in eval_run.score - if not resync_score and has_cached_score: - return APIResponse.success_response(data=eval_run) - - # Get Langfuse client (needs session for credentials lookup) - langfuse = get_langfuse_client( - session=_session, - org_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - ) - - # Capture data needed for Langfuse fetch and DB update - dataset_name = eval_run.dataset_name - run_name = eval_run.run_name - eval_run_id = eval_run.id - org_id = auth_context.organization_.id - project_id = auth_context.project_.id - - # Session is no longer needed - slow Langfuse API calls happen here - # without holding the DB connection - try: - score = fetch_trace_scores_from_langfuse( - langfuse=langfuse, - dataset_name=dataset_name, - run_name=run_name, - ) - except ValueError as e: - # Run not found in Langfuse - return eval_run with error - logger.warning( - f"[get_evaluation_run_status] Run not found in Langfuse | " - f"evaluation_id={evaluation_id} | error={e}" - ) - return APIResponse.failure_response(error=str(e), data=eval_run) - except Exception as e: - logger.error( - f"[get_evaluation_run_status] Failed to fetch trace info | " - f"evaluation_id={evaluation_id} | error={e}", - exc_info=True, - ) - return APIResponse.failure_response( - error=f"Failed to fetch trace info from Langfuse: {str(e)}", - data=eval_run, - ) - - # Open new session just for the score commit - eval_run = save_score( - eval_run_id=eval_run_id, - organization_id=org_id, - project_id=project_id, - score=score, - ) - - if not eval_run: - raise HTTPException( - status_code=404, - detail=f"Evaluation run {evaluation_id} not found after score update", - ) - - return APIResponse.success_response(data=eval_run) diff --git a/backend/app/api/routes/evaluations/__init__.py b/backend/app/api/routes/evaluations/__init__.py new file mode 100644 index 00000000..3f7fe120 --- /dev/null +++ b/backend/app/api/routes/evaluations/__init__.py @@ -0,0 +1,13 @@ +"""Evaluation API routes.""" + +from fastapi import APIRouter + +from app.api.routes.evaluations import dataset, evaluation + +router = APIRouter(prefix="/evaluations", tags=["evaluation"]) + +# Include dataset routes under /evaluations/datasets +router.include_router(dataset.router, prefix="/datasets") + +# Include evaluation routes directly under /evaluations +router.include_router(evaluation.router) diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py new file mode 100644 index 00000000..25ecacff --- /dev/null +++ b/backend/app/api/routes/evaluations/dataset.py @@ -0,0 +1,191 @@ +"""Evaluation dataset API routes.""" + +import logging + +from fastapi import ( + APIRouter, + Depends, + File, + Form, + HTTPException, + Query, + UploadFile, +) + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.evaluations import ( + get_dataset_by_id, + list_datasets as list_evaluation_datasets, +) +from app.crud.evaluations.dataset import delete_dataset as delete_dataset_crud +from app.models.evaluation import DatasetUploadResponse, EvaluationDataset +from app.services.evaluations import ( + upload_dataset as upload_evaluation_dataset, + validate_csv_file, +) +from app.utils import ( + APIResponse, + load_description, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse: + """Convert a dataset model to a DatasetUploadResponse.""" + return DatasetUploadResponse( + dataset_id=dataset.id, + dataset_name=dataset.name, + total_items=dataset.dataset_metadata.get("total_items_count", 0), + original_items=dataset.dataset_metadata.get("original_items_count", 0), + duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), + langfuse_dataset_id=dataset.langfuse_dataset_id, + object_store_url=dataset.object_store_url, + ) + + +@router.post( + "/", + description=load_description("evaluation/upload_dataset.md"), + response_model=APIResponse[DatasetUploadResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +async def upload_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + file: UploadFile = File( + ..., description="CSV file with 'question' and 'answer' columns" + ), + dataset_name: str = Form(..., description="Name for the dataset"), + description: str | None = Form(None, description="Optional dataset description"), + duplication_factor: int = Form( + default=1, + ge=1, + le=5, + description="Number of times to duplicate each item (min: 1, max: 5)", + ), +) -> APIResponse[DatasetUploadResponse]: + """Upload an evaluation dataset.""" + # Validate and read CSV file + csv_content = await validate_csv_file(file) + + # Upload dataset using service + dataset = upload_evaluation_dataset( + session=_session, + csv_content=csv_content, + dataset_name=dataset_name, + description=description, + duplication_factor=duplication_factor, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response(data=_dataset_to_response(dataset)) + + +@router.get( + "/", + description=load_description("evaluation/list_datasets.md"), + response_model=APIResponse[list[DatasetUploadResponse]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def list_datasets( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query( + default=50, ge=1, le=100, description="Maximum number of datasets to return" + ), + offset: int = Query(default=0, ge=0, description="Number of datasets to skip"), +) -> APIResponse[list[DatasetUploadResponse]]: + """List evaluation datasets.""" + datasets = list_evaluation_datasets( + session=_session, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=[_dataset_to_response(dataset) for dataset in datasets] + ) + + +@router.get( + "/{dataset_id}", + description=load_description("evaluation/get_dataset.md"), + response_model=APIResponse[DatasetUploadResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def get_dataset( + dataset_id: int, + _session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[DatasetUploadResponse]: + """Get a specific evaluation dataset.""" + logger.info( + f"[get_dataset] Fetching dataset | id={dataset_id} | " + f"org_id={auth_context.organization_.id} | " + f"project_id={auth_context.project_.id}" + ) + + dataset = get_dataset_by_id( + session=_session, + dataset_id=dataset_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException( + status_code=404, detail=f"Dataset {dataset_id} not found or not accessible" + ) + + return APIResponse.success_response(data=_dataset_to_response(dataset)) + + +@router.delete( + "/{dataset_id}", + description=load_description("evaluation/delete_dataset.md"), + response_model=APIResponse[dict], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def delete_dataset( + dataset_id: int, + _session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[dict]: + """Delete an evaluation dataset.""" + logger.info( + f"[delete_dataset] Deleting dataset | id={dataset_id} | " + f"org_id={auth_context.organization_.id} | " + f"project_id={auth_context.project_.id}" + ) + + dataset = get_dataset_by_id( + session=_session, + dataset_id=dataset_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException( + status_code=404, detail=f"Dataset {dataset_id} not found or not accessible" + ) + dataset_name = dataset.name + error = delete_dataset_crud(session=_session, dataset=dataset) + + if error: + raise HTTPException(status_code=400, detail=error) + + logger.info(f"[delete_dataset] Successfully deleted dataset | id={dataset_id}") + return APIResponse.success_response( + data={ + "message": f"Successfully deleted dataset '{dataset_name}' (id={dataset_id})", + "dataset_id": dataset_id, + } + ) diff --git a/backend/app/api/routes/evaluations/evaluation.py b/backend/app/api/routes/evaluations/evaluation.py new file mode 100644 index 00000000..b51a5948 --- /dev/null +++ b/backend/app/api/routes/evaluations/evaluation.py @@ -0,0 +1,154 @@ +"""Evaluation run API routes.""" + +import logging + +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + Query, +) + +from app.api.deps import AuthContextDep, SessionDep +from app.crud.evaluations import list_evaluation_runs as list_evaluation_runs_crud +from app.models.evaluation import EvaluationRunPublic +from app.api.permissions import Permission, require_permission +from app.services.evaluations import ( + get_evaluation_with_scores, + start_evaluation, +) +from app.utils import ( + APIResponse, + load_description, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/", + description=load_description("evaluation/create_evaluation.md"), + response_model=APIResponse[EvaluationRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def evaluate( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int = Body(..., description="ID of the evaluation dataset"), + experiment_name: str = Body( + ..., description="Name for this evaluation experiment/run" + ), + config: dict = Body(default_factory=dict, description="Evaluation configuration"), + assistant_id: str + | None = Body( + None, description="Optional assistant ID to fetch configuration from" + ), +) -> APIResponse[EvaluationRunPublic]: + """Start an evaluation run.""" + eval_run = start_evaluation( + session=_session, + dataset_id=dataset_id, + experiment_name=experiment_name, + config=config, + assistant_id=assistant_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if eval_run.status == "failed": + return APIResponse.failure_response( + error=eval_run.error_message or "Evaluation failed to start", + data=eval_run, + ) + + return APIResponse.success_response(data=eval_run) + + +@router.get( + "/", + description=load_description("evaluation/list_evaluations.md"), + response_model=APIResponse[list[EvaluationRunPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def list_evaluation_runs( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = 50, + offset: int = 0, +) -> APIResponse[list[EvaluationRunPublic]]: + """List evaluation runs.""" + logger.info( + f"[list_evaluation_runs] Listing evaluation runs | " + f"org_id={auth_context.organization_.id} | " + f"project_id={auth_context.project_.id} | limit={limit} | offset={offset}" + ) + + return APIResponse.success_response( + data=list_evaluation_runs_crud( + session=_session, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + ) + + +@router.get( + "/{evaluation_id}", + description=load_description("evaluation/get_evaluation.md"), + response_model=APIResponse[EvaluationRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def get_evaluation_run_status( + evaluation_id: int, + _session: SessionDep, + auth_context: AuthContextDep, + get_trace_info: bool = Query( + False, + description=( + "If true, fetch and include Langfuse trace scores with Q&A context. " + "On first request, data is fetched from Langfuse and cached. " + "Subsequent requests return cached data." + ), + ), + resync_score: bool = Query( + False, + description=( + "If true, clear cached scores and re-fetch from Langfuse. " + "Useful when new evaluators have been added or scores have been updated. " + "Requires get_trace_info=true." + ), + ), +) -> APIResponse[EvaluationRunPublic]: + """Get evaluation run status with optional trace info.""" + if resync_score and not get_trace_info: + raise HTTPException( + status_code=400, + detail="resync_score=true requires get_trace_info=true", + ) + + eval_run, error = get_evaluation_with_scores( + session=_session, + evaluation_id=evaluation_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + get_trace_info=get_trace_info, + resync_score=resync_score, + ) + + if not eval_run: + raise HTTPException( + status_code=404, + detail=( + f"Evaluation run {evaluation_id} not found or not accessible " + "to this organization" + ), + ) + + if error: + return APIResponse.failure_response(error=error, data=eval_run) + return APIResponse.success_response(data=eval_run) diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 9f7cd88d..0cd10bd5 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -1,5 +1,21 @@ """Batch processing infrastructure for LLM providers.""" from .base import BatchProvider +from .openai import OpenAIBatchProvider +from .operations import ( + download_batch_results, + process_completed_batch, + start_batch_job, + upload_batch_results_to_object_store, +) +from .polling import poll_batch_status -__all__ = ["BatchProvider"] +__all__ = [ + "BatchProvider", + "OpenAIBatchProvider", + "start_batch_job", + "download_batch_results", + "process_completed_batch", + "upload_batch_results_to_object_store", + "poll_batch_status", +] diff --git a/backend/app/crud/batch_operations.py b/backend/app/core/batch/operations.py similarity index 79% rename from backend/app/crud/batch_operations.py rename to backend/app/core/batch/operations.py index f2bb332e..d02884fd 100644 --- a/backend/app/crud/batch_operations.py +++ b/backend/app/core/batch/operations.py @@ -8,10 +8,7 @@ from app.core.batch.base import BatchProvider from app.core.cloud import get_cloud_storage from app.core.storage_utils import upload_jsonl_to_object_store as shared_upload_jsonl -from app.crud.batch_job import ( - create_batch_job, - update_batch_job, -) +from app.crud.job.job import create_batch_job, update_batch_job from app.models.batch_job import BatchJob, BatchJobCreate, BatchJobUpdate logger = logging.getLogger(__name__) @@ -86,47 +83,6 @@ def start_batch_job( raise -def poll_batch_status( - session: Session, provider: BatchProvider, batch_job: BatchJob -) -> dict[str, Any]: - """Poll provider for batch status and update database.""" - logger.info( - f"[poll_batch_status] Polling | id={batch_job.id} | " - f"provider_batch_id={batch_job.provider_batch_id}" - ) - - try: - status_result = provider.get_batch_status(batch_job.provider_batch_id) - - provider_status = status_result["provider_status"] - if provider_status != batch_job.provider_status: - update_data = {"provider_status": provider_status} - - if status_result.get("provider_output_file_id"): - update_data["provider_output_file_id"] = status_result[ - "provider_output_file_id" - ] - - if status_result.get("error_message"): - update_data["error_message"] = status_result["error_message"] - - batch_job_update = BatchJobUpdate(**update_data) - batch_job = update_batch_job( - session=session, batch_job=batch_job, batch_job_update=batch_job_update - ) - - logger.info( - f"[poll_batch_status] Updated | id={batch_job.id} | " - f"{batch_job.provider_status} -> {provider_status}" - ) - - return status_result - - except Exception as e: - logger.error(f"[poll_batch_status] Failed | {e}", exc_info=True) - raise - - def download_batch_results( provider: BatchProvider, batch_job: BatchJob ) -> list[dict[str, Any]]: diff --git a/backend/app/core/batch/polling.py b/backend/app/core/batch/polling.py new file mode 100644 index 00000000..c364aeb3 --- /dev/null +++ b/backend/app/core/batch/polling.py @@ -0,0 +1,53 @@ +"""Batch status polling operations.""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.batch.base import BatchProvider +from app.crud.job.job import update_batch_job +from app.models.batch_job import BatchJob, BatchJobUpdate + +logger = logging.getLogger(__name__) + + +def poll_batch_status( + session: Session, provider: BatchProvider, batch_job: BatchJob +) -> dict[str, Any]: + """Poll provider for batch status and update database.""" + logger.info( + f"[poll_batch_status] Polling | id={batch_job.id} | " + f"provider_batch_id={batch_job.provider_batch_id}" + ) + + try: + status_result = provider.get_batch_status(batch_job.provider_batch_id) + + provider_status = status_result["provider_status"] + if provider_status != batch_job.provider_status: + update_data = {"provider_status": provider_status} + + if status_result.get("provider_output_file_id"): + update_data["provider_output_file_id"] = status_result[ + "provider_output_file_id" + ] + + if status_result.get("error_message"): + update_data["error_message"] = status_result["error_message"] + + batch_job_update = BatchJobUpdate(**update_data) + batch_job = update_batch_job( + session=session, batch_job=batch_job, batch_job_update=batch_job_update + ) + + logger.info( + f"[poll_batch_status] Updated | id={batch_job.id} | " + f"{batch_job.provider_status} -> {provider_status}" + ) + + return status_result + + except Exception as e: + logger.error(f"[poll_batch_status] Failed | {e}", exc_info=True) + raise diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 8807d705..8cee6e98 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -110,44 +110,6 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) -def encrypt_api_key(api_key: str) -> str: - """ - Encrypt an API key before storage. - - Args: - api_key: The plain text API key to encrypt - - Returns: - str: The encrypted API key - - Raises: - ValueError: If encryption fails - """ - try: - return get_fernet().encrypt(api_key.encode()).decode() - except Exception as e: - raise ValueError(f"Failed to encrypt API key: {e}") - - -def decrypt_api_key(encrypted_api_key: str) -> str: - """ - Decrypt an API key when retrieving it. - - Args: - encrypted_api_key: The encrypted API key to decrypt - - Returns: - str: The decrypted API key - - Raises: - ValueError: If decryption fails - """ - try: - return get_fernet().decrypt(encrypted_api_key.encode()).decode() - except Exception as e: - raise ValueError(f"Failed to decrypt API key: {e}") - - def encrypt_credentials(credentials: dict) -> str: """ Encrypt the entire credentials object before storage. diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index 3c83912a..cb8b6e27 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -93,6 +93,16 @@ def read_all(self): collections = self.session.exec(statement).all() return collections + def exists_by_name(self, collection_name: str) -> bool: + statement = ( + select(Collection.id) + .where(Collection.project_id == self.project_id) + .where(Collection.name == collection_name) + .where(Collection.deleted_at.is_(None)) + ) + result = self.session.exec(statement).first() + return result is not None + def delete_by_id(self, collection_id: UUID) -> Collection: coll = self.read_one(collection_id) coll.deleted_at = now() diff --git a/backend/app/crud/evaluations/__init__.py b/backend/app/crud/evaluations/__init__.py index 5ca0aacd..bb095413 100644 --- a/backend/app/crud/evaluations/__init__.py +++ b/backend/app/crud/evaluations/__init__.py @@ -5,6 +5,7 @@ create_evaluation_run, get_evaluation_run_by_id, list_evaluation_runs, + save_score, ) from app.crud.evaluations.cron import ( process_all_pending_evaluations, @@ -24,6 +25,7 @@ ) from app.crud.evaluations.langfuse import ( create_langfuse_dataset_run, + fetch_trace_scores_from_langfuse, update_traces_with_cosine_scores, upload_dataset_to_langfuse, ) @@ -33,34 +35,3 @@ process_completed_embedding_batch, process_completed_evaluation, ) - -__all__ = [ - # Core - "create_evaluation_run", - "get_evaluation_run_by_id", - "list_evaluation_runs", - # Cron - "process_all_pending_evaluations", - "process_all_pending_evaluations_sync", - # Dataset - "create_evaluation_dataset", - "delete_dataset", - "get_dataset_by_id", - "list_datasets", - "upload_csv_to_object_store", - # Batch - "start_evaluation_batch", - # Processing - "check_and_process_evaluation", - "poll_all_pending_evaluations", - "process_completed_embedding_batch", - "process_completed_evaluation", - # Embeddings - "calculate_average_similarity", - "calculate_cosine_similarity", - "start_embedding_batch", - # Langfuse - "create_langfuse_dataset_run", - "update_traces_with_cosine_scores", - "upload_dataset_to_langfuse", -] diff --git a/backend/app/crud/evaluations/batch.py b/backend/app/crud/evaluations/batch.py index 7e8b6904..e880d7d0 100644 --- a/backend/app/crud/evaluations/batch.py +++ b/backend/app/crud/evaluations/batch.py @@ -14,8 +14,7 @@ from openai import OpenAI from sqlmodel import Session -from app.core.batch.openai import OpenAIBatchProvider -from app.crud.batch_operations import start_batch_job +from app.core.batch import OpenAIBatchProvider, start_batch_job from app.models import EvaluationRun logger = logging.getLogger(__name__) diff --git a/backend/app/crud/evaluations/dataset.py b/backend/app/crud/evaluations/dataset.py index 7efa03d4..6a94809c 100644 --- a/backend/app/crud/evaluations/dataset.py +++ b/backend/app/crud/evaluations/dataset.py @@ -320,68 +320,51 @@ def update_dataset_langfuse_id( ) -def delete_dataset( - session: Session, dataset_id: int, organization_id: int, project_id: int -) -> tuple[bool, str]: +def delete_dataset(session: Session, dataset: EvaluationDataset) -> str | None: """ - Delete an evaluation dataset by ID. + Delete an evaluation dataset. This performs a hard delete from the database. The CSV file in object store (if exists) will remain for audit purposes. Args: session: Database session - dataset_id: Dataset ID to delete - organization_id: Organization ID for validation - project_id: Project ID for validation + dataset: The dataset to delete (must be fetched beforehand) Returns: - Tuple of (success: bool, message: str) + None on success, error message string on failure """ - # First, fetch the dataset to ensure it exists and belongs to the org/project - dataset = get_dataset_by_id( - session=session, - dataset_id=dataset_id, - organization_id=organization_id, - project_id=project_id, - ) - - if not dataset: - return ( - False, - f"Dataset {dataset_id} not found or not accessible", - ) - # Check if dataset is being used by any evaluation runs - statement = select(EvaluationRun).where(EvaluationRun.dataset_id == dataset_id) + statement = select(EvaluationRun).where(EvaluationRun.dataset_id == dataset.id) evaluation_runs = session.exec(statement).all() if evaluation_runs: return ( - False, - f"Cannot delete dataset {dataset_id}: it is being used by " + f"Cannot delete dataset {dataset.id}: it is being used by " f"{len(evaluation_runs)} evaluation run(s). Please delete " - f"the evaluation runs first.", + f"the evaluation runs first." ) # Delete the dataset try: + dataset_id = dataset.id + dataset_name = dataset.name + organization_id = dataset.organization_id + project_id = dataset.project_id + session.delete(dataset) session.commit() logger.info( - f"[delete_dataset] Deleted dataset | id={dataset_id} | name={dataset.name} | org_id={organization_id} | project_id={project_id}" + f"[delete_dataset] Deleted dataset | id={dataset_id} | name={dataset_name} | org_id={organization_id} | project_id={project_id}" ) - return ( - True, - f"Successfully deleted dataset '{dataset.name}' (id={dataset_id})", - ) + return None except Exception as e: session.rollback() logger.error( - f"[delete_dataset] Failed to delete dataset | dataset_id={dataset_id} | {e}", + f"[delete_dataset] Failed to delete dataset | dataset_id={dataset.id} | {e}", exc_info=True, ) - return (False, f"Failed to delete dataset: {e}") + return f"Failed to delete dataset: {e}" diff --git a/backend/app/crud/evaluations/embeddings.py b/backend/app/crud/evaluations/embeddings.py index 70e37421..17ead39a 100644 --- a/backend/app/crud/evaluations/embeddings.py +++ b/backend/app/crud/evaluations/embeddings.py @@ -15,9 +15,8 @@ from openai import OpenAI from sqlmodel import Session -from app.core.batch.openai import OpenAIBatchProvider +from app.core.batch import OpenAIBatchProvider, start_batch_job from app.core.util import now -from app.crud.batch_operations import start_batch_job from app.models import EvaluationRun logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ def validate_embedding_model(model: str) -> None: if model not in VALID_EMBEDDING_MODELS: valid_models = ", ".join(VALID_EMBEDDING_MODELS.keys()) raise ValueError( - f"Invalid embedding model '{model}'. " f"Supported models: {valid_models}" + f"Invalid embedding model '{model}'. Supported models: {valid_models}" ) @@ -82,7 +81,7 @@ def build_embedding_jsonl( validate_embedding_model(embedding_model) logger.info( - f"Building embedding JSONL for {len(results)} items with model {embedding_model}" + f"[build_embedding_jsonl] Building JSONL | items={len(results)} | model={embedding_model}" ) jsonl_data = [] @@ -253,7 +252,7 @@ def calculate_cosine_similarity(vec1: list[float], vec2: list[float]) -> float: def calculate_average_similarity( - embedding_pairs: list[dict[str, Any]] + embedding_pairs: list[dict[str, Any]], ) -> dict[str, Any]: """ Calculate cosine similarity statistics for all embedding pairs. diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index 12b89266..fbc2d231 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -19,10 +19,10 @@ from openai import OpenAI from sqlmodel import Session, select -from app.core.batch.openai import OpenAIBatchProvider -from app.crud.batch_job import get_batch_job -from app.crud.batch_operations import ( +from app.core.batch import ( + OpenAIBatchProvider, download_batch_results, + poll_batch_status, upload_batch_results_to_object_store, ) from app.crud.evaluations.batch import fetch_dataset_items @@ -36,6 +36,7 @@ create_langfuse_dataset_run, update_traces_with_cosine_scores, ) +from app.crud.job import get_batch_job from app.models import EvaluationRun from app.utils import get_langfuse_client, get_openai_client @@ -484,10 +485,6 @@ async def check_and_process_evaluation( if embedding_batch_job: # Poll embedding batch status provider = OpenAIBatchProvider(client=openai_client) - - # Local import to avoid circular dependency with batch_operations - from app.crud.batch_operations import poll_batch_status - poll_batch_status( session=session, provider=provider, batch_job=embedding_batch_job ) @@ -560,8 +557,6 @@ async def check_and_process_evaluation( # IMPORTANT: Poll OpenAI to get the latest status before checking provider = OpenAIBatchProvider(client=openai_client) - from app.crud.batch_operations import poll_batch_status - poll_batch_status(session=session, provider=provider, batch_job=batch_job) # Refresh batch_job to get the updated provider_status diff --git a/backend/app/crud/job/__init__.py b/backend/app/crud/job/__init__.py new file mode 100644 index 00000000..a7bbf558 --- /dev/null +++ b/backend/app/crud/job/__init__.py @@ -0,0 +1,24 @@ +"""Job-related CRUD operations. + +For batch operations (start_batch_job, poll_batch_status, etc.), +import directly from app.core.batch instead. +""" + +from app.crud.job.job import ( + create_batch_job, + delete_batch_job, + get_batch_job, + get_batch_jobs_by_ids, + get_batches_by_type, + update_batch_job, +) + +__all__ = [ + # CRUD operations + "create_batch_job", + "get_batch_job", + "update_batch_job", + "get_batch_jobs_by_ids", + "get_batches_by_type", + "delete_batch_job", +] diff --git a/backend/app/crud/job/batch.py b/backend/app/crud/job/batch.py new file mode 100644 index 00000000..4ee8c084 --- /dev/null +++ b/backend/app/crud/job/batch.py @@ -0,0 +1,22 @@ +""" +Batch operations re-export layer. + +This module provides convenient imports for batch-related operations +while the actual implementation lives in app.core.batch. +""" + +from app.core.batch.operations import ( + download_batch_results, + process_completed_batch, + start_batch_job, + upload_batch_results_to_object_store, +) +from app.core.batch.polling import poll_batch_status + +__all__ = [ + "start_batch_job", + "download_batch_results", + "process_completed_batch", + "upload_batch_results_to_object_store", + "poll_batch_status", +] diff --git a/backend/app/crud/batch_job.py b/backend/app/crud/job/job.py similarity index 100% rename from backend/app/crud/batch_job.py rename to backend/app/crud/job/job.py diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ac7e89d6..a4d76ee2 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,12 @@ from .collection import ( Collection, + CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, + DeletionRequest, + ProviderType, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 57e5a17b..322262bb 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import Any, Literal from uuid import UUID, uuid4 @@ -7,65 +8,87 @@ from app.core.util import now from app.models.document import DocumentPublic - -from .organization import Organization from .project import Project +class ProviderType(str, Enum): + """Supported LLM providers for collections.""" + + OPENAI = "OPENAI" + # BEDROCK = "bedrock" + # GEMINI = "gemini" + + class Collection(SQLModel, table=True): """Database model for Collection operations.""" id: UUID = Field( default_factory=uuid4, primary_key=True, + description="Unique identifier for the collection", sa_column_kwargs={"comment": "Unique identifier for the collection"}, ) + provider: ProviderType = ( + Field( + nullable=False, + description="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", + sa_column_kwargs={"LLM provider used for this collection"}, + ), + ) llm_service_id: str = Field( nullable=False, + description="External LLM service identifier (e.g., OpenAI vector store ID)", sa_column_kwargs={ "comment": "External LLM service identifier (e.g., OpenAI vector store ID)" }, ) llm_service_name: str = Field( nullable=False, + description="Name of the LLM service", sa_column_kwargs={"comment": "Name of the LLM service"}, ) - - # Foreign keys - organization_id: int = Field( - foreign_key="organization.id", - nullable=False, - ondelete="CASCADE", - sa_column_kwargs={"comment": "Reference to the organization"}, + name: str = Field( + nullable=True, + unique=True, + description="Name of the collection", + sa_column_kwargs={"comment": "Name of the collection"}, + ) + description: str = Field( + nullable=True, + description="Description of the collection", + sa_column_kwargs={"comment": "Description of the collection"}, ) project_id: int = Field( foreign_key="project.id", nullable=False, ondelete="CASCADE", + description="Project the collection belongs to", sa_column_kwargs={"comment": "Reference to the project"}, ) - - # Timestamps inserted_at: datetime = Field( default_factory=now, + description="Timestamp when the collection was created", sa_column_kwargs={"comment": "Timestamp when the collection was created"}, ) updated_at: datetime = Field( default_factory=now, + description="Timestamp when the collection was updated", sa_column_kwargs={"comment": "Timestamp when the collection was last updated"}, ) deleted_at: datetime | None = Field( default=None, + description="Timestamp when the collection was deleted", sa_column_kwargs={"comment": "Timestamp when the collection was deleted"}, ) - - # Relationships - organization: Organization = Relationship(back_populates="collections") project: Project = Relationship(back_populates="collections") # Request models -class DocumentOptions(SQLModel): +class CollectionOptions(SQLModel): + name: str | None = Field(default=None, description="Name of the collection") + description: str | None = Field( + default=None, description="Description of the collection" + ) documents: list[UUID] = Field( description="List of document IDs", ) @@ -154,9 +177,9 @@ class ProviderOptions(SQLModel): class CreationRequest( - DocumentOptions, - ProviderOptions, AssistantOptions, + CollectionOptions, + ProviderOptions, CallbackRequest, ): def extract_super_type(self, cls: "CreationRequest"): @@ -181,7 +204,6 @@ class CollectionPublic(SQLModel): llm_service_id: str llm_service_name: str project_id: int - organization_id: int inserted_at: datetime updated_at: datetime diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 0f936607..b9ff9a7f 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -74,9 +74,6 @@ class Organization(OrganizationBase, table=True): assistants: list["Assistant"] = Relationship( back_populates="organization", cascade_delete=True ) - collections: list["Collection"] = Relationship( - back_populates="organization", cascade_delete=True - ) openai_conversations: list["OpenAIConversation"] = Relationship( back_populates="organization", cascade_delete=True ) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ed83e4a8..1522f863 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -6,7 +6,6 @@ from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage -from app.core.util import now from app.core.db import engine from app.crud import ( CollectionCrud, @@ -14,7 +13,6 @@ DocumentCollectionCrud, CollectionJobCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, CollectionJob, @@ -22,19 +20,12 @@ CollectionJobUpdate, CollectionPublic, CollectionJobPublic, -) -from app.models.collection import ( CreationRequest, - AssistantOptions, -) -from app.services.collections.helpers import ( - _backout, - batch_documents, - extract_error_message, - OPENAI_VECTOR_STORE, ) +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -60,8 +51,8 @@ def start_job( project_id=project_id, job_id=str(collection_job_id), trace_id=trace_id, - with_assistant=with_assistant, request=request.model_dump(mode="json"), + with_assistant=with_assistant, organization_id=organization_id, ) @@ -116,26 +107,6 @@ def build_failure_payload(collection_job: CollectionJob, error_message: str) -> ) -def _cleanup_remote_resources( - assistant, - assistant_crud, - vector_store, - vector_store_crud, -) -> None: - """Best-effort cleanup of partially created remote resources.""" - try: - if assistant is not None and assistant_crud is not None: - _backout(assistant_crud, assistant.id) - elif vector_store is not None and vector_store_crud is not None: - _backout(vector_store_crud, vector_store.id) - else: - logger.warning( - "[create_collection._backout] Skipping: no resource/crud available" - ) - except Exception: - logger.warning("[create_collection.execute_job] Backout failed") - - def _mark_job_failed( project_id: int, job_id: str, @@ -163,29 +134,32 @@ def _mark_job_failed( def execute_job( request: dict, + with_assistant: bool, project_id: int, organization_id: int, task_id: str, job_id: str, - with_assistant: bool, task_instance, ) -> None: """ Worker entrypoint scheduled by start_job. - Orchestrates: job state, client/storage init, batching, vector-store upload, + Orchestrates: job state, provider init, collection creation, optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - # Keep references for potential backout/cleanup on failure - assistant = None - assistant_crud = None - vector_store = None - vector_store_crud = None + # Keeping the references for potential backout/cleanup on failure collection_job = None + result = None + provider = None try: creation_request = CreationRequest(**request) + if ( + with_assistant == True + ): # this will be removed once dalgo switches to vector store creation only + creation_request.provider = "openai" + job_uuid = UUID(job_id) with Session(engine) as session: @@ -199,50 +173,28 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) storage = get_cloud_storage(session=session, project_id=project_id) - - # Batch documents for upload, and flatten for linking/metrics later document_crud = DocumentCrud(session, project_id) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, ) - flat_docs = [doc for batch in docs_batches for doc in batch] - vector_store_crud = OpenAIVectorStoreCrud(client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + result = provider.create( + collection_request=creation_request, + storage=storage, + document_crud=document_crud, + ) - # if with_assistant is true, create assistant backed by the vector store - if with_assistant: - assistant_crud = OpenAIAssistantCrud(client) + llm_service_id = result.llm_service_id + llm_service_name = result.llm_service_name - # Filter out None to avoid sending unset options - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store.id, **assistant_options) - llm_service_id = assistant.id - llm_service_name = assistant_options.get("model") or "assistant" - - logger.info( - "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store.id, - ) - else: - # If no assistant, the collection points directly at the vector store - llm_service_id = vector_store.id - llm_service_name = OPENAI_VECTOR_STORE - logger.info( - "[execute_job] Skipping assistant creation | with_assistant=False" - ) + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_each(creation_request.documents) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} file_sizes_kb = [ @@ -256,14 +208,15 @@ def execute_job( collection = Collection( id=collection_id, project_id=project_id, - organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, + provider=creation_request.provider.upper(), + name=creation_request.name, + description=creation_request.description, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - # Link documents to the new collection if flat_docs: DocumentCollectionCrud(session).create(collection, flat_docs) @@ -299,12 +252,13 @@ def execute_job( exc_info=True, ) - _cleanup_remote_resources( - assistant=assistant, - assistant_crud=assistant_crud, - vector_store=vector_store, - vector_store_crud=vector_store_crud, - ) + if provider is not None and result is not None: + try: + provider.cleanup(result) + except Exception: + logger.warning( + "[create_collection.execute_job] Provider cleanup failed" + ) collection_job = _mark_job_failed( project_id=project_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index ca337b79..e175301c 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -6,7 +6,6 @@ from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( CollectionJobStatus, CollectionJobUpdate, @@ -15,9 +14,10 @@ CollectionIDPublic, ) from app.models.collection import DeletionRequest -from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -155,7 +155,6 @@ def execute_job( job_uuid = UUID(job_id) collection_job = None - client = None try: with Session(engine) as session: @@ -169,20 +168,16 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) - collection = CollectionCrud(session, project_id).read_one(collection_id) - # Identify which external service (assistant/vector store) this collection belongs to - service = (collection.llm_service_name or "").strip().lower() - is_vector = service == OPENAI_VECTOR_STORE - llm_service_id = collection.llm_service_id + provider = get_llm_provider( + session=session, + provider=collection.provider.lower(), + project_id=project_id, + organization_id=organization_id, + ) - # Delete the corresponding OpenAI resource (vector store or assistant) - if is_vector: - OpenAIVectorStoreCrud(client).delete(llm_service_id) - else: - OpenAIAssistantCrud(client).delete(llm_service_id) + provider.delete(collection) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..0665260a 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -5,17 +5,26 @@ from uuid import UUID from typing import List +from fastapi import HTTPException from sqlmodel import select from openai import OpenAIError -from app.crud.document.document import DocumentCrud +from app.crud import DocumentCrud, CollectionCrud +from app.api.deps import SessionDep from app.models import DocumentCollection, Collection logger = logging.getLogger(__name__) -# llm service name for when only an openai vector store is being made -OPENAI_VECTOR_STORE = "openai vector store" + +def get_service_name(provider: str) -> str: + """Get the collection service name for a provider.""" + names = { + "openai": "openai vector store", + # "bedrock": "bedrock knowledge base", + # "gemini": "gemini file search store", + } + return names.get(provider.lower(), "") def extract_error_message(err: Exception) -> str: @@ -101,4 +110,23 @@ def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): service = ( (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" ) - return v_crud if service == OPENAI_VECTOR_STORE else a_crud + return v_crud if service == get_service_name("openai") else a_crud + + +def ensure_unique_name( + session: SessionDep, + project_id: int, + requested_name: str, +) -> str: + """ + Ensure collection name is unique based on strategy. + + """ + existing = CollectionCrud(session, project_id).exists_by_name(requested_name) + if existing: + raise HTTPException( + status_code=409, + detail=f"Collection '{requested_name}' already exists. Choose a different name.", + ) + + return requested_name diff --git a/backend/app/services/collections/providers/__init__.py b/backend/app/services/collections/providers/__init__.py new file mode 100644 index 00000000..5a9b6a55 --- /dev/null +++ b/backend/app/services/collections/providers/__init__.py @@ -0,0 +1,6 @@ +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider +from app.services.collections.providers.registry import ( + LLMProvider, + get_llm_provider, +) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py new file mode 100644 index 00000000..e4c91d9c --- /dev/null +++ b/backend/app/services/collections/providers/base.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import Any + +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.models import CreationRequest, Collection + + +class BaseProvider(ABC): + """Abstract base class for collection providers. + + All provider implementations (OpenAI, Bedrock, etc.) must inherit from + this class and implement the required methods. + + Providers handle creation of collection and + optional assistant/agent creation backed by those collections. + + Attributes: + client: The provider-specific client instance + """ + + def __init__(self, client: Any): + """Initialize provider with client. + + Args: + client: Provider-specific client instance + """ + self.client = client + + @abstractmethod + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> Collection: + """Create collection with documents and optionally an assistant. + + Args: + collection_request: Collection parameters (name, description, document list, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant/agent + assistant_options: Options for assistant creation (provider-specific) + + Returns: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement execute method") + + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection. + + Called when a collection is being deleted and remote resources need to be cleaned up. + + Args: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement delete method") + + @abstractmethod + def cleanup(self, collection: Collection) -> None: + """Clean up/rollback resources created during execute. + + Called when collection creation fails and remote resources need to be deleted. + + Args: + collection_result: The CreateCollectionresult returned from execute, containing resource IDs + """ + raise NotImplementedError("Providers must implement cleanup method") + + def get_provider_name(self) -> str: + """Get the name of the provider. + + Returns: + Provider name (e.g., "openai", "bedrock", "pinecone") + """ + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py new file mode 100644 index 00000000..a33bd854 --- /dev/null +++ b/backend/app/services/collections/providers/openai.py @@ -0,0 +1,124 @@ +import logging + +from openai import OpenAI + +from app.services.collections.providers import BaseProvider +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.services.collections.helpers import batch_documents, get_service_name, _backout +from app.models import CreationRequest, Collection + + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """OpenAI-specific collection provider for vector stores and assistants.""" + + def __init__(self, client: OpenAI): + super().__init__(client) + self.client = client + + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> Collection: + """ + Create OpenAI vector store with documents and optionally an assistant. + """ + try: + docs_batches = batch_documents( + document_crud, + collection_request.documents, + collection_request.batch_size, + ) + + vector_store_crud = OpenAIVectorStoreCrud(self.client) + vector_store = vector_store_crud.create() + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + logger.info( + "[OpenAIProvider.execute] Vector store created | " + f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" + ) + + # Check if we need to create an assistant (based on assistant options in request) + with_assistant = ( + collection_request.model is not None + and collection_request.instructions is not None + ) + if with_assistant: + assistant_crud = OpenAIAssistantCrud(self.client) + + assistant_options = { + "model": collection_request.model, + "instructions": collection_request.instructions, + "temperature": collection_request.temperature, + } + filtered_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **filtered_options) + + logger.info( + "[OpenAIProvider.execute] Assistant created | " + f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + ) + + return Collection( + llm_service_id=assistant.id, + llm_service_name=filtered_options.get("model", "assistant"), + ) + else: + logger.info( + "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" + ) + + return Collection( + llm_service_id=vector_store.id, + llm_service_name=get_service_name("openai"), + ) + + except Exception as e: + logger.error( + f"[OpenAIProvider.execute] Failed to create collection: {str(e)}", + exc_info=True, + ) + raise + + def delete(self, collection: Collection) -> None: + """Delete OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + """ + try: + if collection.llm_service_name != get_service_name("openai"): + OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" + ) + else: + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) + except Exception as e: + logger.error( + f"[OpenAIProvider.delete] Failed to delete resource | " + f"llm_service_id={collection.llm_service_id}, error={str(e)}", + exc_info=True, + ) + raise + + def cleanup(self, result: Collection) -> None: + """ + Clean up OpenAI resources (assistant or vector store). + """ + _backout(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/services/collections/providers/registry.py b/backend/app/services/collections/providers/registry.py new file mode 100644 index 00000000..10d07d45 --- /dev/null +++ b/backend/app/services/collections/providers/registry.py @@ -0,0 +1,71 @@ +import logging + +from sqlmodel import Session +from openai import OpenAI + +from app.crud import get_provider_credential +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider + + +logger = logging.getLogger(__name__) + + +class LLMProvider: + OPENAI = "openai" + # Future constants for providers: + # ANTHROPIC = "ANTHROPIC" + # GEMINI = "gemini" + + _registry: dict[str, type[BaseProvider]] = { + OPENAI: OpenAIProvider, + # Future providers: + # ANTHROPIC: BedrockProvider, + # GEMINI: GeminiProvider, + } + + @classmethod + def get(cls, name: str) -> type[BaseProvider]: + """Return the provider class for a given name.""" + provider = cls._registry.get(name) + if not provider: + raise ValueError( + f"Provider '{name}' is not supported. " + f"Supported providers: {', '.join(cls._registry.keys())}" + ) + return provider + + @classmethod + def supported_providers(cls) -> list[str]: + """Return a list of supported provider names.""" + return list(cls._registry.keys()) + + +def get_llm_provider( + session: Session, provider: str, project_id: int, organization_id: int +) -> BaseProvider: + provider_class = LLMProvider.get(provider) + + credentials = get_provider_credential( + session=session, + provider=provider, + project_id=project_id, + org_id=organization_id, + ) + + if not credentials: + raise ValueError( + f"Credentials for provider '{provider}' not configured for this project." + ) + + if provider == LLMProvider.OPENAI: + if "api_key" not in credentials: + raise ValueError("OpenAI credentials not configured for this project.") + client = OpenAI(api_key=credentials["api_key"]) + else: + logger.error( + f"[get_llm_provider] Unsupported provider type requested: {provider}" + ) + raise ValueError(f"Provider '{provider}' is not supported.") + + return provider_class(client=client) diff --git a/backend/app/services/evaluations/__init__.py b/backend/app/services/evaluations/__init__.py new file mode 100644 index 00000000..62201b42 --- /dev/null +++ b/backend/app/services/evaluations/__init__.py @@ -0,0 +1,16 @@ +"""Evaluation services.""" + +from app.services.evaluations.dataset import upload_dataset +from app.services.evaluations.evaluation import ( + build_evaluation_config, + get_evaluation_with_scores, + start_evaluation, +) +from app.services.evaluations.validators import ( + ALLOWED_EXTENSIONS, + ALLOWED_MIME_TYPES, + MAX_FILE_SIZE, + parse_csv_items, + sanitize_dataset_name, + validate_csv_file, +) diff --git a/backend/app/services/evaluations/dataset.py b/backend/app/services/evaluations/dataset.py new file mode 100644 index 00000000..fe0e8924 --- /dev/null +++ b/backend/app/services/evaluations/dataset.py @@ -0,0 +1,163 @@ +"""Dataset management service for evaluations.""" + +import logging + +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.crud.evaluations import ( + create_evaluation_dataset, + upload_csv_to_object_store, + upload_dataset_to_langfuse, +) +from app.models.evaluation import EvaluationDataset +from app.services.evaluations.validators import ( + parse_csv_items, + sanitize_dataset_name, +) +from app.utils import get_langfuse_client + +logger = logging.getLogger(__name__) + + +def upload_dataset( + session: Session, + csv_content: bytes, + dataset_name: str, + description: str | None, + duplication_factor: int, + organization_id: int, + project_id: int, +) -> EvaluationDataset: + """ + Orchestrate dataset upload workflow. + + Steps: + 1. Sanitize dataset name + 2. Parse and validate CSV + 3. Upload to object store + 4. Upload to Langfuse + 5. Store metadata in database + + Args: + session: Database session + csv_content: Raw CSV file content + dataset_name: Name for the dataset + description: Optional dataset description + duplication_factor: Number of times to duplicate each item + organization_id: Organization ID + project_id: Project ID + + Returns: + Created EvaluationDataset record + + Raises: + HTTPException: If upload fails at any step + """ + # Step 1: Sanitize dataset name for Langfuse compatibility + original_name = dataset_name + try: + dataset_name = sanitize_dataset_name(dataset_name) + except ValueError as e: + raise HTTPException(status_code=422, detail=f"Invalid dataset name: {str(e)}") + + if original_name != dataset_name: + logger.info( + f"[upload_dataset] Dataset name sanitized | '{original_name}' -> '{dataset_name}'" + ) + + logger.info( + f"[upload_dataset] Uploading dataset | dataset={dataset_name} | " + f"duplication_factor={duplication_factor} | org_id={organization_id} | " + f"project_id={project_id}" + ) + + # Step 2: Parse CSV and extract items + original_items = parse_csv_items(csv_content) + original_items_count = len(original_items) + total_items_count = original_items_count * duplication_factor + + logger.info( + f"[upload_dataset] Parsed items from CSV | original={original_items_count} | " + f"total_with_duplication={total_items_count}" + ) + + # Step 3: Upload to object store (if credentials configured) + object_store_url = None + try: + storage = get_cloud_storage(session=session, project_id=project_id) + object_store_url = upload_csv_to_object_store( + storage=storage, csv_content=csv_content, dataset_name=dataset_name + ) + if object_store_url: + logger.info( + f"[upload_dataset] Successfully uploaded CSV to object store | {object_store_url}" + ) + else: + logger.info( + "[upload_dataset] Object store upload returned None | " + "continuing without object store storage" + ) + except Exception as e: + logger.warning( + f"[upload_dataset] Failed to upload CSV to object store " + f"(continuing without object store) | {e}", + exc_info=True, + ) + object_store_url = None + + # Step 4: Upload to Langfuse + langfuse_dataset_id = None + try: + langfuse = get_langfuse_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + + langfuse_dataset_id, _ = upload_dataset_to_langfuse( + langfuse=langfuse, + items=original_items, + dataset_name=dataset_name, + duplication_factor=duplication_factor, + ) + + logger.info( + f"[upload_dataset] Successfully uploaded dataset to Langfuse | " + f"dataset={dataset_name} | id={langfuse_dataset_id}" + ) + + except Exception as e: + logger.error( + f"[upload_dataset] Failed to upload dataset to Langfuse | {e}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to upload dataset to Langfuse: {e}" + ) + + # Step 5: Store metadata in database + metadata = { + "original_items_count": original_items_count, + "total_items_count": total_items_count, + "duplication_factor": duplication_factor, + } + + dataset = create_evaluation_dataset( + session=session, + name=dataset_name, + description=description, + dataset_metadata=metadata, + object_store_url=object_store_url, + langfuse_dataset_id=langfuse_dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + logger.info( + f"[upload_dataset] Successfully created dataset record in database | " + f"id={dataset.id} | name={dataset_name}" + ) + + return dataset diff --git a/backend/app/services/evaluations/evaluation.py b/backend/app/services/evaluations/evaluation.py new file mode 100644 index 00000000..bf0d4dd0 --- /dev/null +++ b/backend/app/services/evaluations/evaluation.py @@ -0,0 +1,319 @@ +"""Evaluation run orchestration service.""" + +import logging + +from fastapi import HTTPException +from sqlmodel import Session + +from app.crud.assistants import get_assistant_by_id +from app.crud.evaluations import ( + create_evaluation_run, + fetch_trace_scores_from_langfuse, + get_dataset_by_id, + get_evaluation_run_by_id, + save_score, + start_evaluation_batch, +) +from app.models.evaluation import EvaluationRun +from app.utils import get_langfuse_client, get_openai_client + +logger = logging.getLogger(__name__) + + +def build_evaluation_config( + session: Session, + config: dict, + assistant_id: str | None, + project_id: int, +) -> dict: + """ + Build evaluation configuration from assistant or provided config. + + If assistant_id is provided, fetch assistant and merge with config. + Config values take precedence over assistant values. + + Args: + session: Database session + config: Provided configuration dict + assistant_id: Optional assistant ID to fetch configuration from + project_id: Project ID for assistant lookup + + Returns: + Complete evaluation configuration dict + + Raises: + HTTPException: If assistant not found or model missing + """ + if assistant_id: + assistant = get_assistant_by_id( + session=session, + assistant_id=assistant_id, + project_id=project_id, + ) + + if not assistant: + raise HTTPException( + status_code=404, detail=f"Assistant {assistant_id} not found" + ) + + logger.info( + f"[build_evaluation_config] Found assistant in DB | id={assistant.id} | " + f"model={assistant.model} | instructions=" + f"{assistant.instructions[:50] if assistant.instructions else 'None'}..." + ) + + # Build config from assistant (use provided config values to override if present) + merged_config = { + "model": config.get("model", assistant.model), + "instructions": config.get("instructions", assistant.instructions), + "temperature": config.get("temperature", assistant.temperature), + } + + # Add tools if vector stores are available + vector_store_ids = config.get( + "vector_store_ids", assistant.vector_store_ids or [] + ) + if vector_store_ids and len(vector_store_ids) > 0: + merged_config["tools"] = [ + { + "type": "file_search", + "vector_store_ids": vector_store_ids, + } + ] + + logger.info("[build_evaluation_config] Using config from assistant") + return merged_config + + # Using provided config directly + logger.info("[build_evaluation_config] Using provided config directly") + + # Validate that config has minimum required fields + if not config.get("model"): + raise HTTPException( + status_code=400, + detail="Config must include 'model' when assistant_id is not provided", + ) + + return config + + +def start_evaluation( + session: Session, + dataset_id: int, + experiment_name: str, + config: dict, + assistant_id: str | None, + organization_id: int, + project_id: int, +) -> EvaluationRun: + """ + Start an evaluation run. + + Steps: + 1. Validate dataset exists and has Langfuse ID + 2. Build config (from assistant or direct) + 3. Create evaluation run record + 4. Start batch processing + + Args: + session: Database session + dataset_id: ID of the evaluation dataset + experiment_name: Name for this evaluation experiment/run + config: Evaluation configuration + assistant_id: Optional assistant ID to fetch configuration from + organization_id: Organization ID + project_id: Project ID + + Returns: + EvaluationRun instance + + Raises: + HTTPException: If dataset not found or evaluation fails to start + """ + logger.info( + f"[start_evaluation] Starting evaluation | experiment_name={experiment_name} | " + f"dataset_id={dataset_id} | " + f"org_id={organization_id} | " + f"assistant_id={assistant_id} | " + f"config_keys={list(config.keys())}" + ) + + dataset = get_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset {dataset_id} not found or not accessible to this " + f"organization/project", + ) + + logger.info( + f"[start_evaluation] Found dataset | id={dataset.id} | name={dataset.name} | " + f"object_store_url={'present' if dataset.object_store_url else 'None'} | " + f"langfuse_id={dataset.langfuse_dataset_id}" + ) + + if not dataset.langfuse_dataset_id: + raise HTTPException( + status_code=400, + detail=f"Dataset {dataset_id} does not have a Langfuse dataset ID. " + "Please ensure Langfuse credentials were configured when the dataset was created.", + ) + + eval_config = build_evaluation_config( + session=session, + config=config, + assistant_id=assistant_id, + project_id=project_id, + ) + + openai_client = get_openai_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + langfuse = get_langfuse_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + + eval_run = create_evaluation_run( + session=session, + run_name=experiment_name, + dataset_name=dataset.name, + dataset_id=dataset_id, + config=eval_config, + organization_id=organization_id, + project_id=project_id, + ) + + try: + eval_run = start_evaluation_batch( + langfuse=langfuse, + openai_client=openai_client, + session=session, + eval_run=eval_run, + config=eval_config, + ) + + logger.info( + f"[start_evaluation] Evaluation started successfully | " + f"batch_job_id={eval_run.batch_job_id} | total_items={eval_run.total_items}" + ) + + return eval_run + + except Exception as e: + logger.error( + f"[start_evaluation] Failed to start evaluation | run_id={eval_run.id} | {e}", + exc_info=True, + ) + # Error is already handled in start_evaluation_batch + session.refresh(eval_run) + return eval_run + + +def get_evaluation_with_scores( + session: Session, + evaluation_id: int, + organization_id: int, + project_id: int, + get_trace_info: bool, + resync_score: bool, +) -> tuple[EvaluationRun | None, str | None]: + """ + Get evaluation run, optionally with trace scores from Langfuse. + + Handles caching logic for trace scores - scores are fetched on first request + and cached in the database. + + Args: + session: Database session + evaluation_id: ID of the evaluation run + organization_id: Organization ID + project_id: Project ID + get_trace_info: If true, fetch trace scores + resync_score: If true, clear cached scores and re-fetch + + Returns: + Tuple of (EvaluationRun or None, error_message or None) + """ + logger.info( + f"[get_evaluation_with_scores] Fetching status for evaluation run | " + f"evaluation_id={evaluation_id} | " + f"org_id={organization_id} | " + f"project_id={project_id} | " + f"get_trace_info={get_trace_info} | " + f"resync_score={resync_score}" + ) + + eval_run = get_evaluation_run_by_id( + session=session, + evaluation_id=evaluation_id, + organization_id=organization_id, + project_id=project_id, + ) + + if not eval_run: + return None, None + + if not get_trace_info: + return eval_run, None + + # Only fetch trace info for completed evaluations + if eval_run.status != "completed": + return eval_run, ( + f"Trace info is only available for completed evaluations. " + f"Current status: {eval_run.status}" + ) + + # Check if we already have cached scores + has_cached_score = eval_run.score is not None and "traces" in eval_run.score + if not resync_score and has_cached_score: + return eval_run, None + + langfuse = get_langfuse_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + + # Capture data needed for Langfuse fetch and DB update + dataset_name = eval_run.dataset_name + run_name = eval_run.run_name + eval_run_id = eval_run.id + + try: + score = fetch_trace_scores_from_langfuse( + langfuse=langfuse, + dataset_name=dataset_name, + run_name=run_name, + ) + except ValueError as e: + logger.warning( + f"[get_evaluation_with_scores] Run not found in Langfuse | " + f"evaluation_id={evaluation_id} | error={e}" + ) + return eval_run, str(e) + except Exception as e: + logger.error( + f"[get_evaluation_with_scores] Failed to fetch trace info | " + f"evaluation_id={evaluation_id} | error={e}", + exc_info=True, + ) + return eval_run, f"Failed to fetch trace info from Langfuse: {str(e)}" + + eval_run = save_score( + eval_run_id=eval_run_id, + organization_id=organization_id, + project_id=project_id, + score=score, + ) + + return eval_run, None diff --git a/backend/app/services/evaluations/validators.py b/backend/app/services/evaluations/validators.py new file mode 100644 index 00000000..92733a2f --- /dev/null +++ b/backend/app/services/evaluations/validators.py @@ -0,0 +1,169 @@ +"""Validation utilities for evaluation datasets.""" + +import csv +import io +import logging +import re +from pathlib import Path + +from fastapi import HTTPException, UploadFile + +logger = logging.getLogger(__name__) + +MAX_FILE_SIZE = 1024 * 1024 # 1 MB +ALLOWED_EXTENSIONS = {".csv"} +ALLOWED_MIME_TYPES = { + "text/csv", + "application/csv", + "text/plain", +} + + +def sanitize_dataset_name(name: str) -> str: + """ + Sanitize dataset name for Langfuse compatibility. + + Langfuse has issues with spaces and special characters in dataset names. + This function ensures the name can be both created and fetched. + + Rules: + - Replace spaces with underscores + - Replace hyphens with underscores + - Keep only alphanumeric characters and underscores + - Convert to lowercase for consistency + - Remove leading/trailing underscores + - Collapse multiple consecutive underscores into one + + Args: + name: Original dataset name + + Returns: + Sanitized dataset name safe for Langfuse + + Examples: + "testing 0001" -> "testing_0001" + "My Dataset!" -> "my_dataset" + "Test--Data__Set" -> "test_data_set" + """ + sanitized = name.lower() + + # Replace spaces and hyphens with underscores + sanitized = sanitized.replace(" ", "_").replace("-", "_") + + # Keep only alphanumeric characters and underscores + sanitized = re.sub(r"[^a-z0-9_]", "", sanitized) + + # Collapse multiple underscores into one + sanitized = re.sub(r"_+", "_", sanitized) + + sanitized = sanitized.strip("_") + + if not sanitized: + raise ValueError("Dataset name cannot be empty after sanitization") + + return sanitized + + +async def validate_csv_file(file: UploadFile) -> bytes: + """ + Validate CSV file extension, MIME type, and size. + + Args: + file: The uploaded file + + Returns: + CSV content as bytes if valid + + Raises: + HTTPException: If validation fails + """ + if not file.filename: + raise HTTPException( + status_code=422, + detail="File must have a filename", + ) + file_ext = Path(file.filename).suffix.lower() + if file_ext not in ALLOWED_EXTENSIONS: + raise HTTPException( + status_code=422, + detail=f"Invalid file type. Only CSV files are allowed. Got: {file_ext}", + ) + + content_type = file.content_type + if content_type not in ALLOWED_MIME_TYPES: + raise HTTPException( + status_code=422, + detail=f"Invalid content type. Expected CSV, got: {content_type}", + ) + + file.file.seek(0, 2) + file_size = file.file.tell() + file.file.seek(0) + + if file_size > MAX_FILE_SIZE: + raise HTTPException( + status_code=413, + detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024 * 1024):.0f}MB", + ) + + if file_size == 0: + raise HTTPException(status_code=422, detail="Empty file uploaded") + + return await file.read() + + +def parse_csv_items(csv_content: bytes) -> list[dict[str, str]]: + """ + Parse CSV and extract question/answer pairs. + + Args: + csv_content: CSV file content as bytes + + Returns: + List of dicts with 'question' and 'answer' keys + + Raises: + HTTPException: If CSV is invalid or empty + """ + try: + csv_text = csv_content.decode("utf-8") + csv_reader = csv.DictReader(io.StringIO(csv_text)) + + if not csv_reader.fieldnames: + raise HTTPException(status_code=422, detail="CSV file has no headers") + + # Normalize headers for case-insensitive matching + clean_headers = { + field.strip().lower(): field for field in csv_reader.fieldnames + } + + # Validate required headers (case-insensitive) + if "question" not in clean_headers or "answer" not in clean_headers: + raise HTTPException( + status_code=422, + detail=f"CSV must contain 'question' and 'answer' columns " + f"Found columns: {csv_reader.fieldnames}", + ) + + question_col = clean_headers["question"] + answer_col = clean_headers["answer"] + + items = [] + for row in csv_reader: + question = row.get(question_col, "").strip() + answer = row.get(answer_col, "").strip() + if question and answer: + items.append({"question": question, "answer": answer}) + + if not items: + raise HTTPException( + status_code=422, detail="No valid items found in CSV file" + ) + + return items + + except HTTPException: + raise + except Exception as e: + logger.error(f"[parse_csv_items] Failed to parse CSV | {e}", exc_info=True) + raise HTTPException(status_code=422, detail=f"Invalid CSV file: {e}") diff --git a/backend/app/tests/api/routes/collections/test_collection_delete.py b/backend/app/tests/api/routes/collections/test_collection_delete.py index 99e61d67..8f5e72fa 100644 --- a/backend/app/tests/api/routes/collections/test_collection_delete.py +++ b/backend/app/tests/api/routes/collections/test_collection_delete.py @@ -7,7 +7,7 @@ from app.core.config import settings from app.models import CollectionJobStatus from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection @patch("app.api.routes.collections.delete_service.start_job") @@ -26,7 +26,7 @@ def test_delete_collection_calls_start_job_and_returns_job( - Calls delete_service.start_job with correct arguments """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) resp = client.request( "DELETE", @@ -70,7 +70,7 @@ def test_delete_collection_with_callback_url_passes_it_to_start_job( into the DeletionRequest and then into delete_service.start_job. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) payload = { "callback_url": "https://example.com/collections/delete-callback", diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 90f8b80c..ee594f62 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -6,9 +6,13 @@ from app.core.config import settings from app.tests.utils.utils import get_project, get_document -from app.tests.utils.collection import get_collection, get_vector_store_collection +from app.tests.utils.collection import ( + get_assistant_collection, + get_vector_store_collection, +) from app.crud import DocumentCollectionCrud from app.models import Collection, Document +from app.services.collections.helpers import get_service_name def link_document_to_collection( @@ -40,13 +44,13 @@ def test_collection_info_returns_assistant_collection_with_docs( ): """ Happy path: - - Assistant-style collection (get_collection) + - Assistant-style collection (get_assistant_collection) - include_docs = True (default) - At least one document linked """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) document = link_document_to_collection(db, collection) @@ -82,7 +86,7 @@ def test_collection_info_include_docs_false_returns_no_docs( When include_docs=false, the endpoint should not populate the documents list. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) link_document_to_collection(db, collection) @@ -113,7 +117,7 @@ def test_collection_info_pagination_skip_and_limit( We create multiple document links and then request a paginated slice. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) documents = db.exec( select(Document).where(Document.deleted_at.is_(None)).limit(2) @@ -148,7 +152,7 @@ def test_collection_info_vector_store_collection( via get_vector_store_collection. """ project = get_project(db, "Dalgo") - collection = get_vector_store_collection(db, project) + collection = get_vector_store_collection(db, project, provider="openai") link_document_to_collection(db, collection) @@ -163,7 +167,7 @@ def test_collection_info_vector_store_collection( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "openai vector store" + assert payload["llm_service_name"] == get_service_name("openai") assert payload["llm_service_id"] == collection.llm_service_id docs = payload.get("documents", []) diff --git a/backend/app/tests/api/routes/collections/test_collection_job_info.py b/backend/app/tests/api/routes/collections/test_collection_job_info.py index 8735b51d..e2c8b44b 100644 --- a/backend/app/tests/api/routes/collections/test_collection_job_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_job_info.py @@ -3,7 +3,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection, get_collection_job +from app.tests.utils.collection import get_assistant_collection, get_collection_job from app.models import ( CollectionActionType, CollectionJobStatus, @@ -39,7 +39,7 @@ def test_collection_info_create_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL @@ -99,7 +99,7 @@ def test_collection_info_delete_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, @@ -131,7 +131,7 @@ def test_collection_info_delete_failed( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index f7507c12..57510c09 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -4,9 +4,10 @@ from app.core.config import settings from app.tests.utils.utils import get_project from app.tests.utils.collection import ( - get_collection, + get_assistant_collection, get_vector_store_collection, ) +from app.services.collections.helpers import get_service_name def test_list_collections_returns_api_response( @@ -39,7 +40,7 @@ def test_list_collections_includes_assistant_collection( user_api_key_header, ): """ - Ensure that a newly created assistant-style collection (get_collection) + Ensure that a newly created assistant-style collection (get_assistant_collection) appears in the list for the current project. """ @@ -51,7 +52,7 @@ def test_list_collections_includes_assistant_collection( ) assert response_before.status_code == 200 - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) response_after = client.get( f"{settings.API_V1_STR}/collections/", @@ -82,7 +83,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( appear in the list and expose the expected LLM fields. """ project = get_project(db, "Dalgo") - collection = get_vector_store_collection(db, project) + collection = get_vector_store_collection(db, project, provider="openai") response = client.get( f"{settings.API_V1_STR}/collections/", @@ -101,7 +102,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( row = matching[0] assert row["project_id"] == project.id - assert row["llm_service_name"] == "openai vector store" + assert row["llm_service_name"] == get_service_name("openai") assert row["llm_service_id"] == collection.llm_service_id diff --git a/backend/app/tests/api/routes/documents/conftest.py b/backend/app/tests/api/routes/documents/conftest.py new file mode 100644 index 00000000..d36dc181 --- /dev/null +++ b/backend/app/tests/api/routes/documents/conftest.py @@ -0,0 +1,11 @@ +import pytest +from starlette.testclient import TestClient + +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.document import WebCrawler + + +@pytest.fixture +def crawler(client: TestClient, user_api_key: TestAuthContext) -> WebCrawler: + """Provides a WebCrawler instance for document API testing.""" + return WebCrawler(client, user_api_key=user_api_key) diff --git a/backend/app/tests/api/routes/documents/test_route_document_info.py b/backend/app/tests/api/routes/documents/test_route_document_info.py index 3a49013f..425cb50b 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_info.py +++ b/backend/app/tests/api/routes/documents/test_route_document_info.py @@ -7,7 +7,6 @@ DocumentStore, Route, WebCrawler, - crawler, httpx_to_standard, ) diff --git a/backend/app/tests/api/routes/documents/test_route_document_list.py b/backend/app/tests/api/routes/documents/test_route_document_list.py index a580fb2a..d7e3c1dd 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_list.py +++ b/backend/app/tests/api/routes/documents/test_route_document_list.py @@ -6,7 +6,6 @@ DocumentStore, Route, WebCrawler, - crawler, httpx_to_standard, ) diff --git a/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py b/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py index 57de20ad..32a1ea27 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py +++ b/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py @@ -16,11 +16,10 @@ from app.core.config import settings from app.models import Document from app.tests.utils.document import ( - DocumentStore, DocumentMaker, + DocumentStore, Route, WebCrawler, - crawler, ) diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index d4d2aadc..9914fbae 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -4,7 +4,7 @@ from fastapi import HTTPException from fastapi.testclient import TestClient from unittest.mock import patch -from app.tests.utils.openai import mock_openai_assistant +from app.tests.utils.llm_provider import mock_openai_assistant from app.tests.utils.utils import get_assistant from app.tests.utils.auth import TestAuthContext diff --git a/backend/app/tests/api/routes/test_evaluation.py b/backend/app/tests/api/routes/test_evaluation.py index c4eb3f0b..ec1af4f8 100644 --- a/backend/app/tests/api/routes/test_evaluation.py +++ b/backend/app/tests/api/routes/test_evaluation.py @@ -1,11 +1,15 @@ import io +from typing import Any from unittest.mock import Mock, patch import pytest -from sqlmodel import select +from fastapi.testclient import TestClient +from sqlmodel import Session, select from app.crud.evaluations.batch import build_evaluation_jsonl from app.models import EvaluationDataset, EvaluationRun +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.test_data import create_test_evaluation_dataset # Helper function to create CSV file-like object @@ -16,17 +20,17 @@ def create_csv_file(content: str) -> tuple[str, io.BytesIO]: @pytest.fixture -def valid_csv_content(): +def valid_csv_content() -> str: """Valid CSV content with question and answer columns.""" return """question,answer "Who is known as the strongest jujutsu sorcerer?","Satoru Gojo" -"What is the name of Gojo’s Domain Expansion?","Infinite Void" +"What is the name of Gojo's Domain Expansion?","Infinite Void" "Who is known as the King of Curses?","Ryomen Sukuna" """ @pytest.fixture -def invalid_csv_missing_columns(): +def invalid_csv_missing_columns() -> str: """CSV content missing required columns.""" return """query,response "Who is known as the strongest jujutsu sorcerer?","Satoru Gojo" @@ -34,7 +38,7 @@ def invalid_csv_missing_columns(): @pytest.fixture -def csv_with_empty_rows(): +def csv_with_empty_rows() -> str: """CSV content with some empty rows.""" return """question,answer "Who is known as the strongest jujutsu sorcerer?","Satoru Gojo" @@ -48,34 +52,35 @@ class TestDatasetUploadValidation: """Test CSV validation and parsing.""" def test_upload_dataset_valid_csv( - self, client, user_api_key_header, valid_csv_content, db - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + db: Session, + ) -> None: """Test uploading a valid CSV file.""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch( - "app.api.routes.evaluation.get_langfuse_client" + "app.services.evaluations.dataset.get_langfuse_client" ) as mock_get_langfuse_client, patch( - "app.api.routes.evaluation.upload_dataset_to_langfuse" + "app.services.evaluations.dataset.upload_dataset_to_langfuse" ) as mock_langfuse_upload, ): - # Mock object store upload mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" - # Mock Langfuse client mock_get_langfuse_client.return_value = Mock() - # Mock Langfuse upload mock_langfuse_upload.return_value = ("test_dataset_id", 9) filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -101,22 +106,21 @@ def test_upload_dataset_valid_csv( # Verify object store upload was called mock_store_upload.assert_called_once() - # Verify Langfuse upload was called mock_langfuse_upload.assert_called_once() def test_upload_dataset_missing_columns( self, - client, - user_api_key_header, - invalid_csv_missing_columns, - ): + client: TestClient, + user_api_key_header: dict[str, str], + invalid_csv_missing_columns: str, + ) -> None: """Test uploading CSV with missing required columns.""" filename, file_obj = create_csv_file(invalid_csv_missing_columns) # The CSV validation happens before any mocked functions are called # so this test checks the actual validation logic response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -125,7 +129,6 @@ def test_upload_dataset_missing_columns( headers=user_api_key_header, ) - # Check that the response indicates unprocessable entity assert response.status_code == 422 response_data = response.json() error_str = response_data.get( @@ -134,22 +137,24 @@ def test_upload_dataset_missing_columns( assert "question" in error_str.lower() or "answer" in error_str.lower() def test_upload_dataset_empty_rows( - self, client, user_api_key_header, csv_with_empty_rows - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + csv_with_empty_rows: str, + ) -> None: """Test uploading CSV with empty rows (should skip them).""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch( - "app.api.routes.evaluation.get_langfuse_client" + "app.services.evaluations.dataset.get_langfuse_client" ) as mock_get_langfuse_client, patch( - "app.api.routes.evaluation.upload_dataset_to_langfuse" + "app.services.evaluations.dataset.upload_dataset_to_langfuse" ) as mock_langfuse_upload, ): - # Mock object store and Langfuse uploads mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" mock_get_langfuse_client.return_value = Mock() mock_langfuse_upload.return_value = ("test_dataset_id", 4) @@ -157,7 +162,7 @@ def test_upload_dataset_empty_rows( filename, file_obj = create_csv_file(csv_with_empty_rows) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -180,19 +185,22 @@ class TestDatasetUploadDuplication: """Test duplication logic.""" def test_upload_with_default_duplication( - self, client, user_api_key_header, valid_csv_content - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + ) -> None: """Test uploading with default duplication factor (1).""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch( - "app.api.routes.evaluation.get_langfuse_client" + "app.services.evaluations.dataset.get_langfuse_client" ) as mock_get_langfuse_client, patch( - "app.api.routes.evaluation.upload_dataset_to_langfuse" + "app.services.evaluations.dataset.upload_dataset_to_langfuse" ) as mock_langfuse_upload, ): mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" @@ -202,11 +210,11 @@ def test_upload_with_default_duplication( filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", - # duplication_factor not provided, should default to 1 + # duplication_factor not provided, would default to 1 }, headers=user_api_key_header, ) @@ -218,22 +226,25 @@ def test_upload_with_default_duplication( assert data["duplication_factor"] == 1 assert data["original_items"] == 3 - assert data["total_items"] == 3 # 3 items * 1 duplication + assert data["total_items"] == 3 def test_upload_with_custom_duplication( - self, client, user_api_key_header, valid_csv_content - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + ) -> None: """Test uploading with custom duplication factor.""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch( - "app.api.routes.evaluation.get_langfuse_client" + "app.services.evaluations.dataset.get_langfuse_client" ) as mock_get_langfuse_client, patch( - "app.api.routes.evaluation.upload_dataset_to_langfuse" + "app.services.evaluations.dataset.upload_dataset_to_langfuse" ) as mock_langfuse_upload, ): mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" @@ -243,7 +254,7 @@ def test_upload_with_custom_duplication( filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -262,19 +273,23 @@ def test_upload_with_custom_duplication( assert data["total_items"] == 12 # 3 items * 4 duplication def test_upload_with_description( - self, client, user_api_key_header, valid_csv_content, db - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + db: Session, + ) -> None: """Test uploading with a description.""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch( - "app.api.routes.evaluation.get_langfuse_client" + "app.services.evaluations.dataset.get_langfuse_client" ) as mock_get_langfuse_client, patch( - "app.api.routes.evaluation.upload_dataset_to_langfuse" + "app.services.evaluations.dataset.upload_dataset_to_langfuse" ) as mock_langfuse_upload, ): mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" @@ -284,7 +299,7 @@ def test_upload_with_description( filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset_with_description", @@ -310,13 +325,16 @@ def test_upload_with_description( assert dataset.description == "This is a test dataset for evaluation" def test_upload_with_duplication_factor_below_minimum( - self, client, user_api_key_header, valid_csv_content - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + ) -> None: """Test uploading with duplication factor below minimum (0).""" filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -332,13 +350,16 @@ def test_upload_with_duplication_factor_below_minimum( assert "greater than or equal to 1" in response_data["error"] def test_upload_with_duplication_factor_above_maximum( - self, client, user_api_key_header, valid_csv_content - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + ) -> None: """Test uploading with duplication factor above maximum (6).""" filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -354,19 +375,22 @@ def test_upload_with_duplication_factor_above_maximum( assert "less than or equal to 5" in response_data["error"] def test_upload_with_duplication_factor_boundary_minimum( - self, client, user_api_key_header, valid_csv_content - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + ) -> None: """Test uploading with duplication factor at minimum boundary (1).""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch( - "app.api.routes.evaluation.get_langfuse_client" + "app.services.evaluations.dataset.get_langfuse_client" ) as mock_get_langfuse_client, patch( - "app.api.routes.evaluation.upload_dataset_to_langfuse" + "app.services.evaluations.dataset.upload_dataset_to_langfuse" ) as mock_langfuse_upload, ): mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" @@ -376,7 +400,7 @@ def test_upload_with_duplication_factor_boundary_minimum( filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -392,32 +416,33 @@ def test_upload_with_duplication_factor_boundary_minimum( assert data["duplication_factor"] == 1 assert data["original_items"] == 3 - assert data["total_items"] == 3 # 3 items * 1 duplication + assert data["total_items"] == 3 class TestDatasetUploadErrors: """Test error handling.""" def test_upload_langfuse_configuration_fails( - self, client, user_api_key_header, valid_csv_content - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + valid_csv_content: str, + ) -> None: """Test when Langfuse client configuration fails.""" with ( patch("app.core.cloud.get_cloud_storage") as _mock_storage, patch( - "app.api.routes.evaluation.upload_csv_to_object_store" + "app.services.evaluations.dataset.upload_csv_to_object_store" ) as mock_store_upload, patch("app.crud.credentials.get_provider_credential") as mock_get_cred, ): - # Mock object store upload succeeds mock_store_upload.return_value = "s3://bucket/datasets/test_dataset.csv" - # Mock Langfuse credentials not found mock_get_cred.return_value = None filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -438,13 +463,15 @@ def test_upload_langfuse_configuration_fails( or "unauthorized" in error_str.lower() ) - def test_upload_invalid_csv_format(self, client, user_api_key_header): + def test_upload_invalid_csv_format( + self, client: TestClient, user_api_key_header: dict[str, str] + ) -> None: """Test uploading invalid CSV format.""" invalid_csv = "not,a,valid\ncsv format here!!!" filename, file_obj = create_csv_file(invalid_csv) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -470,7 +497,7 @@ def test_upload_without_authentication(self, client, valid_csv_content): filename, file_obj = create_csv_file(valid_csv_content) response = client.post( - "/api/v1/evaluations/datasets", + "/api/v1/evaluations/datasets/", files={"file": (filename, file_obj, "text/csv")}, data={ "dataset_name": "test_dataset", @@ -485,7 +512,7 @@ class TestBatchEvaluation: """Test batch evaluation endpoint using OpenAI Batch API.""" @pytest.fixture - def sample_evaluation_config(self): + def sample_evaluation_config(self) -> dict[str, Any]: """Sample evaluation configuration.""" return { "model": "gpt-4o", @@ -494,15 +521,17 @@ def sample_evaluation_config(self): } def test_start_batch_evaluation_invalid_dataset_id( - self, client, user_api_key_header, sample_evaluation_config - ): - """Test batch evaluation fails with invalid dataset_id.""" - # Try to start evaluation with non-existent dataset_id + self, + client: TestClient, + user_api_key_header: dict[str, str], + sample_evaluation_config: dict[str, Any], + ) -> None: + """Test batch evaluation fails with invalid/non-existent dataset_id.""" response = client.post( - "/api/v1/evaluations", + "/api/v1/evaluations/", json={ "experiment_name": "test_evaluation_run", - "dataset_id": 99999, # Non-existent + "dataset_id": 99999, "config": sample_evaluation_config, }, headers=user_api_key_header, @@ -515,7 +544,9 @@ def test_start_batch_evaluation_invalid_dataset_id( ) assert "not found" in error_str.lower() or "not accessible" in error_str.lower() - def test_start_batch_evaluation_missing_model(self, client, user_api_key_header): + def test_start_batch_evaluation_missing_model( + self, client: TestClient, user_api_key_header: dict[str, str] + ) -> None: """Test batch evaluation fails when model is missing from config.""" # We don't need a real dataset for this test - the validation should happen # before dataset lookup. Use any dataset_id and expect config validation error @@ -525,7 +556,7 @@ def test_start_batch_evaluation_missing_model(self, client, user_api_key_header) } response = client.post( - "/api/v1/evaluations", + "/api/v1/evaluations/", json={ "experiment_name": "test_no_model", "dataset_id": 1, # Dummy ID, error should come before this is checked @@ -548,7 +579,7 @@ def test_start_batch_evaluation_without_authentication( ): """Test batch evaluation requires authentication.""" response = client.post( - "/api/v1/evaluations", + "/api/v1/evaluations/", json={ "experiment_name": "test_evaluation_run", "dataset_id": 1, @@ -562,7 +593,7 @@ def test_start_batch_evaluation_without_authentication( class TestBatchEvaluationJSONLBuilding: """Test JSONL building logic for batch evaluation.""" - def test_build_batch_jsonl_basic(self): + def test_build_batch_jsonl_basic(self) -> None: """Test basic JSONL building with minimal config.""" dataset_items = [ { @@ -593,7 +624,7 @@ def test_build_batch_jsonl_basic(self): assert request["body"]["instructions"] == "You are a helpful assistant" assert request["body"]["input"] == "What is 2+2?" - def test_build_batch_jsonl_with_tools(self): + def test_build_batch_jsonl_with_tools(self) -> None: """Test JSONL building with tools configuration.""" dataset_items = [ { @@ -622,7 +653,7 @@ def test_build_batch_jsonl_with_tools(self): assert request["body"]["tools"][0]["type"] == "file_search" assert "vs_abc123" in request["body"]["tools"][0]["vector_store_ids"] - def test_build_batch_jsonl_minimal_config(self): + def test_build_batch_jsonl_minimal_config(self) -> None: """Test JSONL building with minimal config (only model required).""" dataset_items = [ { @@ -642,7 +673,7 @@ def test_build_batch_jsonl_minimal_config(self): assert request["body"]["model"] == "gpt-4o" assert request["body"]["input"] == "Test question" - def test_build_batch_jsonl_skips_empty_questions(self): + def test_build_batch_jsonl_skips_empty_questions(self) -> None: """Test that items with empty questions are skipped.""" dataset_items = [ { @@ -673,7 +704,7 @@ def test_build_batch_jsonl_skips_empty_questions(self): assert len(jsonl_data) == 1 assert jsonl_data[0]["custom_id"] == "item1" - def test_build_batch_jsonl_multiple_items(self): + def test_build_batch_jsonl_multiple_items(self) -> None: """Test JSONL building with multiple items.""" dataset_items = [ { @@ -704,29 +735,28 @@ class TestGetEvaluationRunStatus: """Test GET /evaluations/{evaluation_id} endpoint.""" @pytest.fixture - def create_test_dataset(self, db, user_api_key): + def create_test_dataset( + self, db: Session, user_api_key: TestAuthContext + ) -> EvaluationDataset: """Create a test dataset for evaluation runs.""" - dataset = EvaluationDataset( - name="test_dataset_for_runs", - description="Test dataset", - dataset_metadata={ - "original_items_count": 3, - "total_items_count": 3, - "duplication_factor": 1, - }, - langfuse_dataset_id="langfuse_test_id", - object_store_url="s3://test/dataset.csv", + return create_test_evaluation_dataset( + db=db, organization_id=user_api_key.organization_id, project_id=user_api_key.project_id, + name="test_dataset_for_runs", + description="Test dataset", + original_items_count=3, + duplication_factor=1, ) - db.add(dataset) - db.commit() - db.refresh(dataset) - return dataset def test_get_evaluation_run_trace_info_not_completed( - self, client, user_api_key_header, db, user_api_key, create_test_dataset - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + create_test_dataset: EvaluationDataset, + ) -> None: """Test requesting trace info for incomplete evaluation returns error.""" eval_run = EvaluationRun( run_name="test_pending_run", @@ -756,8 +786,13 @@ def test_get_evaluation_run_trace_info_not_completed( assert response_data["data"]["id"] == eval_run.id def test_get_evaluation_run_trace_info_completed( - self, client, user_api_key_header, db, user_api_key, create_test_dataset - ): + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + create_test_dataset: EvaluationDataset, + ) -> None: """Test requesting trace info for completed evaluation returns cached scores.""" eval_run = EvaluationRun( run_name="test_completed_run", @@ -792,3 +827,216 @@ def test_get_evaluation_run_trace_info_completed( assert data["id"] == eval_run.id assert data["status"] == "completed" assert "traces" in data["score"] + + def test_get_evaluation_run_not_found( + self, client: TestClient, user_api_key_header: dict[str, str] + ) -> None: + """Test getting non-existent evaluation run returns 404.""" + response = client.get( + "/api/v1/evaluations/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("error", str(response_data)) + ) + assert "not found" in error_str.lower() or "not accessible" in error_str.lower() + + def test_get_evaluation_run_without_trace_info( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + create_test_dataset: EvaluationDataset, + ) -> None: + """Test getting evaluation run without requesting trace info.""" + eval_run = EvaluationRun( + run_name="test_simple_run", + dataset_name=create_test_dataset.name, + dataset_id=create_test_dataset.id, + config={"model": "gpt-4o"}, + status="completed", + total_items=3, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + response = client.get( + f"/api/v1/evaluations/{eval_run.id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + assert data["id"] == eval_run.id + assert data["status"] == "completed" + + def test_get_evaluation_run_resync_without_trace_info_fails( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + create_test_dataset: EvaluationDataset, + ) -> None: + """Test that resync_score=true requires get_trace_info=true.""" + eval_run = EvaluationRun( + run_name="test_run", + dataset_name=create_test_dataset.name, + dataset_id=create_test_dataset.id, + config={"model": "gpt-4o"}, + status="completed", + total_items=3, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + response = client.get( + f"/api/v1/evaluations/{eval_run.id}", + params={"resync_score": True}, # Missing get_trace_info=true + headers=user_api_key_header, + ) + + assert response.status_code == 400 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("error", str(response_data)) + ) + assert ( + "resync_score" in error_str.lower() + and "get_trace_info" in error_str.lower() + ) + + +class TestGetDataset: + """Test GET /evaluations/datasets/{dataset_id} endpoint.""" + + @pytest.fixture + def create_test_dataset( + self, db: Session, user_api_key: TestAuthContext + ) -> EvaluationDataset: + """Create a test dataset.""" + return create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="test_dataset_get", + description="Test dataset for GET", + original_items_count=5, + duplication_factor=2, + ) + + def test_get_dataset_success( + self, + client: TestClient, + user_api_key_header: dict[str, str], + create_test_dataset: EvaluationDataset, + ) -> None: + """Test successfully getting a dataset by ID.""" + response = client.get( + f"/api/v1/evaluations/datasets/{create_test_dataset.id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["dataset_id"] == create_test_dataset.id + assert data["dataset_name"] == "test_dataset_get" + assert data["original_items"] == 5 + assert data["total_items"] == 10 + assert data["duplication_factor"] == 2 + assert data["langfuse_dataset_id"].startswith("langfuse") + assert data["object_store_url"].startswith("s3://test/") + + def test_get_dataset_not_found( + self, client: TestClient, user_api_key_header: dict[str, str] + ) -> None: + """Test getting non-existent dataset returns 404.""" + response = client.get( + "/api/v1/evaluations/datasets/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("error", str(response_data)) + ) + assert "not found" in error_str.lower() or "not accessible" in error_str.lower() + + +class TestDeleteDataset: + """Test DELETE /evaluations/datasets/{dataset_id} endpoint.""" + + @pytest.fixture + def create_test_dataset( + self, db: Session, user_api_key: TestAuthContext + ) -> EvaluationDataset: + """Create a test dataset for deletion.""" + return create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="test_dataset_delete", + description="Test dataset for deletion", + original_items_count=3, + duplication_factor=1, + ) + + def test_delete_dataset_success( + self, + client: TestClient, + user_api_key_header: dict[str, str], + create_test_dataset: EvaluationDataset, + db: Session, + ) -> None: + """Test successfully deleting a dataset.""" + dataset_id = create_test_dataset.id + + response = client.delete( + f"/api/v1/evaluations/datasets/{dataset_id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + assert data["dataset_id"] == dataset_id + assert "message" in data + + verify_response = client.get( + f"/api/v1/evaluations/datasets/{dataset_id}", + headers=user_api_key_header, + ) + assert verify_response.status_code == 404 + + def test_delete_dataset_not_found( + self, client: TestClient, user_api_key_header: dict[str, str] + ) -> None: + """Test deleting non-existent dataset returns 404.""" + response = client.delete( + "/api/v1/evaluations/datasets/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + response_data = response.json() + error_str = response_data.get( + "detail", response_data.get("error", str(response_data)) + ) + assert "not found" in error_str.lower() \ No newline at end of file diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py index 500a467b..658a8431 100644 --- a/backend/app/tests/api/routes/test_openai_conversation.py +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -3,7 +3,7 @@ from app.crud.openai_conversation import create_conversation from app.models import OpenAIConversationCreate -from app.tests.utils.openai import generate_openai_id +from app.tests.utils.llm_provider import generate_openai_id from app.tests.utils.auth import TestAuthContext diff --git a/backend/app/tests/api/test_auth_failures.py b/backend/app/tests/api/test_auth_failures.py new file mode 100644 index 00000000..aaab64ca --- /dev/null +++ b/backend/app/tests/api/test_auth_failures.py @@ -0,0 +1,84 @@ +import pytest +from fastapi.testclient import TestClient + +from app.core.config import settings + + +PROTECTED_ENDPOINTS = [ + (f"{settings.API_V1_STR}/collections/", "GET"), + (f"{settings.API_V1_STR}/collections/", "POST"), + (f"{settings.API_V1_STR}/collections/12345678-1234-5678-1234-567812345678", "GET"), + ( + f"{settings.API_V1_STR}/collections/12345678-1234-5678-1234-567812345678", + "DELETE", + ), + ( + f"{settings.API_V1_STR}/collections/jobs/12345678-1234-5678-1234-567812345678", + "GET", + ), + (f"{settings.API_V1_STR}/documents/", "GET"), + (f"{settings.API_V1_STR}/documents/", "POST"), + (f"{settings.API_V1_STR}/documents/12345678-1234-5678-1234-567812345678", "GET"), + (f"{settings.API_V1_STR}/documents/12345678-1234-5678-1234-567812345678", "DELETE"), + ( + f"{settings.API_V1_STR}/documents/transformation/12345678-1234-5678-1234-567812345678", + "GET", + ), + (f"{settings.API_V1_STR}/cron/evaluations", "GET"), + (f"{settings.API_V1_STR}/evaluations/datasets/", "POST"), + (f"{settings.API_V1_STR}/evaluations/datasets/", "GET"), + ( + f"{settings.API_V1_STR}/evaluations/datasets/12345678-1234-5678-1234-567812345678", + "GET", + ), + (f"{settings.API_V1_STR}/evaluations", "POST"), + (f"{settings.API_V1_STR}/evaluations", "GET"), + (f"{settings.API_V1_STR}/evaluations/12345678-1234-5678-1234-567812345678", "GET"), + (f"{settings.API_V1_STR}/llm/call", "POST"), +] + + +@pytest.mark.parametrize("endpoint,method", PROTECTED_ENDPOINTS) +def test_endpoints_reject_missing_auth_header( + client: TestClient, endpoint: str, method: str +) -> None: + """Test that all protected endpoints return 401 when no auth header is provided.""" + kwargs = {"json": {"name": "test"}} if method in ["POST", "PATCH"] else {} + response = client.request(method, endpoint, **kwargs) + + assert ( + response.status_code == 401 + ), f"Expected 401 for {method} {endpoint} without auth, got {response.status_code}" + + +@pytest.mark.parametrize("endpoint,method", PROTECTED_ENDPOINTS) +def test_endpoints_reject_invalid_auth_format( + client: TestClient, endpoint: str, method: str +) -> None: + """Test that all protected endpoints return 401 when auth header has invalid format.""" + kwargs = {"json": {"name": "test"}} if method in ["POST", "PATCH"] else {} + response = client.request( + method, endpoint, headers={"Authorization": "InvalidFormat"}, **kwargs + ) + + assert ( + response.status_code == 401 + ), f"Expected 401 for {method} {endpoint} with invalid format, got {response.status_code}" + + +@pytest.mark.parametrize("endpoint,method", PROTECTED_ENDPOINTS) +def test_endpoints_reject_nonexistent_api_key( + client: TestClient, endpoint: str, method: str +) -> None: + """Test that all protected endpoints return 401 when API key doesn't exist.""" + kwargs = {"json": {"name": "test"}} if method in ["POST", "PATCH"] else {} + response = client.request( + method, + endpoint, + headers={"Authorization": "ApiKey FakeKeyThatDoesNotExist123456789"}, + **kwargs, + ) + + assert ( + response.status_code == 401 + ), f"Expected 401 for {method} {endpoint} with fake key, got {response.status_code}" diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py index 59101375..b21b82ce 100644 --- a/backend/app/tests/core/test_security.py +++ b/backend/app/tests/core/test_security.py @@ -4,8 +4,6 @@ from app.core.security import ( get_password_hash, verify_password, - encrypt_api_key, - decrypt_api_key, get_encryption_key, APIKeyManager, ) @@ -13,107 +11,6 @@ from app.tests.utils.test_data import create_test_api_key -def test_encrypt_decrypt_api_key(): - """Test that API key encryption and decryption works correctly.""" - # Test data - test_key = "ApiKey test123456789" - - # Encrypt the key - encrypted_key = encrypt_api_key(test_key) - - # Verify encryption worked - assert encrypted_key is not None - assert encrypted_key != test_key - assert isinstance(encrypted_key, str) - - # Decrypt the key - decrypted_key = decrypt_api_key(encrypted_key) - - # Verify decryption worked - assert decrypted_key is not None - assert decrypted_key == test_key - - -def test_api_key_format_validation(): - """Test that API key format is validated correctly.""" - # Test valid API key format - valid_key = "ApiKey test123456789" - encrypted_valid = encrypt_api_key(valid_key) - assert encrypted_valid is not None - assert decrypt_api_key(encrypted_valid) == valid_key - - # Test invalid API key format (missing prefix) - invalid_key = "test123456789" - encrypted_invalid = encrypt_api_key(invalid_key) - assert encrypted_invalid is not None - assert decrypt_api_key(encrypted_invalid) == invalid_key - - -def test_encrypt_api_key_edge_cases(): - """Test edge cases for API key encryption.""" - # Test empty string - empty_key = "" - encrypted_empty = encrypt_api_key(empty_key) - assert encrypted_empty is not None - assert decrypt_api_key(encrypted_empty) == empty_key - - # Test whitespace only - whitespace_key = " " - encrypted_whitespace = encrypt_api_key(whitespace_key) - assert encrypted_whitespace is not None - assert decrypt_api_key(encrypted_whitespace) == whitespace_key - - # Test very long input - long_key = "ApiKey " + "a" * 1000 - encrypted_long = encrypt_api_key(long_key) - assert encrypted_long is not None - assert decrypt_api_key(encrypted_long) == long_key - - -def test_encrypt_api_key_type_validation(): - """Test type validation for API key encryption.""" - # Test non-string inputs - invalid_inputs = [123, [], {}, True] - for invalid_input in invalid_inputs: - with pytest.raises(ValueError, match="Failed to encrypt API key"): - encrypt_api_key(invalid_input) - - -def test_encrypt_api_key_security(): - """Test security properties of API key encryption.""" - # Test that same input produces different encrypted output - test_key = "ApiKey test123456789" - encrypted1 = encrypt_api_key(test_key) - encrypted2 = encrypt_api_key(test_key) - assert encrypted1 != encrypted2 # Different encrypted outputs for same input - - -def test_encrypt_api_key_error_handling(): - """Test error handling in encrypt_api_key.""" - # Test with invalid input - with pytest.raises(ValueError, match="Failed to encrypt API key"): - encrypt_api_key(None) - - -def test_decrypt_api_key_error_handling(): - """Test error handling in decrypt_api_key.""" - # Test with invalid input - with pytest.raises(ValueError, match="Failed to decrypt API key"): - decrypt_api_key(None) - - # Test with various invalid encrypted data formats - invalid_encrypted_data = [ - "invalid_encrypted_data", # Not base64 - "not_a_base64_string", # Not base64 - "a" * 44, # Wrong length - "!" * 44, # Invalid base64 chars - "aGVsbG8=", # Valid base64 but not encrypted - ] - for invalid_data in invalid_encrypted_data: - with pytest.raises(ValueError, match="Failed to decrypt API key"): - decrypt_api_key(invalid_data) - - def test_get_encryption_key(): """Test that encryption key generation works correctly.""" # Get the encryption key diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py index fc52cd08..984a42db 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -18,9 +18,9 @@ def test_create_associates_documents(self, db: Session): collection = Collection( id=uuid4(), project_id=project.id, - organization_id=project.organization_id, llm_service_id="asst_dummy", llm_service_name="gpt-4o", + provider="OPENAI", ) store = DocumentStore(db, project_id=collection.project_id) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index a2668b19..bec471c9 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -12,7 +12,7 @@ from app.tests.utils.document import DocumentStore -def get_collection_for_delete( +def get_assistant_collection_for_delete( db: Session, client=None, project_id: int = None ) -> Collection: project = get_project(db) @@ -27,10 +27,10 @@ def get_collection_for_delete( ) return Collection( - organization_id=project.organization_id, project_id=project_id, llm_service_id=assistant.id, llm_service_name="gpt-4o", + provider="OPENAI", ) @@ -43,7 +43,9 @@ def test_delete_marks_deleted(self, db: Session): client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection_for_delete(db, client, project_id=project.id) + collection = get_assistant_collection_for_delete( + db, client, project_id=project.id + ) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -56,7 +58,7 @@ def test_delete_follows_insert(self, db: Session): assistant = OpenAIAssistantCrud(client) project = get_project(db) - collection = get_collection_for_delete(db, project_id=project.id) + collection = get_assistant_collection_for_delete(db, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -77,7 +79,9 @@ def test_delete_document_deletes_collections(self, db: Session): client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection_for_delete(db, client, project_id=project.id) + coll = get_assistant_collection_for_delete( + db, client, project_id=project.id + ) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index a9da3523..1f1aeab1 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -8,7 +8,7 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection def create_collections(db: Session, n: int): @@ -18,7 +18,7 @@ def create_collections(db: Session, n: int): with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): - collection = get_collection(db, project=project) + collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) if crud is None: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index ceb46c1a..498738cb 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -8,7 +8,7 @@ from app.crud import CollectionCrud from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection def mk_collection(db: Session): @@ -16,7 +16,7 @@ def mk_collection(db: Session): project = get_project(db) with openai_mock.router: client = OpenAI(api_key="sk-test-key") - collection = get_collection(db, project=project) + collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) crud = CollectionCrud(db, collection.project_id) diff --git a/backend/app/tests/crud/evaluations/test_dataset.py b/backend/app/tests/crud/evaluations/test_dataset.py index ccd2e4f3..1a1c4d60 100644 --- a/backend/app/tests/crud/evaluations/test_dataset.py +++ b/backend/app/tests/crud/evaluations/test_dataset.py @@ -410,4 +410,134 @@ def test_update_dataset_langfuse_id_nonexistent(self, db: Session): update_dataset_langfuse_id( session=db, dataset_id=99999, langfuse_dataset_id="langfuse_123" ) - # No assertion needed, just ensuring it doesn't crash + + +class TestDeleteDataset: + """Test deleting evaluation datasets.""" + + def test_delete_dataset_success(self, db: Session) -> None: + """Test successfully deleting a dataset.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + dataset = create_evaluation_dataset( + session=db, + name="dataset_to_delete", + dataset_metadata={"original_items_count": 5}, + organization_id=org.id, + project_id=project.id, + ) + dataset_id = dataset.id + + # New signature: delete_dataset(session, dataset) returns str | None + error = delete_dataset(session=db, dataset=dataset) + + assert error is None # None means success + + # Verify dataset is deleted + fetched = get_dataset_by_id( + session=db, + dataset_id=dataset_id, + organization_id=org.id, + project_id=project.id, + ) + assert fetched is None + + def test_delete_dataset_not_found(self, db: Session) -> None: + """Test deleting a non-existent dataset - dataset must be fetched first.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + # Try to fetch a non-existent dataset + dataset = get_dataset_by_id( + session=db, + dataset_id=99999, + organization_id=org.id, + project_id=project.id, + ) + + # The pattern now is: fetch dataset first, if not found, handle in caller + assert dataset is None + + def test_delete_dataset_wrong_org(self, db: Session) -> None: + """Test that dataset cannot be fetched with wrong organization.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + dataset = create_evaluation_dataset( + session=db, + name="dataset_to_delete", + dataset_metadata={"original_items_count": 5}, + organization_id=org.id, + project_id=project.id, + ) + + # Try to fetch with wrong org - should return None + fetched_wrong_org = get_dataset_by_id( + session=db, + dataset_id=dataset.id, + organization_id=99999, + project_id=project.id, + ) + assert fetched_wrong_org is None + + # Original dataset should still exist + fetched = get_dataset_by_id( + session=db, + dataset_id=dataset.id, + organization_id=org.id, + project_id=project.id, + ) + assert fetched is not None + + def test_delete_dataset_with_evaluation_runs(self, db: Session) -> None: + """Test that dataset cannot be deleted if it has evaluation runs.""" + + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + dataset = create_evaluation_dataset( + session=db, + name="dataset_with_runs", + dataset_metadata={"original_items_count": 5}, + organization_id=org.id, + project_id=project.id, + ) + + eval_run = EvaluationRun( + run_name="test_run", + dataset_name=dataset.name, + dataset_id=dataset.id, + config={"model": "gpt-4o"}, + status="pending", + organization_id=org.id, + project_id=project.id, + inserted_at=now(), + updated_at=now(), + ) + db.add(eval_run) + db.commit() + + # Attempt to delete - should return an error message + error = delete_dataset(session=db, dataset=dataset) + + assert error is not None + assert "cannot delete" in error.lower() or "being used" in error.lower() + assert "evaluation run" in error.lower() + + # Dataset should still exist + fetched = get_dataset_by_id( + session=db, + dataset_id=dataset.id, + organization_id=org.id, + project_id=project.id, + ) + assert fetched is not None diff --git a/backend/app/tests/crud/evaluations/test_processing.py b/backend/app/tests/crud/evaluations/test_processing.py new file mode 100644 index 00000000..bb3b699f --- /dev/null +++ b/backend/app/tests/crud/evaluations/test_processing.py @@ -0,0 +1,805 @@ +from typing import Any +import json +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session, select + +from app.crud.evaluations.processing import ( + check_and_process_evaluation, + parse_evaluation_output, + process_completed_embedding_batch, + process_completed_evaluation, + poll_all_pending_evaluations, +) +from app.models import BatchJob, Organization, Project, EvaluationDataset, EvaluationRun +from app.tests.utils.test_data import create_test_evaluation_dataset +from app.crud.evaluations.core import create_evaluation_run +from app.core.util import now + + +class TestParseEvaluationOutput: + """Test parsing evaluation batch output.""" + + def test_parse_evaluation_output_basic(self) -> None: + """Test basic parsing with valid data.""" + raw_results = [ + { + "custom_id": "item1", + "response": { + "body": { + "id": "resp_123", + "output": [ + { + "type": "message", + "content": [ + {"type": "output_text", "text": "The answer is 4"} + ], + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + }, + } + }, + } + ] + + dataset_items = [ + { + "id": "item1", + "input": {"question": "What is 2+2?"}, + "expected_output": {"answer": "4"}, + } + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 1 + assert results[0]["item_id"] == "item1" + assert results[0]["question"] == "What is 2+2?" + assert results[0]["generated_output"] == "The answer is 4" + assert results[0]["ground_truth"] == "4" + assert results[0]["response_id"] == "resp_123" + assert results[0]["usage"]["total_tokens"] == 15 + + def test_parse_evaluation_output_simple_string(self) -> None: + """Test parsing with simple string output.""" + raw_results = [ + { + "custom_id": "item1", + "response": { + "body": { + "id": "resp_123", + "output": "Simple text response", + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + }, + } + }, + } + ] + + dataset_items = [ + { + "id": "item1", + "input": {"question": "Test?"}, + "expected_output": {"answer": "Test"}, + } + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 1 + assert results[0]["generated_output"] == "Simple text response" + + def test_parse_evaluation_output_with_error(self) -> None: + """Test parsing item with error.""" + raw_results = [ + { + "custom_id": "item1", + "error": {"message": "Rate limit exceeded"}, + "response": {"body": {}}, + } + ] + + dataset_items = [ + { + "id": "item1", + "input": {"question": "Test?"}, + "expected_output": {"answer": "Test"}, + } + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 1 + assert "ERROR: Rate limit exceeded" in results[0]["generated_output"] + + def test_parse_evaluation_output_missing_custom_id(self) -> None: + """Test parsing skips items without custom_id.""" + raw_results = [ + { + "response": { + "body": { + "output": "Test", + "usage": {"total_tokens": 10}, + } + } + } + ] + + dataset_items = [ + { + "id": "item1", + "input": {"question": "Test?"}, + "expected_output": {"answer": "Test"}, + } + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 0 + + def test_parse_evaluation_output_missing_dataset_item(self) -> None: + """Test parsing skips items not in dataset.""" + raw_results = [ + { + "custom_id": "item999", + "response": {"body": {"output": "Test", "usage": {"total_tokens": 10}}}, + } + ] + + dataset_items = [ + { + "id": "item1", + "input": {"question": "Test?"}, + "expected_output": {"answer": "Test"}, + } + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 0 + + def test_parse_evaluation_output_json_string(self) -> None: + """Test parsing JSON string output.""" + raw_results = [ + { + "custom_id": "item1", + "response": { + "body": { + "output": json.dumps( + [ + { + "type": "message", + "content": [ + {"type": "output_text", "text": "Parsed JSON"} + ], + } + ] + ), + "usage": {"total_tokens": 10}, + } + }, + } + ] + + dataset_items = [ + { + "id": "item1", + "input": {"question": "Test?"}, + "expected_output": {"answer": "Test"}, + } + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 1 + assert results[0]["generated_output"] == "Parsed JSON" + + def test_parse_evaluation_output_multiple_items(self) -> None: + """Test parsing multiple items.""" + raw_results = [ + { + "custom_id": f"item{i}", + "response": { + "body": { + "output": f"Output {i}", + "usage": {"total_tokens": 10}, + } + }, + } + for i in range(3) + ] + + dataset_items = [ + { + "id": f"item{i}", + "input": {"question": f"Q{i}"}, + "expected_output": {"answer": f"A{i}"}, + } + for i in range(3) + ] + + results = parse_evaluation_output(raw_results, dataset_items) + + assert len(results) == 3 + for i, result in enumerate(results): + assert result["item_id"] == f"item{i}" + assert result["generated_output"] == f"Output {i}" + assert result["ground_truth"] == f"A{i}" + + +class TestProcessCompletedEvaluation: + """Test processing completed evaluation batch.""" + + @pytest.fixture + def test_dataset(self, db: Session) -> EvaluationDataset: + """Create a test dataset for evaluation runs.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + return create_test_evaluation_dataset( + db=db, + organization_id=org.id, + project_id=project.id, + name="test_dataset_processing", + description="Test dataset", + original_items_count=3, + duplication_factor=1, + ) + + @pytest.fixture + def eval_run_with_batch(self, db: Session, test_dataset) -> EvaluationRun: + """Create evaluation run with batch job.""" + # Create batch job + batch_job = BatchJob( + provider="openai", + provider_batch_id="batch_abc123", + provider_status="completed", + job_type="evaluation", + total_items=2, + status="submitted", + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(batch_job) + db.commit() + db.refresh(batch_job) + + eval_run = create_evaluation_run( + session=db, + run_name="test_run", + dataset_name=test_dataset.name, + dataset_id=test_dataset.id, + config={"model": "gpt-4o"}, + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + ) + eval_run.batch_job_id = batch_job.id + eval_run.status = "processing" + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + return eval_run + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.download_batch_results") + @patch("app.crud.evaluations.processing.fetch_dataset_items") + @patch("app.crud.evaluations.processing.create_langfuse_dataset_run") + @patch("app.crud.evaluations.processing.start_embedding_batch") + @patch("app.crud.evaluations.processing.upload_batch_results_to_object_store") + async def test_process_completed_evaluation_success( + self, + mock_upload, + mock_start_embedding, + mock_create_langfuse, + mock_fetch_dataset, + mock_download, + db: Session, + eval_run_with_batch, + ): + """Test successfully processing completed evaluation.""" + # Mock batch results + mock_download.return_value = [ + { + "custom_id": "item1", + "response": { + "body": { + "id": "resp_123", + "output": "Answer 1", + "usage": {"total_tokens": 10}, + } + }, + } + ] + + # Mock dataset items + mock_fetch_dataset.return_value = [ + { + "id": "item1", + "input": {"question": "Q1"}, + "expected_output": {"answer": "A1"}, + } + ] + + # Mock Langfuse + mock_create_langfuse.return_value = {"item1": "trace_123"} + + # Mock embedding batch + mock_start_embedding.return_value = eval_run_with_batch + + # Mock upload + mock_upload.return_value = "s3://bucket/results.jsonl" + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await process_completed_evaluation( + eval_run=eval_run_with_batch, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + assert result is not None + mock_download.assert_called_once() + mock_fetch_dataset.assert_called_once() + mock_create_langfuse.assert_called_once() + mock_start_embedding.assert_called_once() + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.download_batch_results") + @patch("app.crud.evaluations.processing.fetch_dataset_items") + async def test_process_completed_evaluation_no_results( + self, + mock_fetch_dataset, + mock_download, + db: Session, + eval_run_with_batch, + ): + """Test processing with no valid results.""" + mock_download.return_value = [] + mock_fetch_dataset.return_value = [ + { + "id": "item1", + "input": {"question": "Q1"}, + "expected_output": {"answer": "A1"}, + } + ] + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await process_completed_evaluation( + eval_run=eval_run_with_batch, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + db.refresh(result) + assert result.status == "failed" + assert "No valid results" in result.error_message + + @pytest.mark.asyncio + async def test_process_completed_evaluation_no_batch_job_id( + self, db: Session, test_dataset + ): + """Test processing without batch_job_id.""" + eval_run = create_evaluation_run( + session=db, + run_name="test_run", + dataset_name=test_dataset.name, + dataset_id=test_dataset.id, + config={"model": "gpt-4o"}, + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + ) + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await process_completed_evaluation( + eval_run=eval_run, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + db.refresh(result) + assert result.status == "failed" + assert "no batch_job_id" in result.error_message + + +class TestProcessCompletedEmbeddingBatch: + """Test processing completed embedding batch.""" + + @pytest.fixture + def test_dataset(self, db: Session) -> EvaluationDataset: + """Create a test dataset.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + return create_test_evaluation_dataset( + db=db, + organization_id=org.id, + project_id=project.id, + name="test_dataset_embedding", + description="Test dataset", + original_items_count=2, + duplication_factor=1, + ) + + @pytest.fixture + def eval_run_with_embedding_batch(self, db: Session, test_dataset) -> EvaluationRun: + """Create evaluation run with embedding batch job.""" + # Create embedding batch job + embedding_batch = BatchJob( + provider="openai", + provider_batch_id="batch_embed_123", + provider_status="completed", + job_type="embedding", + total_items=4, + status="submitted", + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(embedding_batch) + db.commit() + db.refresh(embedding_batch) + + # Create evaluation run + eval_run = create_evaluation_run( + session=db, + run_name="test_run_embedding", + dataset_name=test_dataset.name, + dataset_id=test_dataset.id, + config={"model": "gpt-4o"}, + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + ) + eval_run.embedding_batch_job_id = embedding_batch.id + eval_run.status = "processing" + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + return eval_run + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.download_batch_results") + @patch("app.crud.evaluations.processing.parse_embedding_results") + @patch("app.crud.evaluations.processing.calculate_average_similarity") + @patch("app.crud.evaluations.processing.update_traces_with_cosine_scores") + async def test_process_completed_embedding_batch_success( + self, + mock_update_traces, + mock_calculate, + mock_parse, + mock_download, + db: Session, + eval_run_with_embedding_batch, + ): + """Test successfully processing completed embedding batch.""" + mock_download.return_value = [] + mock_parse.return_value = [ + { + "item_id": "item1", + "trace_id": "trace_123", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [1.0, 0.0], + } + ] + mock_calculate.return_value = { + "cosine_similarity_avg": 0.95, + "cosine_similarity_std": 0.02, + "total_pairs": 1, + "per_item_scores": [ + {"item_id": "item1", "trace_id": "trace_123", "cosine_similarity": 0.95} + ], + } + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await process_completed_embedding_batch( + eval_run=eval_run_with_embedding_batch, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + db.refresh(result) + assert result.status == "completed" + assert result.score is not None + assert "cosine_similarity" in result.score + assert result.score["cosine_similarity"]["avg"] == 0.95 + mock_update_traces.assert_called_once() + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.download_batch_results") + @patch("app.crud.evaluations.processing.parse_embedding_results") + async def test_process_completed_embedding_batch_no_results( + self, + mock_parse, + mock_download, + db: Session, + eval_run_with_embedding_batch, + ): + """Test processing with no valid embedding results.""" + mock_download.return_value = [] + mock_parse.return_value = [] + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await process_completed_embedding_batch( + eval_run=eval_run_with_embedding_batch, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + db.refresh(result) + assert result.status == "completed" + assert "failed" in result.error_message.lower() + + +class TestCheckAndProcessEvaluation: + """Test check and process evaluation function.""" + + @pytest.fixture + def test_dataset(self, db: Session) -> EvaluationDataset: + """Create a test dataset.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + return create_test_evaluation_dataset( + db=db, + organization_id=org.id, + project_id=project.id, + name="test_dataset_check", + description="Test dataset", + original_items_count=2, + duplication_factor=1, + ) + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.get_batch_job") + @patch("app.crud.evaluations.processing.poll_batch_status") + @patch("app.crud.evaluations.processing.process_completed_evaluation") + async def test_check_and_process_evaluation_completed( + self, + mock_process, + mock_poll, + mock_get_batch, + db: Session, + test_dataset, + ): + """Test checking evaluation with completed batch.""" + # Create batch job + batch_job = BatchJob( + provider="openai", + provider_batch_id="batch_abc", + provider_status="completed", + job_type="evaluation", + total_items=2, + status="submitted", + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(batch_job) + db.commit() + db.refresh(batch_job) + + # Create evaluation run + eval_run = create_evaluation_run( + session=db, + run_name="test_run", + dataset_name=test_dataset.name, + dataset_id=test_dataset.id, + config={"model": "gpt-4o"}, + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + ) + eval_run.batch_job_id = batch_job.id + eval_run.status = "processing" + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + mock_get_batch.return_value = batch_job + mock_process.return_value = eval_run + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await check_and_process_evaluation( + eval_run=eval_run, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + assert result["action"] == "processed" + assert result["run_id"] == eval_run.id + mock_process.assert_called_once() + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.get_batch_job") + @patch("app.crud.evaluations.processing.poll_batch_status") + async def test_check_and_process_evaluation_failed( + self, + mock_poll, + mock_get_batch, + db: Session, + test_dataset, + ): + """Test checking evaluation with failed batch.""" + # Create failed batch job + batch_job = BatchJob( + provider="openai", + provider_batch_id="batch_fail", + provider_status="failed", + job_type="evaluation", + total_items=2, + status="submitted", + error_message="Provider error", + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(batch_job) + db.commit() + db.refresh(batch_job) + + # Create evaluation run + eval_run = create_evaluation_run( + session=db, + run_name="test_run_fail", + dataset_name=test_dataset.name, + dataset_id=test_dataset.id, + config={"model": "gpt-4o"}, + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + ) + eval_run.batch_job_id = batch_job.id + eval_run.status = "processing" + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + mock_get_batch.return_value = batch_job + + mock_openai = MagicMock() + mock_langfuse = MagicMock() + + result = await check_and_process_evaluation( + eval_run=eval_run, + session=db, + openai_client=mock_openai, + langfuse=mock_langfuse, + ) + + assert result["action"] == "failed" + assert result["current_status"] == "failed" + db.refresh(eval_run) + assert eval_run.status == "failed" + + +class TestPollAllPendingEvaluations: + """Test polling all pending evaluations.""" + + @pytest.fixture + def test_dataset(self, db: Session) -> EvaluationDataset: + """Create a test dataset.""" + org = db.exec(select(Organization)).first() + project = db.exec( + select(Project).where(Project.organization_id == org.id) + ).first() + + return create_test_evaluation_dataset( + db=db, + organization_id=org.id, + project_id=project.id, + name="test_dataset_poll", + description="Test dataset", + original_items_count=2, + duplication_factor=1, + ) + + @pytest.mark.asyncio + async def test_poll_all_pending_evaluations_no_pending( + self, db: Session, test_dataset + ): + """Test polling with no pending evaluations.""" + result = await poll_all_pending_evaluations( + session=db, org_id=test_dataset.organization_id + ) + + assert result["total"] == 0 + assert result["processed"] == 0 + assert result["failed"] == 0 + assert result["still_processing"] == 0 + + @pytest.mark.asyncio + @patch("app.crud.evaluations.processing.check_and_process_evaluation") + @patch("app.crud.evaluations.processing.get_openai_client") + @patch("app.crud.evaluations.processing.get_langfuse_client") + async def test_poll_all_pending_evaluations_with_runs( + self, + mock_langfuse_client, + mock_openai_client, + mock_check, + db: Session, + test_dataset, + ): + """Test polling with pending evaluations.""" + # Create batch job + batch_job = BatchJob( + provider="openai", + provider_batch_id="batch_test", + provider_status="in_progress", + job_type="evaluation", + total_items=2, + status="submitted", + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(batch_job) + db.commit() + db.refresh(batch_job) + + # Create pending evaluation run + eval_run = create_evaluation_run( + session=db, + run_name="test_pending_run", + dataset_name=test_dataset.name, + dataset_id=test_dataset.id, + config={"model": "gpt-4o"}, + organization_id=test_dataset.organization_id, + project_id=test_dataset.project_id, + ) + eval_run.batch_job_id = batch_job.id + eval_run.status = "processing" + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + mock_openai_client.return_value = MagicMock() + mock_langfuse_client.return_value = MagicMock() + mock_check.return_value = { + "run_id": eval_run.id, + "run_name": eval_run.run_name, + "action": "no_change", + } + + result = await poll_all_pending_evaluations( + session=db, org_id=test_dataset.organization_id + ) + + assert result["total"] == 1 + assert result["still_processing"] == 1 + mock_check.assert_called_once() \ No newline at end of file diff --git a/backend/app/tests/crud/test_assistants.py b/backend/app/tests/crud/test_assistants.py index 227fdb31..40c9c6c2 100644 --- a/backend/app/tests/crud/test_assistants.py +++ b/backend/app/tests/crud/test_assistants.py @@ -15,7 +15,7 @@ get_assistant_by_id, get_assistants_by_project, ) -from app.tests.utils.openai import mock_openai_assistant +from app.tests.utils.llm_provider import mock_openai_assistant from app.tests.utils.utils import ( get_project, get_assistant, diff --git a/backend/app/tests/crud/test_openai_conversation.py b/backend/app/tests/crud/test_openai_conversation.py index 890041b0..b9564e00 100644 --- a/backend/app/tests/crud/test_openai_conversation.py +++ b/backend/app/tests/crud/test_openai_conversation.py @@ -14,7 +14,7 @@ ) from app.models import OpenAIConversationCreate from app.tests.utils.utils import get_project, get_organization -from app.tests.utils.openai import generate_openai_id +from app.tests.utils.llm_provider import generate_openai_id def test_get_conversation_by_id_success(db: Session): diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 9d5e7e97..8a7a5403 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -15,9 +15,9 @@ from app.models import CollectionJobStatus, CollectionJob, CollectionActionType from app.models.collection import CreationRequest from app.services.collections.create_collection import start_job, execute_job -from app.tests.utils.openai import get_mock_openai_client_with_vector_store +from app.tests.utils.llm_provider import get_mock_provider from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection_job, get_collection +from app.tests.utils.collection import get_collection_job, get_assistant_collection from app.tests.utils.document import DocumentStore @@ -56,12 +56,10 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): """ project = get_project(db) request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], batch_size=1, callback_url=None, + provider="openai", ) job_id = uuid4() @@ -114,9 +112,9 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_openai_client, db: Session + mock_get_llm_provider, db: Session ): """ execute_job should: @@ -138,16 +136,12 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, - documents=[document.id], - batch_size=1, - callback_url=None, + documents=[document.id], batch_size=1, callback_url=None, provider="openai" ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + ) job_id = uuid4() _ = get_collection_job( @@ -183,8 +177,8 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( created_collection = CollectionCrud(db, project.id).read_one( updated_job.collection_id ) - assert created_collection.llm_service_id == "mock_assistant_id" - assert created_collection.llm_service_name == sample_request.model + assert created_collection.llm_service_id == "mock_vector_store_id" + assert created_collection.llm_service_name == "openai vector store" assert created_collection.updated_at is not None docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) @@ -194,9 +188,9 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") -def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( - mock_get_openai_client, db +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( + mock_get_llm_provider, db ): project = get_project(db) @@ -210,32 +204,23 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( ) req = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.0, - documents=[], - batch_size=1, - callback_url=None, + documents=[], batch_size=1, callback_url=None, provider="openai" ) - _ = mock_get_openai_client.return_value + mock_provider = get_mock_provider( + llm_service_id="vs_123", llm_service_name="openai vector store" + ) + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.create_collection.Session" ) as SessionCtor, patch( - "app.services.collections.create_collection.OpenAIVectorStoreCrud" - ) as MockVS, patch( - "app.services.collections.create_collection.OpenAIAssistantCrud" - ) as MockAsst: + "app.services.collections.create_collection.CollectionCrud" + ) as MockCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False - MockVS.return_value.create.return_value = type( - "Vector store", (), {"id": "vs_123"} - )() - MockVS.return_value.update.return_value = [] - - MockAsst.return_value.create.side_effect = RuntimeError("assistant boom") + MockCrud.return_value.create.side_effect = Exception("DB constraint violation") task_id = str(uuid4()) execute_job( @@ -248,21 +233,16 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( task_instance=None, ) - failed = CollectionJobCrud(db, project.id).read_one(job.id) - assert failed.task_id == task_id - assert failed.status == CollectionJobStatus.FAILED - assert "assistant boom" in (failed.error_message or "") - - MockVS.return_value.delete.assert_called_once_with("vs_123") + mock_provider.cleanup.assert_called_once() @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_flow_callback_job_and_creates_collection( mock_send_callback, - mock_get_openai_client, + mock_get_llm_provider, db, ): """ @@ -287,16 +267,15 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( callback_url = "https://example.com/collections/create-success" sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[document.id], batch_size=1, callback_url=callback_url, + provider="openai", ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + ) job_id = uuid.uuid4() _ = get_collection_job( @@ -340,11 +319,11 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_creates_collection_with_callback( mock_send_callback, - mock_get_openai_client, + mock_get_llm_provider, db, ): """ @@ -369,16 +348,15 @@ def test_execute_job_success_creates_collection_with_callback( callback_url = "https://example.com/collections/create-success" sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[document.id], batch_size=1, callback_url=callback_url, + provider="openai", ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="gpt-4o" + ) job_id = uuid.uuid4() _ = get_collection_job( @@ -422,13 +400,13 @@ def test_execute_job_success_creates_collection_with_callback( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") @patch("app.services.collections.create_collection.CollectionCrud") def test_execute_job_failure_flow_callback_job_and_marks_failed( MockCollectionCrud, mock_send_callback, - mock_get_openai_client, + mock_get_llm_provider, db: Session, ): """ @@ -437,7 +415,7 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_assistant_collection(db, project, assistant_id="asst_123") job = get_collection_job( db, project, @@ -446,7 +424,7 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( collection_id=None, ) - mock_get_openai_client.return_value = MagicMock() + mock_get_llm_provider.return_value = MagicMock() callback_url = "https://example.com/collections/create-failure" @@ -454,12 +432,10 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( collection_crud_instance.read_one.return_value = collection sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[uuid.uuid4()], batch_size=1, callback_url=callback_url, + provider="openai", ) task_id = uuid.uuid4() diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index 07e2af08..4c7284c3 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -1,15 +1,13 @@ from unittest.mock import patch, MagicMock from uuid import uuid4, UUID -from sqlalchemy.exc import SQLAlchemyError -from app.models.collection import ( - DeletionRequest, -) +from app.models.collection import DeletionRequest + from app.tests.utils.utils import get_project from app.crud import CollectionJobCrud from app.models import CollectionJobStatus, CollectionActionType -from app.tests.utils.collection import get_collection, get_collection_job +from app.tests.utils.collection import get_collection_job, get_vector_store_collection from app.services.collections.delete_collection import start_job, execute_job @@ -20,7 +18,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db): - return the same job_id (UUID) """ project = get_project(db) - created_collection = get_collection(db, project) + created_collection = get_vector_store_collection(db, project, provider="OPENAI") req = DeletionRequest(collection_id=created_collection.id) @@ -72,19 +70,22 @@ def test_start_job_creates_collection_job_and_schedules_task(db): assert "trace_id" in kwargs -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") def test_execute_job_delete_success_updates_job_and_calls_delete( - mock_get_openai_client, db + mock_get_llm_provider, db ): """ - execute_job should set task_id on the CollectionJob - - call remote delete via OpenAIAssistantCrud.delete(...) + - call provider.delete() to delete remote resources - delete local record via CollectionCrud.delete_by_id(...) - mark job successful and clear error_message """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, project, vector_store_id="asst_123", provider="OPENAI" + ) + job = get_collection_job( db, project, @@ -93,13 +94,13 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete = MagicMock() + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db @@ -108,8 +109,6 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.return_value = None - task_id = uuid4() req = DeletionRequest(collection_id=collection.id) @@ -128,25 +127,32 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( assert updated_job.status == CollectionJobStatus.SUCCESSFUL assert updated_job.error_message in (None, "") + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) - mock_get_openai_client.assert_called_once() + + mock_get_llm_provider.assert_called_once() -@patch("app.services.collections.delete_collection.get_openai_client") -def test_execute_job_delete_failure_marks_job_failed(mock_get_openai_client, db): +@patch("app.services.collections.delete_collection.get_llm_provider") +def test_execute_job_delete_failure_marks_job_failed(mock_get_llm_provider, db): """ - When the remote delete (OpenAIAssistantCrud.delete) raises, - the job should be marked FAILED and error_message set. + When provider.delete() raises an exception: + - Job should be marked FAILED + - error_message should be set + - Local collection should NOT be deleted """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector_123", + provider="OPENAI", + ) + job = get_collection_job( db, project, @@ -155,13 +161,13 @@ def test_execute_job_delete_failure_marks_job_failed(mock_get_openai_client, db) collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete.side_effect = Exception("Remote deletion failed") + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db @@ -170,10 +176,6 @@ def test_execute_job_delete_failure_marks_job_failed(mock_get_openai_client, db) collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( - "something went wrong" - ) - task_id = uuid4() req = DeletionRequest(collection_id=collection.id) @@ -192,22 +194,23 @@ def test_execute_job_delete_failure_marks_job_failed(mock_get_openai_client, db) assert failed_job.status == CollectionJobStatus.FAILED assert ( failed_job.error_message - and "something went wrong" in failed_job.error_message + and "Remote deletion failed" in failed_job.error_message ) + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_not_called() - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() -@patch("app.services.collections.delete_collection.get_openai_client") + +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_success_with_callback_sends_success_payload( - mock_get_openai_client, + mock_send_callback, + mock_get_llm_provider, db, ): """ @@ -218,7 +221,13 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector 123", + provider="OPENAI", + ) + job = get_collection_job( db, project, @@ -227,32 +236,26 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete = MagicMock() + mock_get_llm_provider.return_value = mock_provider callback_url = "https://example.com/collections/delete-success" with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" - ) as MockCollectionCrud, patch( - "app.services.collections.delete_collection.send_callback" - ) as mock_send_callback: + ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.return_value = None - task_id = uuid4() req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) - from app.services.collections.delete_collection import execute_job - execute_job( request=req.model_dump(mode="json"), project_id=project.id, @@ -268,12 +271,12 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( assert updated_job.status == CollectionJobStatus.SUCCESSFUL assert updated_job.error_message in (None, "") + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() mock_send_callback.assert_called_once() cb_url_arg, payload_arg = mock_send_callback.call_args.args @@ -285,20 +288,28 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( assert UUID(payload_arg["data"]["job_id"]) == job.id -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( - mock_get_openai_client, + mock_send_callback, + mock_get_llm_provider, db, ): """ - When the remote delete raises AND a callback_url is provided: + When provider.delete() raises AND a callback_url is provided: - job is marked FAILED with error_message set - send_callback is called once - failure payload has success=False, status=FAILED, correct collection id, and error message """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector_123", + provider="OPENAI", + ) + job = get_collection_job( db, project, @@ -307,33 +318,26 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete.side_effect = Exception("Remote deletion failed") + mock_get_llm_provider.return_value = mock_provider + callback_url = "https://example.com/collections/delete-failed" with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" - ) as MockCollectionCrud, patch( - "app.services.collections.delete_collection.send_callback" - ) as mock_send_callback: + ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( - "something went wrong" - ) - task_id = uuid4() req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) - from app.services.collections.delete_collection import execute_job - execute_job( request=req.model_dump(mode="json"), project_id=project.id, @@ -349,24 +353,22 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( assert failed_job.status == CollectionJobStatus.FAILED assert ( failed_job.error_message - and "something went wrong" in failed_job.error_message + and "Remote deletion failed" in failed_job.error_message ) + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_not_called() - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() mock_send_callback.assert_called_once() cb_url_arg, payload_arg = mock_send_callback.call_args.args assert str(cb_url_arg) == callback_url assert payload_arg["success"] is False - assert "something went wrong" in (payload_arg["error"] or "") + assert "Remote deletion failed" in (payload_arg["error"] or "") assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED assert payload_arg["data"]["collection"]["id"] == str(collection.id) assert UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/llm/providers/test_openai.py b/backend/app/tests/services/llm/providers/test_openai.py index 745dd00b..4dfb671e 100644 --- a/backend/app/tests/services/llm/providers/test_openai.py +++ b/backend/app/tests/services/llm/providers/test_openai.py @@ -12,7 +12,7 @@ ) from app.models.llm.request import ConversationConfig from app.services.llm.providers.openai import OpenAIProvider -from app.tests.utils.openai import mock_openai_response +from app.tests.utils.llm_provider import mock_openai_response class TestOpenAIProvider: diff --git a/backend/app/tests/services/response/response/test_process_response.py b/backend/app/tests/services/response/response/test_process_response.py index a6a03e56..179a494f 100644 --- a/backend/app/tests/services/response/response/test_process_response.py +++ b/backend/app/tests/services/response/response/test_process_response.py @@ -16,7 +16,7 @@ from app.utils import APIResponse from app.tests.utils.utils import get_project from app.tests.utils.test_data import create_test_credential -from app.tests.utils.openai import mock_openai_response, generate_openai_id +from app.tests.utils.llm_provider import mock_openai_response, generate_openai_id from app.crud import JobCrud, create_assistant from openai import OpenAI diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 429bfc8b..d97f6188 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,8 +8,10 @@ CollectionActionType, CollectionJob, CollectionJobStatus, + ProviderType, ) from app.crud import CollectionCrud, CollectionJobCrud +from app.services.collections.helpers import get_service_name class constants: @@ -22,7 +24,7 @@ def uuid_increment(value: UUID): return UUID(int=inc) -def get_collection( +def get_assistant_collection( db: Session, project, *, @@ -43,6 +45,7 @@ def get_collection( organization_id=project.organization_id, llm_service_name=model, llm_service_id=assistant_id, + provider=ProviderType.OPENAI, ) return CollectionCrud(db, project.id).create(collection) @@ -53,6 +56,7 @@ def get_vector_store_collection( *, vector_store_id: Optional[str] = None, collection_id: Optional[UUID] = None, + provider: str, ) -> Collection: """ Create a Collection configured for the Vector Store path. @@ -64,9 +68,9 @@ def get_vector_store_collection( collection = Collection( id=collection_id or uuid4(), project_id=project.id, - organization_id=project.organization_id, - llm_service_name="openai vector store", + llm_service_name=get_service_name("openai"), llm_service_id=vector_store_id, + provider=provider.upper(), ) return CollectionCrud(db, project.id).create(collection) diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py index dddb1c2a..10381ed8 100644 --- a/backend/app/tests/utils/document.py +++ b/backend/app/tests/utils/document.py @@ -1,12 +1,12 @@ import itertools as it import functools as ft +from typing import Any, Generator from uuid import UUID from pathlib import Path from datetime import datetime from dataclasses import dataclass from urllib.parse import ParseResult, urlunparse -import pytest from httpx import Response from sqlmodel import Session, delete from fastapi.testclient import TestClient @@ -51,32 +51,32 @@ def __next__(self): class DocumentStore: - def __init__(self, db: Session, project_id: int): + def __init__(self, db: Session, project_id: int) -> None: self.db = db self.documents = DocumentMaker(project_id=project_id, session=db) self.clear(self.db) @staticmethod - def clear(db: Session): + def clear(db: Session) -> None: db.exec(delete(Document)) db.commit() @property - def project(self): + def project(self) -> Project: return self.documents.project - def put(self): + def put(self) -> Document: doc = next(self.documents) self.db.add(doc) self.db.commit() self.db.refresh(doc) return doc - def extend(self, n: int): + def extend(self, n: int) -> Generator[Document, None, None]: for _ in range(n): yield self.put() - def fill(self, n: int): + def fill(self, n: int) -> list[Document]: return list(self.extend(n)) @@ -84,14 +84,14 @@ class Route: _empty = ParseResult(*it.repeat("", len(ParseResult._fields))) _root = Path(settings.API_V1_STR, "documents") - def __init__(self, endpoint, **qs_args): + def __init__(self, endpoint: str | Path, **qs_args: Any) -> None: self.endpoint = endpoint self.qs_args = qs_args - def __str__(self): + def __str__(self) -> str: return urlunparse(self.to_url()) - def to_url(self): + def to_url(self) -> ParseResult: path = self._root.joinpath(self.endpoint) kwargs = { "path": str(path), @@ -102,7 +102,7 @@ def to_url(self): return self._empty._replace(**kwargs) - def append(self, doc: Document, suffix: str = None): + def append(self, doc: Document, suffix: str | None = None) -> "Route": segments = [self.endpoint, str(doc.id)] if suffix: segments.append(suffix) @@ -115,13 +115,13 @@ class WebCrawler: client: TestClient user_api_key: TestAuthContext - def get(self, route: Route): + def get(self, route: Route) -> Response: return self.client.get( str(route), headers={"X-API-KEY": self.user_api_key.key}, ) - def delete(self, route: Route): + def delete(self, route: Route) -> Response: return self.client.delete( str(route), headers={"X-API-KEY": self.user_api_key.key}, @@ -133,23 +133,23 @@ class DocumentComparator: @ft.singledispatchmethod @staticmethod - def to_string(value): + def to_string(value: Any) -> Any: return value @to_string.register @staticmethod - def _(value: UUID): + def _(value: UUID) -> str: return str(value) @to_string.register @staticmethod - def _(value: datetime): + def _(value: datetime) -> str: return value.isoformat() - def __init__(self, document: Document): + def __init__(self, document: Document) -> None: self.document = document - def __eq__(self, other: dict): + def __eq__(self, other: dict) -> bool: this = dict(self.to_public_dict()) return this == other @@ -162,9 +162,4 @@ def to_public_dict(self) -> dict: value = getattr(self.document, field, None) result[field] = self.to_string(value) - return result - - -@pytest.fixture -def crawler(client: TestClient, user_api_key: TestAuthContext): - return WebCrawler(client, user_api_key=user_api_key) + return result \ No newline at end of file diff --git a/backend/app/tests/utils/openai.py b/backend/app/tests/utils/llm_provider.py similarity index 85% rename from backend/app/tests/utils/openai.py rename to backend/app/tests/utils/llm_provider.py index c92a5214..b50f1446 100644 --- a/backend/app/tests/utils/openai.py +++ b/backend/app/tests/utils/llm_provider.py @@ -114,12 +114,10 @@ def mock_openai_response( def get_mock_openai_client_with_vector_store(): mock_client = MagicMock() - # Vector store mock_vector_store = MagicMock() mock_vector_store.id = "mock_vector_store_id" mock_client.vector_stores.create.return_value = mock_vector_store - # File upload + polling mock_file_batch = MagicMock() mock_file_batch.file_counts.completed = 2 mock_file_batch.file_counts.total = 2 @@ -127,10 +125,8 @@ def get_mock_openai_client_with_vector_store(): mock_file_batch ) - # File list mock_client.vector_stores.files.list.return_value = {"data": []} - # Assistant mock_assistant = MagicMock() mock_assistant.id = "mock_assistant_id" mock_assistant.name = "Mock Assistant" @@ -139,3 +135,28 @@ def get_mock_openai_client_with_vector_store(): mock_client.beta.assistants.create.return_value = mock_assistant return mock_client + + +def get_mock_provider( + llm_service_id: str = "mock_service_id", + llm_service_name: str = "mock_service_name", +): + """ + Create a properly configured mock provider for tests. + + Returns a mock that mimics BaseProvider with: + - create() method returning result with llm_service_id and llm_service_name + - cleanup() method for cleanup on failure + - delete() method for deletion + """ + mock_provider = MagicMock() + + mock_result = MagicMock() + mock_result.llm_service_id = llm_service_id + mock_result.llm_service_name = llm_service_name + + mock_provider.create.return_value = mock_result + mock_provider.cleanup = MagicMock() + mock_provider.delete = MagicMock() + + return mock_provider