diff --git a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py new file mode 100644 index 00000000..ac0bcd89 --- /dev/null +++ b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py @@ -0,0 +1,112 @@ +"""extend collection table for provider agnostic support + +Revision ID: 041 +Revises: 040 +Create Date: 2026-01-15 16:53:19.495583 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + +provider_type = postgresql.ENUM( + "openai", + # aws + # gemini + name="providertype", + create_type=False, +) + + +def upgrade(): + provider_type.create(op.get_bind(), checkfirst=True) + op.add_column( + "collection", + sa.Column( + "provider", + provider_type, + nullable=True, + comment="LLM provider used for this collection", + ), + ) + op.execute("UPDATE collection SET provider = 'openai' WHERE provider IS NULL") + op.alter_column("collection", "provider", nullable=False) + op.add_column( + "collection", + sa.Column( + "name", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="Name of the collection", + ), + ) + op.add_column( + "collection", + sa.Column( + "description", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="Description of the collection", + ), + ) + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM service provider", + existing_nullable=False, + ) + op.create_unique_constraint( + "uq_collection_project_id_name", "collection", ["project_id", "name"] + ) + op.drop_constraint( + op.f("collection_organization_id_fkey"), "collection", type_="foreignkey" + ) + op.drop_column("collection", "organization_id") + + +def downgrade(): + op.add_column( + "collection", + sa.Column( + "organization_id", + sa.INTEGER(), + autoincrement=False, + nullable=True, + comment="Reference to the organization", + ), + ) + op.execute( + """UPDATE collection SET organization_id = (SELECT organization_id FROM project + WHERE project.id = collection.project_id)""" + ) + op.alter_column("collection", "organization_id", nullable=False) + op.create_foreign_key( + op.f("collection_organization_id_fkey"), + "collection", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint("uq_collection_project_id_name", "collection", type_="unique") + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "description") + op.drop_column("collection", "name") + op.drop_column("collection", "provider") diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index c3a5f440..915951b9 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -10,17 +10,17 @@ pipeline: "model" and "instruction" in the request body otherwise only a vector store will be created from the documents given. -If any one of the OpenAI interactions fail, all OpenAI resources are -cleaned up. If a Vector Store is unable to be created, for example, +If any one of the LLM service interactions fail, all service resources are +cleaned up. If an OpenAI vector Store is unable to be created, for example, all file(s) that were uploaded to OpenAI are removed from OpenAI. Failure can occur from OpenAI being down, or some parameter value being invalid. It can also fail due to document types not being accepted. This is especially true for PDFs that may not be parseable. -Vector store/assistant will be created asynchronously. The immediate response -from this endpoint is `collection_job` object which is going to contain -the collection "job ID" and status. Once the collection has been created, -information about the collection will be returned to the user via the -callback URL. If a callback URL is not provided, clients can check the +In the case of Openai, Vector store/assistant will be created asynchronously. +The immediate response from this endpoint is `collection_job` object which is +going to contain the collection "job ID" and status. Once the collection has +been created, information about the collection will be returned to the user via +the callback URL. If a callback URL is not provided, clients can check the `collection job info` endpoint with the `job_id`, to retrieve information about the creation of collection. diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index b9dc7b3f..bf063a08 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -28,6 +28,7 @@ CollectionPublic, ) from app.utils import APIResponse, load_description, validate_callback_url +from app.services.collections.helpers import ensure_unique_name from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -88,6 +89,9 @@ def create_collection( if request.callback_url: validate_callback_url(str(request.callback_url)) + if request.name: + ensure_unique_name(session, current_user.project_.id, request.name) + collection_job_crud = CollectionJobCrud(session, current_user.project_.id) collection_job = collection_job_crud.create( CollectionJobCreate( diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index 3c83912a..cb8b6e27 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -93,6 +93,16 @@ def read_all(self): collections = self.session.exec(statement).all() return collections + def exists_by_name(self, collection_name: str) -> bool: + statement = ( + select(Collection.id) + .where(Collection.project_id == self.project_id) + .where(Collection.name == collection_name) + .where(Collection.deleted_at.is_(None)) + ) + result = self.session.exec(statement).first() + return result is not None + def delete_by_id(self, collection_id: UUID) -> Collection: coll = self.read_one(collection_id) coll.deleted_at = now() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ac7e89d6..a4d76ee2 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,12 @@ from .collection import ( Collection, + CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, + DeletionRequest, + ProviderType, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 57e5a17b..7a537079 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -1,71 +1,102 @@ from datetime import datetime +from enum import Enum from typing import Any, Literal from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator +from sqlalchemy import UniqueConstraint from sqlmodel import Field, Relationship, SQLModel from app.core.util import now from app.models.document import DocumentPublic - -from .organization import Organization from .project import Project +class ProviderType(str, Enum): + """Supported LLM providers for collections.""" + + openai = "openai" + # BEDROCK = "bedrock" + # GEMINI = "gemini" + + class Collection(SQLModel, table=True): """Database model for Collection operations.""" + __table_args__ = ( + UniqueConstraint( + "project_id", + "name", + name="uq_collection_project_id_name", + ), + ) + id: UUID = Field( default_factory=uuid4, primary_key=True, + description="Unique identifier for the collection", sa_column_kwargs={"comment": "Unique identifier for the collection"}, ) + provider: ProviderType = ( + Field( + nullable=False, + description="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", + sa_column_kwargs={"comment": "LLM provider used for this collection"}, + ), + ) llm_service_id: str = Field( nullable=False, + description="External LLM service identifier (e.g., OpenAI vector store ID)", sa_column_kwargs={ "comment": "External LLM service identifier (e.g., OpenAI vector store ID)" }, ) llm_service_name: str = Field( nullable=False, + description="Name of the LLM service", sa_column_kwargs={"comment": "Name of the LLM service"}, ) - - # Foreign keys - organization_id: int = Field( - foreign_key="organization.id", - nullable=False, - ondelete="CASCADE", - sa_column_kwargs={"comment": "Reference to the organization"}, + name: str = Field( + nullable=True, + description="Name of the collection", + sa_column_kwargs={"comment": "Name of the collection"}, + ) + description: str = Field( + nullable=True, + description="Description of the collection", + sa_column_kwargs={"comment": "Description of the collection"}, ) project_id: int = Field( foreign_key="project.id", nullable=False, ondelete="CASCADE", + description="Project the collection belongs to", sa_column_kwargs={"comment": "Reference to the project"}, ) - - # Timestamps inserted_at: datetime = Field( default_factory=now, + description="Timestamp when the collection was created", sa_column_kwargs={"comment": "Timestamp when the collection was created"}, ) updated_at: datetime = Field( default_factory=now, + description="Timestamp when the collection was updated", sa_column_kwargs={"comment": "Timestamp when the collection was last updated"}, ) deleted_at: datetime | None = Field( default=None, + description="Timestamp when the collection was deleted", sa_column_kwargs={"comment": "Timestamp when the collection was deleted"}, ) - - # Relationships - organization: Organization = Relationship(back_populates="collections") project: Project = Relationship(back_populates="collections") # Request models -class DocumentOptions(SQLModel): +class CollectionOptions(SQLModel): + name: str | None = Field(default=None, description="Name of the collection") + description: str | None = Field( + default=None, description="Description of the collection" + ) documents: list[UUID] = Field( description="List of document IDs", ) @@ -154,9 +185,9 @@ class ProviderOptions(SQLModel): class CreationRequest( - DocumentOptions, - ProviderOptions, AssistantOptions, + CollectionOptions, + ProviderOptions, CallbackRequest, ): def extract_super_type(self, cls: "CreationRequest"): @@ -181,7 +212,6 @@ class CollectionPublic(SQLModel): llm_service_id: str llm_service_name: str project_id: int - organization_id: int inserted_at: datetime updated_at: datetime diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 0f936607..b9ff9a7f 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -74,9 +74,6 @@ class Organization(OrganizationBase, table=True): assistants: list["Assistant"] = Relationship( back_populates="organization", cascade_delete=True ) - collections: list["Collection"] = Relationship( - back_populates="organization", cascade_delete=True - ) openai_conversations: list["OpenAIConversation"] = Relationship( back_populates="organization", cascade_delete=True ) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ed83e4a8..5e7389db 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -6,7 +6,6 @@ from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage -from app.core.util import now from app.core.db import engine from app.crud import ( CollectionCrud, @@ -14,7 +13,6 @@ DocumentCollectionCrud, CollectionJobCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, CollectionJob, @@ -22,19 +20,12 @@ CollectionJobUpdate, CollectionPublic, CollectionJobPublic, -) -from app.models.collection import ( CreationRequest, - AssistantOptions, -) -from app.services.collections.helpers import ( - _backout, - batch_documents, - extract_error_message, - OPENAI_VECTOR_STORE, ) +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -60,8 +51,8 @@ def start_job( project_id=project_id, job_id=str(collection_job_id), trace_id=trace_id, - with_assistant=with_assistant, request=request.model_dump(mode="json"), + with_assistant=with_assistant, organization_id=organization_id, ) @@ -116,26 +107,6 @@ def build_failure_payload(collection_job: CollectionJob, error_message: str) -> ) -def _cleanup_remote_resources( - assistant, - assistant_crud, - vector_store, - vector_store_crud, -) -> None: - """Best-effort cleanup of partially created remote resources.""" - try: - if assistant is not None and assistant_crud is not None: - _backout(assistant_crud, assistant.id) - elif vector_store is not None and vector_store_crud is not None: - _backout(vector_store_crud, vector_store.id) - else: - logger.warning( - "[create_collection._backout] Skipping: no resource/crud available" - ) - except Exception: - logger.warning("[create_collection.execute_job] Backout failed") - - def _mark_job_failed( project_id: int, job_id: str, @@ -163,29 +134,33 @@ def _mark_job_failed( def execute_job( request: dict, + with_assistant: bool, project_id: int, organization_id: int, task_id: str, job_id: str, - with_assistant: bool, task_instance, ) -> None: """ Worker entrypoint scheduled by start_job. - Orchestrates: job state, client/storage init, batching, vector-store upload, + Orchestrates: job state, provider init, collection creation, optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - # Keep references for potential backout/cleanup on failure - assistant = None - assistant_crud = None - vector_store = None - vector_store_crud = None + # Keeping the references for potential backout/cleanup on failure collection_job = None + result = None + creation_request = None + provider = None try: creation_request = CreationRequest(**request) + if ( + with_assistant + ): # this will be removed once dalgo switches to vector store creation only + creation_request.provider = "openai" + job_uuid = UUID(job_id) with Session(engine) as session: @@ -199,50 +174,28 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) storage = get_cloud_storage(session=session, project_id=project_id) - - # Batch documents for upload, and flatten for linking/metrics later document_crud = DocumentCrud(session, project_id) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, ) - flat_docs = [doc for batch in docs_batches for doc in batch] - vector_store_crud = OpenAIVectorStoreCrud(client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + result = provider.create( + collection_request=creation_request, + storage=storage, + document_crud=document_crud, + ) - # if with_assistant is true, create assistant backed by the vector store - if with_assistant: - assistant_crud = OpenAIAssistantCrud(client) + llm_service_id = result.llm_service_id + llm_service_name = result.llm_service_name - # Filter out None to avoid sending unset options - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store.id, **assistant_options) - llm_service_id = assistant.id - llm_service_name = assistant_options.get("model") or "assistant" - - logger.info( - "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store.id, - ) - else: - # If no assistant, the collection points directly at the vector store - llm_service_id = vector_store.id - llm_service_name = OPENAI_VECTOR_STORE - logger.info( - "[execute_job] Skipping assistant creation | with_assistant=False" - ) + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_each(creation_request.documents) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} file_sizes_kb = [ @@ -253,17 +206,19 @@ def execute_job( collection_crud = CollectionCrud(session, project_id) collection_id = uuid4() + collection = Collection( id=collection_id, project_id=project_id, - organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, + provider=creation_request.provider, + name=creation_request.name, + description=creation_request.description, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - # Link documents to the new collection if flat_docs: DocumentCollectionCrud(session).create(collection, flat_docs) @@ -299,12 +254,13 @@ def execute_job( exc_info=True, ) - _cleanup_remote_resources( - assistant=assistant, - assistant_crud=assistant_crud, - vector_store=vector_store, - vector_store_crud=vector_store_crud, - ) + if provider is not None and result is not None: + try: + provider.delete(result) + except Exception: + logger.warning( + "[create_collection.execute_job] Provider cleanup failed" + ) collection_job = _mark_job_failed( project_id=project_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index ca337b79..db02a353 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -6,7 +6,6 @@ from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( CollectionJobStatus, CollectionJobUpdate, @@ -15,9 +14,10 @@ CollectionIDPublic, ) from app.models.collection import DeletionRequest -from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -155,7 +155,6 @@ def execute_job( job_uuid = UUID(job_id) collection_job = None - client = None try: with Session(engine) as session: @@ -169,20 +168,16 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) - collection = CollectionCrud(session, project_id).read_one(collection_id) - # Identify which external service (assistant/vector store) this collection belongs to - service = (collection.llm_service_name or "").strip().lower() - is_vector = service == OPENAI_VECTOR_STORE - llm_service_id = collection.llm_service_id + provider = get_llm_provider( + session=session, + provider=collection.provider, + project_id=project_id, + organization_id=organization_id, + ) - # Delete the corresponding OpenAI resource (vector store or assistant) - if is_vector: - OpenAIVectorStoreCrud(client).delete(llm_service_id) - else: - OpenAIAssistantCrud(client).delete(llm_service_id) + provider.delete(collection) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..7965e2e2 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -5,17 +5,26 @@ from uuid import UUID from typing import List +from fastapi import HTTPException from sqlmodel import select from openai import OpenAIError -from app.crud.document.document import DocumentCrud +from app.crud import DocumentCrud, CollectionCrud +from app.api.deps import SessionDep from app.models import DocumentCollection, Collection logger = logging.getLogger(__name__) -# llm service name for when only an openai vector store is being made -OPENAI_VECTOR_STORE = "openai vector store" + +def get_service_name(provider: str) -> str: + """Get the collection service name for a provider.""" + names = { + "openai": "openai vector store", + # "bedrock": "bedrock knowledge base", + # "gemini": "gemini file search store", + } + return names.get(provider.lower(), "") def extract_error_message(err: Exception) -> str: @@ -69,17 +78,6 @@ def batch_documents( return docs_batches -def _backout(crud, llm_service_id: str): - """Best-effort cleanup: attempt to delete the assistant by ID""" - try: - crud.delete(llm_service_id) - except OpenAIError as err: - logger.error( - f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", - exc_info=True, - ) - - # Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will # eventually be removed from Kaapi. Once that happens, this function can be safely deleted - def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): @@ -101,4 +99,23 @@ def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): service = ( (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" ) - return v_crud if service == OPENAI_VECTOR_STORE else a_crud + return v_crud if service == get_service_name("openai") else a_crud + + +def ensure_unique_name( + session: SessionDep, + project_id: int, + requested_name: str, +) -> str: + """ + Ensure collection name is unique based on strategy. + + """ + existing = CollectionCrud(session, project_id).exists_by_name(requested_name) + if existing: + raise HTTPException( + status_code=409, + detail=f"Collection '{requested_name}' already exists. Choose a different name.", + ) + + return requested_name diff --git a/backend/app/services/collections/providers/__init__.py b/backend/app/services/collections/providers/__init__.py new file mode 100644 index 00000000..5a9b6a55 --- /dev/null +++ b/backend/app/services/collections/providers/__init__.py @@ -0,0 +1,6 @@ +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider +from app.services.collections.providers.registry import ( + LLMProvider, + get_llm_provider, +) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py new file mode 100644 index 00000000..d76fb618 --- /dev/null +++ b/backend/app/services/collections/providers/base.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import Any + +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.models import CreationRequest, Collection + + +class BaseProvider(ABC): + """Abstract base class for collection providers. + + All provider implementations (OpenAI, Bedrock, etc.) must inherit from + this class and implement the required methods. + + Providers handle creation of collection and + optional assistant/agent creation backed by those collections. + + Attributes: + client: The provider-specific client instance + """ + + def __init__(self, client: Any) -> None: + """Initialize provider with client. + + Args: + client: Provider-specific client instance + """ + self.client = client + + @abstractmethod + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> Collection: + """Create collection with documents and optionally an assistant. + + Args: + collection_request: Collection parameters (name, description, document list, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant/agent + assistant_options: Options for assistant creation (provider-specific) + + Returns: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement execute method") + + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection. + + Called when a collection is being deleted and remote resources need to be cleaned up. + + Args: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement delete method") + + def get_provider_name(self) -> str: + """Get the name of the provider. + + Returns: + Provider name (e.g., "openai", "bedrock", "pinecone") + """ + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py new file mode 100644 index 00000000..b8a73412 --- /dev/null +++ b/backend/app/services/collections/providers/openai.py @@ -0,0 +1,118 @@ +import logging + +from openai import OpenAI + +from app.services.collections.providers import BaseProvider +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.services.collections.helpers import batch_documents, get_service_name +from app.models import CreationRequest, Collection + + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """OpenAI-specific collection provider for vector stores and assistants.""" + + def __init__(self, client: OpenAI): + super().__init__(client) + self.client = client + + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> Collection: + """ + Create OpenAI vector store with documents and optionally an assistant. + """ + try: + docs_batches = batch_documents( + document_crud, + collection_request.documents, + collection_request.batch_size, + ) + + vector_store_crud = OpenAIVectorStoreCrud(self.client) + vector_store = vector_store_crud.create() + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + logger.info( + "[OpenAIProvider.create] Vector store created | " + f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" + ) + + # Check if we need to create an assistant (based on assistant options in request) + with_assistant = ( + collection_request.model is not None + and collection_request.instructions is not None + ) + if with_assistant: + assistant_crud = OpenAIAssistantCrud(self.client) + + assistant_options = { + "model": collection_request.model, + "instructions": collection_request.instructions, + "temperature": collection_request.temperature, + } + filtered_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **filtered_options) + + logger.info( + "[OpenAIProvider.create] Assistant created | " + f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + ) + + return Collection( + llm_service_id=assistant.id, + llm_service_name=filtered_options.get("model", "assistant"), + ) + else: + logger.info( + "[OpenAIProvider.create] Skipping assistant creation | with_assistant=False" + ) + + return Collection( + llm_service_id=vector_store.id, + llm_service_name=get_service_name("openai"), + ) + + except Exception as e: + logger.error( + f"[OpenAIProvider.create] Failed to create collection: {str(e)}", + exc_info=True, + ) + raise + + def delete(self, collection: Collection) -> None: + """Delete OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + """ + try: + if collection.llm_service_name != get_service_name("openai"): + OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" + ) + else: + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) + except Exception as e: + logger.error( + f"[OpenAIProvider.delete] Failed to delete resource | " + f"llm_service_id={collection.llm_service_id}, error={str(e)}", + exc_info=True, + ) + raise diff --git a/backend/app/services/collections/providers/registry.py b/backend/app/services/collections/providers/registry.py new file mode 100644 index 00000000..10d07d45 --- /dev/null +++ b/backend/app/services/collections/providers/registry.py @@ -0,0 +1,71 @@ +import logging + +from sqlmodel import Session +from openai import OpenAI + +from app.crud import get_provider_credential +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider + + +logger = logging.getLogger(__name__) + + +class LLMProvider: + OPENAI = "openai" + # Future constants for providers: + # ANTHROPIC = "ANTHROPIC" + # GEMINI = "gemini" + + _registry: dict[str, type[BaseProvider]] = { + OPENAI: OpenAIProvider, + # Future providers: + # ANTHROPIC: BedrockProvider, + # GEMINI: GeminiProvider, + } + + @classmethod + def get(cls, name: str) -> type[BaseProvider]: + """Return the provider class for a given name.""" + provider = cls._registry.get(name) + if not provider: + raise ValueError( + f"Provider '{name}' is not supported. " + f"Supported providers: {', '.join(cls._registry.keys())}" + ) + return provider + + @classmethod + def supported_providers(cls) -> list[str]: + """Return a list of supported provider names.""" + return list(cls._registry.keys()) + + +def get_llm_provider( + session: Session, provider: str, project_id: int, organization_id: int +) -> BaseProvider: + provider_class = LLMProvider.get(provider) + + credentials = get_provider_credential( + session=session, + provider=provider, + project_id=project_id, + org_id=organization_id, + ) + + if not credentials: + raise ValueError( + f"Credentials for provider '{provider}' not configured for this project." + ) + + if provider == LLMProvider.OPENAI: + if "api_key" not in credentials: + raise ValueError("OpenAI credentials not configured for this project.") + client = OpenAI(api_key=credentials["api_key"]) + else: + logger.error( + f"[get_llm_provider] Unsupported provider type requested: {provider}" + ) + raise ValueError(f"Provider '{provider}' is not supported.") + + return provider_class(client=client) diff --git a/backend/app/tests/api/routes/collections/test_collection_delete.py b/backend/app/tests/api/routes/collections/test_collection_delete.py index d55b461a..f7ed400d 100644 --- a/backend/app/tests/api/routes/collections/test_collection_delete.py +++ b/backend/app/tests/api/routes/collections/test_collection_delete.py @@ -9,7 +9,7 @@ from app.tests.utils.auth import TestAuthContext from app.models import CollectionJobStatus from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection @patch("app.api.routes.collections.delete_service.start_job") @@ -28,7 +28,7 @@ def test_delete_collection_calls_start_job_and_returns_job( - Calls delete_service.start_job with correct arguments """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) resp = client.request( "DELETE", @@ -72,7 +72,7 @@ def test_delete_collection_with_callback_url_passes_it_to_start_job( into the DeletionRequest and then into delete_service.start_job. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) payload = { "callback_url": "https://example.com/collections/delete-callback", diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 781a491b..88cc7ed3 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -6,9 +6,13 @@ from app.core.config import settings from app.tests.utils.utils import get_project, get_document -from app.tests.utils.collection import get_collection, get_vector_store_collection +from app.tests.utils.collection import ( + get_assistant_collection, + get_vector_store_collection, +) from app.crud import DocumentCollectionCrud from app.models import Collection, Document +from app.services.collections.helpers import get_service_name def link_document_to_collection( @@ -40,13 +44,13 @@ def test_collection_info_returns_assistant_collection_with_docs( ) -> None: """ Happy path: - - Assistant-style collection (get_collection) + - Assistant-style collection (get_assistant_collection) - include_docs = True (default) - At least one document linked """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) document = link_document_to_collection(db, collection) @@ -82,7 +86,7 @@ def test_collection_info_include_docs_false_returns_no_docs( When include_docs=false, the endpoint should not populate the documents list. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) link_document_to_collection(db, collection) @@ -113,7 +117,7 @@ def test_collection_info_pagination_skip_and_limit( We create multiple document links and then request a paginated slice. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) documents = db.exec( select(Document).where(Document.deleted_at.is_(None)).limit(2) @@ -163,7 +167,7 @@ def test_collection_info_vector_store_collection( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "openai vector store" + assert payload["llm_service_name"] == get_service_name("openai") assert payload["llm_service_id"] == collection.llm_service_id docs = payload.get("documents", []) @@ -197,7 +201,7 @@ def test_collection_info_include_docs_and_url( the endpoint returns documents with their URLs. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) document = link_document_to_collection(db, collection) diff --git a/backend/app/tests/api/routes/collections/test_collection_job_info.py b/backend/app/tests/api/routes/collections/test_collection_job_info.py index abc2e3cb..ad95b676 100644 --- a/backend/app/tests/api/routes/collections/test_collection_job_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_job_info.py @@ -5,7 +5,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection, get_collection_job +from app.tests.utils.collection import get_assistant_collection, get_collection_job from app.models import ( CollectionActionType, CollectionJobStatus, @@ -41,7 +41,7 @@ def test_collection_info_create_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL @@ -101,7 +101,7 @@ def test_collection_info_delete_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, @@ -133,7 +133,7 @@ def test_collection_info_delete_failed( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index de713532..e9b5626d 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -4,9 +4,10 @@ from app.core.config import settings from app.tests.utils.utils import get_project from app.tests.utils.collection import ( - get_collection, + get_assistant_collection, get_vector_store_collection, ) +from app.services.collections.helpers import get_service_name def test_list_collections_returns_api_response( @@ -39,7 +40,7 @@ def test_list_collections_includes_assistant_collection( user_api_key_header: dict[str, str], ) -> None: """ - Ensure that a newly created assistant-style collection (get_collection) + Ensure that a newly created assistant-style collection (get_assistant_collection) appears in the list for the current project. """ @@ -51,7 +52,7 @@ def test_list_collections_includes_assistant_collection( ) assert response_before.status_code == 200 - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) response_after = client.get( f"{settings.API_V1_STR}/collections/", @@ -101,7 +102,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( row = matching[0] assert row["project_id"] == project.id - assert row["llm_service_name"] == "openai vector store" + assert row["llm_service_name"] == get_service_name("openai") assert row["llm_service_id"] == collection.llm_service_id diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index 3d2e9aa9..d54651c2 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -6,8 +6,8 @@ from sqlmodel import Session from fastapi import HTTPException from fastapi.testclient import TestClient - -from app.tests.utils.openai import mock_openai_assistant +from unittest.mock import patch +from app.tests.utils.llm_provider import mock_openai_assistant from app.tests.utils.utils import get_assistant from app.tests.utils.auth import TestAuthContext diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py index 7957c3d9..2620d80b 100644 --- a/backend/app/tests/api/routes/test_openai_conversation.py +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -3,7 +3,7 @@ from app.crud.openai_conversation import create_conversation from app.models import OpenAIConversationCreate -from app.tests.utils.openai import generate_openai_id +from app.tests.utils.llm_provider import generate_openai_id from app.tests.utils.auth import TestAuthContext diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py index f8b0e1f2..87c8ac67 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -4,7 +4,7 @@ from sqlmodel import Session, select from app.crud import CollectionCrud -from app.models import DocumentCollection, Collection +from app.models import DocumentCollection, Collection, ProviderType from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project @@ -18,9 +18,9 @@ def test_create_associates_documents(self, db: Session) -> None: collection = Collection( id=uuid4(), project_id=project.id, - organization_id=project.organization_id, llm_service_id="asst_dummy", llm_service_name="gpt-4o", + provider=ProviderType.openai, ) store = DocumentStore(db, project_id=collection.project_id) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index 2c2ff45c..5cf4643d 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -3,13 +3,13 @@ from sqlmodel import Session, select from app.crud import CollectionCrud -from app.models import APIKey, Collection +from app.models import APIKey, Collection, ProviderType from app.crud.rag import OpenAIAssistantCrud from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore -def get_collection_for_delete( +def get_assistant_collection_for_delete( db: Session, client=None, project_id: int = None ) -> Collection: project = get_project(db) @@ -24,10 +24,10 @@ def get_collection_for_delete( ) return Collection( - organization_id=project.organization_id, project_id=project_id, llm_service_id=assistant.id, llm_service_name="gpt-4o", + provider=ProviderType.openai, ) @@ -40,7 +40,9 @@ def test_delete_marks_deleted(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection_for_delete(db, client, project_id=project.id) + collection = get_assistant_collection_for_delete( + db, client, project_id=project.id + ) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -53,7 +55,7 @@ def test_delete_follows_insert(self, db: Session) -> None: assistant = OpenAIAssistantCrud(client) project = get_project(db) - collection = get_collection_for_delete(db, project_id=project.id) + collection = get_assistant_collection_for_delete(db, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -74,7 +76,9 @@ def test_delete_document_deletes_collections(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection_for_delete(db, client, project_id=project.id) + coll = get_assistant_collection_for_delete( + db, client, project_id=project.id + ) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index fcf1cafc..0382b883 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -6,7 +6,7 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection def create_collections(db: Session, n: int) -> Collection: @@ -16,7 +16,7 @@ def create_collections(db: Session, n: int) -> Collection: with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): - collection = get_collection(db, project=project) + collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) if crud is None: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index fe601196..2fc5f767 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -8,15 +8,14 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection def mk_collection(db: Session) -> Collection: openai_mock = OpenAIMock() project = get_project(db) with openai_mock.router: - client = OpenAI(api_key="sk-test-key") - collection = get_collection(db, project=project) + collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) crud = CollectionCrud(db, collection.project_id) diff --git a/backend/app/tests/crud/test_assistants.py b/backend/app/tests/crud/test_assistants.py index e40669c1..12aa4994 100644 --- a/backend/app/tests/crud/test_assistants.py +++ b/backend/app/tests/crud/test_assistants.py @@ -15,7 +15,7 @@ get_assistant_by_id, get_assistants_by_project, ) -from app.tests.utils.openai import mock_openai_assistant +from app.tests.utils.llm_provider import mock_openai_assistant from app.tests.utils.utils import ( get_project, get_assistant, diff --git a/backend/app/tests/crud/test_openai_conversation.py b/backend/app/tests/crud/test_openai_conversation.py index fec3da1e..314238bc 100644 --- a/backend/app/tests/crud/test_openai_conversation.py +++ b/backend/app/tests/crud/test_openai_conversation.py @@ -15,7 +15,7 @@ ) from app.models import OpenAIConversationCreate, Project from app.tests.utils.utils import get_project, get_organization -from app.tests.utils.openai import generate_openai_id +from app.tests.utils.llm_provider import generate_openai_id def test_get_conversation_by_id_success(db: Session) -> None: diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py new file mode 100644 index 00000000..a4193b2f --- /dev/null +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -0,0 +1,156 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.collections.providers.openai import OpenAIProvider +from app.models.collection import Collection +from app.services.collections.helpers import get_service_name +from app.tests.utils.llm_provider import ( + generate_openai_id, + get_mock_openai_client_with_vector_store, +) + + +def test_create_openai_vector_store_only() -> None: + client = get_mock_openai_client_with_vector_store() + provider = OpenAIProvider(client=client) + + collection_request = SimpleNamespace( + documents=["doc1", "doc2"], + batch_size=1, + model=None, + instructions=None, + temperature=None, + ) + + storage = MagicMock() + document_crud = MagicMock() + + fake_batches = [["doc1"], ["doc2"]] + vector_store_id = generate_openai_id("vs_") + + with patch( + "app.services.collections.providers.openai.batch_documents", + return_value=fake_batches, + ), patch( + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls: + vector_store_crud = vector_store_crud_cls.return_value + vector_store_crud.create.return_value = MagicMock(id=vector_store_id) + vector_store_crud.update.return_value = iter([None]) + + collection = provider.create( + collection_request, + storage, + document_crud, + ) + + assert isinstance(collection, Collection) + assert collection.llm_service_id == vector_store_id + assert collection.llm_service_name == get_service_name("openai") + + +def test_create_openai_with_assistant() -> None: + client = get_mock_openai_client_with_vector_store() + provider = OpenAIProvider(client=client) + + collection_request = SimpleNamespace( + documents=["doc1"], + batch_size=1, + model="gpt-4o", + instructions="You are helpful", + temperature=0.7, + ) + + storage = MagicMock() + document_crud = MagicMock() + + fake_batches = [["doc1"]] + vector_store_id = generate_openai_id("vs_") + assistant_id = generate_openai_id("asst_") + + with patch( + "app.services.collections.providers.openai.batch_documents", + return_value=fake_batches, + ), patch( + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls, patch( + "app.services.collections.providers.openai.OpenAIAssistantCrud" + ) as assistant_crud_cls: + vector_store_crud = vector_store_crud_cls.return_value + vector_store_crud.create.return_value = MagicMock(id=vector_store_id) + vector_store_crud.update.return_value = iter([None]) + + assistant_crud = assistant_crud_cls.return_value + assistant_crud.create.return_value = MagicMock(id=assistant_id) + + collection = provider.create( + collection_request, + storage, + document_crud, + ) + + assert collection.llm_service_id == assistant_id + assert collection.llm_service_name == "gpt-4o" + + +def test_delete_openai_assistant() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + collection = Collection( + llm_service_id=generate_openai_id("asst_"), + llm_service_name="gpt-4o", + provider="openai", + project_id=1, + ) + + with patch( + "app.services.collections.providers.openai.OpenAIAssistantCrud" + ) as assistant_crud_cls: + assistant_crud = assistant_crud_cls.return_value + provider.delete(collection) + + assistant_crud.delete.assert_called_once_with(collection.llm_service_id) + + +def test_delete_openai_vector_store() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + collection = Collection( + llm_service_id=generate_openai_id("vs_"), + llm_service_name=get_service_name("openai"), + ) + + with patch( + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls: + vector_store_crud = vector_store_crud_cls.return_value + provider.delete(collection) + + vector_store_crud.delete.assert_called_once_with(collection.llm_service_id) + + +def test_create_propagates_exception() -> None: + provider = OpenAIProvider(client=MagicMock()) + + collection_request = SimpleNamespace( + documents=["doc1"], + batch_size=1, + model=None, + instructions=None, + temperature=None, + ) + + with patch( + "app.services.collections.providers.openai.batch_documents", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(RuntimeError): + provider.create( + collection_request, + MagicMock(), + MagicMock(), + ) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 0ea5e495..2fefd60e 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -16,9 +16,9 @@ from app.models import CollectionJobStatus, CollectionJob, CollectionActionType, Project from app.models.collection import CreationRequest from app.services.collections.create_collection import start_job, execute_job -from app.tests.utils.openai import get_mock_openai_client_with_vector_store +from app.tests.utils.llm_provider import get_mock_provider from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection_job, get_collection +from app.tests.utils.collection import get_collection_job, get_assistant_collection from app.tests.utils.document import DocumentStore @@ -57,12 +57,10 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non """ project = get_project(db) request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], batch_size=1, callback_url=None, + provider="openai", ) job_id = uuid4() @@ -115,9 +113,9 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_openai_client: Any, db: Session + mock_get_llm_provider: MagicMock, db: Session ) -> None: """ execute_job should: @@ -139,16 +137,12 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, - documents=[document.id], - batch_size=1, - callback_url=None, + documents=[document.id], batch_size=1, callback_url=None, provider="openai" ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + ) job_id = uuid4() _ = get_collection_job( @@ -184,8 +178,8 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( created_collection = CollectionCrud(db, project.id).read_one( updated_job.collection_id ) - assert created_collection.llm_service_id == "mock_assistant_id" - assert created_collection.llm_service_name == sample_request.model + assert created_collection.llm_service_id == "mock_vector_store_id" + assert created_collection.llm_service_name == "openai vector store" assert created_collection.updated_at is not None docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) @@ -195,9 +189,9 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") -def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( - mock_get_openai_client: Any, db: Session +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( + mock_get_llm_provider: MagicMock, db ) -> None: project = get_project(db) @@ -211,32 +205,23 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( ) req = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.0, - documents=[], - batch_size=1, - callback_url=None, + documents=[], batch_size=1, callback_url=None, provider="openai" ) - _ = mock_get_openai_client.return_value + mock_provider = get_mock_provider( + llm_service_id="vs_123", llm_service_name="openai vector store" + ) + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.create_collection.Session" ) as SessionCtor, patch( - "app.services.collections.create_collection.OpenAIVectorStoreCrud" - ) as MockVS, patch( - "app.services.collections.create_collection.OpenAIAssistantCrud" - ) as MockAsst: + "app.services.collections.create_collection.CollectionCrud" + ) as MockCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False - MockVS.return_value.create.return_value = type( - "Vector store", (), {"id": "vs_123"} - )() - MockVS.return_value.update.return_value = [] - - MockAsst.return_value.create.side_effect = RuntimeError("assistant boom") + MockCrud.return_value.create.side_effect = Exception("DB constraint violation") task_id = str(uuid4()) execute_job( @@ -249,20 +234,17 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( task_instance=None, ) - failed = CollectionJobCrud(db, project.id).read_one(job.id) - assert failed.task_id == task_id - assert failed.status == CollectionJobStatus.FAILED - assert "assistant boom" in (failed.error_message or "") - - MockVS.return_value.delete.assert_called_once_with("vs_123") + mock_provider.delete.assert_called_once() @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_flow_callback_job_and_creates_collection( - mock_send_callback: Any, mock_get_openai_client: Any, db: Session + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, + db, ) -> None: """ execute_job should: @@ -286,16 +268,15 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( callback_url = "https://example.com/collections/create-success" sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[document.id], batch_size=1, callback_url=callback_url, + provider="openai", ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + ) job_id = uuid.uuid4() _ = get_collection_job( @@ -339,10 +320,12 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_creates_collection_with_callback( - mock_send_callback: Any, mock_get_openai_client: Any, db: Session + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, + db, ) -> None: """ execute_job should: @@ -366,16 +349,15 @@ def test_execute_job_success_creates_collection_with_callback( callback_url = "https://example.com/collections/create-success" sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[document.id], batch_size=1, callback_url=callback_url, + provider="openai", ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="gpt-4o" + ) job_id = uuid.uuid4() _ = get_collection_job( @@ -419,13 +401,13 @@ def test_execute_job_success_creates_collection_with_callback( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") @patch("app.services.collections.create_collection.CollectionCrud") def test_execute_job_failure_flow_callback_job_and_marks_failed( - MockCollectionCrud: Any, - mock_send_callback: Any, - mock_get_openai_client: Any, + MockCollectionCrud, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, db: Session, ) -> None: """ @@ -434,7 +416,7 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_assistant_collection(db, project, assistant_id="asst_123") job = get_collection_job( db, project, @@ -443,7 +425,7 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( collection_id=None, ) - mock_get_openai_client.return_value = MagicMock() + mock_get_llm_provider.return_value = MagicMock() callback_url = "https://example.com/collections/create-failure" @@ -451,12 +433,10 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( collection_crud_instance.read_one.return_value = collection sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[uuid.uuid4()], batch_size=1, callback_url=callback_url, + provider="openai", ) task_id = uuid.uuid4() diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index 26153ee4..d1243ba5 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -1,15 +1,13 @@ -from typing import Any from unittest.mock import patch, MagicMock from uuid import uuid4, UUID from sqlmodel import Session -from sqlalchemy.exc import SQLAlchemyError from app.models.collection import DeletionRequest from app.tests.utils.utils import get_project from app.crud import CollectionJobCrud from app.models import CollectionJobStatus, CollectionActionType -from app.tests.utils.collection import get_collection, get_collection_job +from app.tests.utils.collection import get_collection_job, get_vector_store_collection from app.services.collections.delete_collection import start_job, execute_job @@ -20,7 +18,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non - return the same job_id (UUID) """ project = get_project(db) - created_collection = get_collection(db, project) + created_collection = get_vector_store_collection(db, project) req = DeletionRequest(collection_id=created_collection.id) @@ -72,19 +70,20 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non assert "trace_id" in kwargs -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") def test_execute_job_delete_success_updates_job_and_calls_delete( - mock_get_openai_client: Any, db: Session + mock_get_llm_provider: MagicMock, db ) -> None: """ - execute_job should set task_id on the CollectionJob - - call remote delete via OpenAIAssistantCrud.delete(...) + - call provider.delete() to delete remote resources - delete local record via CollectionCrud.delete_by_id(...) - mark job successful and clear error_message """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection(db, project, vector_store_id="asst_123") + job = get_collection_job( db, project, @@ -93,13 +92,13 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete = MagicMock() + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db @@ -108,8 +107,6 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.return_value = None - task_id = uuid4() req = DeletionRequest(collection_id=collection.id) @@ -128,27 +125,33 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( assert updated_job.status == CollectionJobStatus.SUCCESSFUL assert updated_job.error_message in (None, "") + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() -@patch("app.services.collections.delete_collection.get_openai_client") + +@patch("app.services.collections.delete_collection.get_llm_provider") def test_execute_job_delete_failure_marks_job_failed( - mock_get_openai_client: Any, db: Session + mock_get_llm_provider: MagicMock, db ) -> None: """ - When the remote delete (OpenAIAssistantCrud.delete) raises, - the job should be marked FAILED and error_message set. + When provider.delete() raises an exception: + - Job should be marked FAILED + - error_message should be set + - Local collection should NOT be deleted """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector_123", + ) + job = get_collection_job( db, project, @@ -157,13 +160,13 @@ def test_execute_job_delete_failure_marks_job_failed( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete.side_effect = Exception("Remote deletion failed") + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db @@ -172,10 +175,6 @@ def test_execute_job_delete_failure_marks_job_failed( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( - "something went wrong" - ) - task_id = uuid4() req = DeletionRequest(collection_id=collection.id) @@ -194,23 +193,24 @@ def test_execute_job_delete_failure_marks_job_failed( assert failed_job.status == CollectionJobStatus.FAILED assert ( failed_job.error_message - and "something went wrong" in failed_job.error_message + and "Remote deletion failed" in failed_job.error_message ) + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_not_called() - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() -@patch("app.services.collections.delete_collection.get_openai_client") + +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_success_with_callback_sends_success_payload( - mock_get_openai_client: Any, - db: Session, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, + db, ) -> None: """ When deletion succeeds and a callback_url is provided: @@ -220,7 +220,12 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector 123", + ) + job = get_collection_job( db, project, @@ -229,27 +234,23 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete = MagicMock() + mock_get_llm_provider.return_value = mock_provider callback_url = "https://example.com/collections/delete-success" with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" - ) as MockCollectionCrud, patch( - "app.services.collections.delete_collection.send_callback" - ) as mock_send_callback: + ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.return_value = None - task_id = uuid4() req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) @@ -268,12 +269,12 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( assert updated_job.status == CollectionJobStatus.SUCCESSFUL assert updated_job.error_message in (None, "") + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() mock_send_callback.assert_called_once() cb_url_arg, payload_arg = mock_send_callback.call_args.args @@ -285,20 +286,27 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( assert UUID(payload_arg["data"]["job_id"]) == job.id -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( - mock_get_openai_client: Any, - db: Session, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, + db, ) -> None: """ - When the remote delete raises AND a callback_url is provided: + When provider.delete() raises AND a callback_url is provided: - job is marked FAILED with error_message set - send_callback is called once - failure payload has success=False, status=FAILED, correct collection id, and error message """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector_123", + ) + job = get_collection_job( db, project, @@ -307,28 +315,23 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete.side_effect = Exception("Remote deletion failed") + mock_get_llm_provider.return_value = mock_provider + callback_url = "https://example.com/collections/delete-failed" with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" - ) as MockCollectionCrud, patch( - "app.services.collections.delete_collection.send_callback" - ) as mock_send_callback: + ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( - "something went wrong" - ) - task_id = uuid4() req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) @@ -347,24 +350,22 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( assert failed_job.status == CollectionJobStatus.FAILED assert ( failed_job.error_message - and "something went wrong" in failed_job.error_message + and "Remote deletion failed" in failed_job.error_message ) + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_not_called() - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() mock_send_callback.assert_called_once() cb_url_arg, payload_arg = mock_send_callback.call_args.args assert str(cb_url_arg) == callback_url assert payload_arg["success"] is False - assert "something went wrong" in (payload_arg["error"] or "") + assert "Remote deletion failed" in (payload_arg["error"] or "") assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED assert payload_arg["data"]["collection"]["id"] == str(collection.id) assert UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index be94ffeb..f53271f1 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -1,11 +1,17 @@ from __future__ import annotations import json -from typing import Any from types import SimpleNamespace from uuid import uuid4 +import pytest +from sqlmodel import Session +from fastapi import HTTPException + from app.services.collections import helpers +from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_vector_store_collection +from app.services.collections.helpers import ensure_unique_name def test_extract_error_message_parses_json_and_strips_prefix() -> None: @@ -84,28 +90,38 @@ def test_batch_documents_empty_input() -> None: assert crud.calls == [] -# _backout +def test_ensure_unique_name_success(db: Session) -> None: + requested_name = "new_collection_name" + + project = get_project(db) + result = ensure_unique_name( + session=db, + project_id=project.id, + requested_name=requested_name, + ) -def test_backout_calls_delete_and_swallows_openai_error(monkeypatch: Any) -> None: - class Crud: - def __init__(self): - self.calls = 0 + assert result == requested_name - def delete(self, resource_id: str): - self.calls += 1 - crud = Crud() - helpers._backout(crud, "rsrc_1") - assert crud.calls == 1 +def test_ensure_unique_name_conflict_with_vector_store_collection(db: Session) -> None: + existing_name = "vector_collection" + project = get_project(db) - class DummyOpenAIError(Exception): - pass + collection = get_vector_store_collection( + db=db, + project=project, + ) - monkeypatch.setattr(helpers, "OpenAIError", DummyOpenAIError) + collection.name = existing_name + db.commit() - class FailingCrud: - def delete(self, resource_id: str): - raise DummyOpenAIError("nope") + with pytest.raises(HTTPException) as exc: + ensure_unique_name( + session=db, + project_id=project.id, + requested_name=existing_name, + ) - helpers._backout(FailingCrud(), "rsrc_2") + assert exc.value.status_code == 409 + assert "already exists" in exc.value.detail diff --git a/backend/app/tests/services/llm/providers/test_openai.py b/backend/app/tests/services/llm/providers/test_openai.py index 745dd00b..4dfb671e 100644 --- a/backend/app/tests/services/llm/providers/test_openai.py +++ b/backend/app/tests/services/llm/providers/test_openai.py @@ -12,7 +12,7 @@ ) from app.models.llm.request import ConversationConfig from app.services.llm.providers.openai import OpenAIProvider -from app.tests.utils.openai import mock_openai_response +from app.tests.utils.llm_provider import mock_openai_response class TestOpenAIProvider: diff --git a/backend/app/tests/services/response/response/test_process_response.py b/backend/app/tests/services/response/response/test_process_response.py index 3145dd6b..68ba24d2 100644 --- a/backend/app/tests/services/response/response/test_process_response.py +++ b/backend/app/tests/services/response/response/test_process_response.py @@ -17,7 +17,7 @@ ) from app.utils import APIResponse from app.tests.utils.test_data import create_test_credential -from app.tests.utils.openai import mock_openai_response, generate_openai_id +from app.tests.utils.llm_provider import mock_openai_response, generate_openai_id from app.crud import JobCrud, create_assistant diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index bdc68b29..f1844cb9 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,9 +8,11 @@ CollectionActionType, CollectionJob, CollectionJobStatus, + ProviderType, Project, ) from app.crud import CollectionCrud, CollectionJobCrud +from app.services.collections.helpers import get_service_name class constants: @@ -23,7 +25,7 @@ def uuid_increment(value: UUID) -> UUID: return UUID(int=inc) -def get_collection( +def get_assistant_collection( db: Session, project: Project, *, @@ -44,6 +46,7 @@ def get_collection( organization_id=project.organization_id, llm_service_name=model, llm_service_id=assistant_id, + provider=ProviderType.openai, ) return CollectionCrud(db, project.id).create(collection) @@ -65,9 +68,9 @@ def get_vector_store_collection( collection = Collection( id=collection_id or uuid4(), project_id=project.id, - organization_id=project.organization_id, - llm_service_name="openai vector store", + llm_service_name=get_service_name("openai"), llm_service_id=vector_store_id, + provider=ProviderType.openai, ) return CollectionCrud(db, project.id).create(collection) diff --git a/backend/app/tests/utils/llm_provider.py b/backend/app/tests/utils/llm_provider.py new file mode 100644 index 00000000..afa2dfde --- /dev/null +++ b/backend/app/tests/utils/llm_provider.py @@ -0,0 +1,197 @@ +import time +import secrets +import string +from typing import Optional +from types import SimpleNamespace +from unittest.mock import MagicMock + +from openai.types.beta import Assistant as OpenAIAssistant +from openai.types.beta.assistant import ToolResources, ToolResourcesFileSearch +from openai.types.beta.assistant_tool import FileSearchTool +from openai.types.beta.file_search_tool import FileSearch + + +def generate_openai_id(prefix: str, length: int = 40) -> str: + """Generate a realistic ID similar to OpenAI's format (alphanumeric only)""" + # Generate random alphanumeric string + chars = string.ascii_lowercase + string.digits + random_part = "".join(secrets.choice(chars) for _ in range(length)) + return f"{prefix}{random_part}" + + +def mock_openai_assistant( + assistant_id: str = "assistant_mock", + vector_store_ids: Optional[list[str]] = ["vs_1", "vs_2"], + max_num_results: int = 30, +) -> OpenAIAssistant: + return OpenAIAssistant( + id=assistant_id, + created_at=int(time.time()), + description="Mock description", + instructions="Mock instructions", + metadata={}, + model="gpt-4o", + name="Mock Assistant", + object="assistant", + tools=[ + FileSearchTool( + type="file_search", + file_search=FileSearch( + max_num_results=max_num_results, + ), + ) + ], + temperature=1.0, + tool_resources=ToolResources( + code_interpreter=None, + file_search=ToolResourcesFileSearch(vector_store_ids=vector_store_ids), + ), + top_p=1.0, + reasoning_effort=None, + ) + + +def mock_openai_response( + text: str = "Hello world", + previous_response_id: str | None = None, + model: str = "gpt-4", + conversation_id: str | None = None, +) -> SimpleNamespace: + """Return a minimal mock OpenAI-like response object for testing. + + Args: + text: The response text + previous_response_id: Optional previous response ID + model: Model name + conversation_id: Optional conversation ID. If provided, adds conversation object to response. + """ + + usage = SimpleNamespace( + input_tokens=10, + output_tokens=20, + total_tokens=30, + ) + + output_item = SimpleNamespace( + id=generate_openai_id("out_"), + type="message", + role="assistant", + content=[{"type": "output_text", "text": text}], + ) + + conversation = None + if conversation_id: + conversation = SimpleNamespace(id=conversation_id) + + response = SimpleNamespace( + id=generate_openai_id("resp_"), + created_at=int(time.time()), + model=model, + object="response", + output=[output_item], + output_text=text, + usage=usage, + previous_response_id=previous_response_id, + conversation=conversation, + model_dump=lambda: { + "id": response.id, + "model": model, + "output_text": text, + "usage": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + }, + "conversation": {"id": conversation_id} if conversation_id else None, + }, + ) + return response + + +def get_mock_openai_client_with_vector_store() -> MagicMock: + mock_client = MagicMock() + + mock_vector_store = MagicMock() + mock_vector_store.id = "mock_vector_store_id" + mock_client.vector_stores.create.return_value = mock_vector_store + + mock_file_batch = MagicMock() + mock_file_batch.file_counts.completed = 2 + mock_file_batch.file_counts.total = 2 + mock_client.vector_stores.file_batches.upload_and_poll.return_value = ( + mock_file_batch + ) + + mock_client.vector_stores.files.list.return_value = {"data": []} + + mock_assistant = MagicMock() + mock_assistant.id = "mock_assistant_id" + mock_assistant.name = "Mock Assistant" + mock_assistant.model = "gpt-4o" + mock_assistant.instructions = "Mock instructions" + mock_client.beta.assistants.create.return_value = mock_assistant + + return mock_client + + +def create_mock_batch( + batch_id: str = "batch-xyz789", + status: str = "completed", + output_file_id: str | None = "output-file-123", + error_file_id: str | None = None, + total: int = 100, + completed: int = 100, + failed: int = 0, +) -> MagicMock: + """ + Create a mock OpenAI batch object with configurable properties. + + Args: + batch_id: The batch ID + status: Batch status (completed, in_progress, failed, expired, cancelled, etc.) + output_file_id: Output file ID (None for incomplete batches) + error_file_id: Error file ID (None if no errors) + total: Total number of requests in the batch + completed: Number of completed requests + failed: Number of failed requests + + Returns: + MagicMock configured to represent an OpenAI batch object + """ + mock_batch = MagicMock() + mock_batch.id = batch_id + mock_batch.status = status + mock_batch.output_file_id = output_file_id + mock_batch.error_file_id = error_file_id + + # Create request_counts mock + mock_batch.request_counts = MagicMock() + mock_batch.request_counts.total = total + mock_batch.request_counts.completed = completed + mock_batch.request_counts.failed = failed + + return mock_batch + + +def get_mock_provider( + llm_service_id: str = "mock_service_id", + llm_service_name: str = "mock_service_name", +): + """ + Create a properly configured mock provider for tests. + + Returns a mock that mimics BaseProvider with: + - create() method returning result with llm_service_id and llm_service_name + - cleanup() method for cleanup on failure + - delete() method for deletion + """ + mock_provider = MagicMock() + + mock_result = MagicMock() + mock_result.llm_service_id = llm_service_id + mock_result.llm_service_name = llm_service_name + + mock_provider.create.return_value = mock_result + mock_provider.delete = MagicMock() + + return mock_provider