From fd18a020e5366575d5521728a5ff087597d25167 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 24 Dec 2025 16:45:24 +0530 Subject: [PATCH 01/11] pushing the logic first --- ..._adding_blob_column_in_collection_table.py | 47 ++++++ backend/app/models/__init__.py | 4 + backend/app/models/collection/__init__.py | 14 ++ .../{collection.py => collection/request.py} | 122 ++++++++------ backend/app/models/collection/response.py | 33 ++++ .../services/collections/create_collection.py | 118 ++++--------- .../services/collections/delete_collection.py | 28 ++-- backend/app/services/collections/helpers.py | 11 -- .../collections/providers/__init__.py | 6 + .../services/collections/providers/base.py | 84 ++++++++++ .../services/collections/providers/openai.py | 156 ++++++++++++++++++ .../collections/providers/registry.py | 71 ++++++++ 12 files changed, 533 insertions(+), 161 deletions(-) create mode 100644 backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py create mode 100644 backend/app/models/collection/__init__.py rename backend/app/models/{collection.py => collection/request.py} (63%) create mode 100644 backend/app/models/collection/response.py create mode 100644 backend/app/services/collections/providers/__init__.py create mode 100644 backend/app/services/collections/providers/base.py create mode 100644 backend/app/services/collections/providers/openai.py create mode 100644 backend/app/services/collections/providers/registry.py diff --git a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py new file mode 100644 index 00000000..8f65f055 --- /dev/null +++ b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py @@ -0,0 +1,47 @@ +"""adding blob column in collection table + +Revision ID: 041 +Revises: 040 +Create Date: 2025-12-24 11:03:44.620424 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection", + sa.Column( + "collection_blob", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Provider-specific knowledge base creation parameters (name, description, chunking params etc.)", + ), + ) + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM service provider", + existing_nullable=False, + ) + + +def downgrade(): + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "collection_blob") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ac7e89d6..ef08fd09 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,13 @@ from .collection import ( Collection, + CreateCollectionParams, + CreateCollectionResult, + CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, + DeletionRequest, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection/__init__.py b/backend/app/models/collection/__init__.py new file mode 100644 index 00000000..e31f65bc --- /dev/null +++ b/backend/app/models/collection/__init__.py @@ -0,0 +1,14 @@ +from app.models.collection.request import ( + Collection, + CreationRequest, + DeletionRequest, + CallbackRequest, + AssistantOptions, + CreateCollectionParams, +) +from app.models.collection.response import ( + CollectionIDPublic, + CollectionPublic, + CollectionWithDocsPublic, + CreateCollectionResult, +) diff --git a/backend/app/models/collection.py b/backend/app/models/collection/request.py similarity index 63% rename from backend/app/models/collection.py rename to backend/app/models/collection/request.py index 57e5a17b..9f8e106b 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection/request.py @@ -3,13 +3,13 @@ from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field, Relationship, SQLModel from app.core.util import now -from app.models.document import DocumentPublic - -from .organization import Organization -from .project import Project +from app.models.organization import Organization +from app.models.project import Project class Collection(SQLModel, table=True): @@ -30,8 +30,13 @@ class Collection(SQLModel, table=True): nullable=False, sa_column_kwargs={"comment": "Name of the LLM service"}, ) - - # Foreign keys + collection_blob: dict[str, Any] | None = Field( + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Provider-specific collection parameters (name, description, chunking params etc.)", + ) + ) organization_id: int = Field( foreign_key="organization.id", nullable=False, @@ -44,8 +49,6 @@ class Collection(SQLModel, table=True): ondelete="CASCADE", sa_column_kwargs={"comment": "Reference to the project"}, ) - - # Timestamps inserted_at: datetime = Field( default_factory=now, sa_column_kwargs={"comment": "Timestamp when the collection was created"}, @@ -64,27 +67,55 @@ class Collection(SQLModel, table=True): project: Project = Relationship(back_populates="collections") -# Request models -class DocumentOptions(SQLModel): - documents: list[UUID] = Field( - description="List of document IDs", +class DocumentInput(SQLModel): + """Document to be added to knowledge base.""" + + name: str | None = Field( + description="Display name for the document", ) - batch_size: int = Field( - default=1, - description=( - "Number of documents to send to OpenAI in a single " - "transaction. See the `file_ids` parameter in the " - "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." - ), + id: UUID = Field( + description="Reference to uploaded file/document in Kaapi", + ) + + +class CreateCollectionParams(SQLModel): + """Request-specific parameters for knowledge base creation.""" + + name: str | None = Field( + min_length=1, + description="Name of the knowledge base to create or update", + ) + description: str | None = Field( + default=None, + description="Description of the knowledge base (required by Bedrock, optional for others)", + ) + documents: list[DocumentInput] = Field( + default_factory=list, + description="List of documents to add to the knowledge base", + ) + chunking_params: dict[str, Any] | None = Field( + default=None, + description="Chunking parameters for document processing (e.g., chunk_size, chunk_overlap)", + ) + additional_params: dict[str, Any] | None = Field( + default=None, + description="Additional provider-specific parameters", ) def model_post_init(self, __context: Any): - self.documents = list(set(self.documents)) + """Deduplicate documents by file_id.""" + seen = set() + unique_docs = [] + for doc in self.documents: + if doc.file_id not in seen: + seen.add(doc.file_id) + unique_docs.append(doc) + self.documents = unique_docs class AssistantOptions(SQLModel): # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create + # parameters accepted by the OpenAI.client.beta.assistants.create # API. model: str | None = Field( default=None, @@ -139,6 +170,8 @@ def norm(x: Any) -> Any: class CallbackRequest(SQLModel): + """Optional callback configuration for async job notifications.""" + callback_url: HttpUrl | None = Field( default=None, description="URL to call to report endpoint status", @@ -153,40 +186,23 @@ class ProviderOptions(SQLModel): ) -class CreationRequest( - DocumentOptions, - ProviderOptions, - AssistantOptions, - CallbackRequest, -): - def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.model_fields.keys(): - field_value = getattr(self, field_name) - yield (field_name, field_value) - - -class DeletionRequest(CallbackRequest): - collection_id: UUID = Field(description="Collection to delete") - - -# Response models - - -class CollectionIDPublic(SQLModel): - id: UUID +class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): + """API request for collection creation""" + collection_params: CreateCollectionParams = Field( + ..., + description="Collection creation specific parameters (name, documents, etc.)", + ) + batch_size: int = Field( + default=10, + ge=1, + le=500, + description="Number of documents to process in a single batch", + ) -class CollectionPublic(SQLModel): - id: UUID - llm_service_id: str - llm_service_name: str - project_id: int - organization_id: int - inserted_at: datetime - updated_at: datetime - deleted_at: datetime | None = None +class DeletionRequest(ProviderOptions, CallbackRequest): + """API request for collection deletion""" -class CollectionWithDocsPublic(CollectionPublic): - documents: list[DocumentPublic] | None = None + collection_id: UUID = Field(description="Collection to delete") diff --git a/backend/app/models/collection/response.py b/backend/app/models/collection/response.py new file mode 100644 index 00000000..f72c5ee7 --- /dev/null +++ b/backend/app/models/collection/response.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlmodel import SQLModel + +from app.models.document import DocumentPublic + + +class CreateCollectionResult(SQLModel): + llm_service_id: str + llm_service_name: str + collection_blob: dict[str, Any] + + +class CollectionIDPublic(SQLModel): + id: UUID + + +class CollectionPublic(SQLModel): + id: UUID + llm_service_id: str + llm_service_name: str + project_id: int + organization_id: int + + inserted_at: datetime + updated_at: datetime + deleted_at: datetime | None = None + + +class CollectionWithDocsPublic(CollectionPublic): + documents: list[DocumentPublic] | None = None diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ed83e4a8..1086dc71 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -6,7 +6,6 @@ from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage -from app.core.util import now from app.core.db import engine from app.crud import ( CollectionCrud, @@ -14,7 +13,6 @@ DocumentCollectionCrud, CollectionJobCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, CollectionJob, @@ -23,18 +21,11 @@ CollectionPublic, CollectionJobPublic, ) -from app.models.collection import ( - CreationRequest, - AssistantOptions, -) -from app.services.collections.helpers import ( - _backout, - batch_documents, - extract_error_message, - OPENAI_VECTOR_STORE, -) +from app.models.collection import CreationRequest +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -116,26 +107,6 @@ def build_failure_payload(collection_job: CollectionJob, error_message: str) -> ) -def _cleanup_remote_resources( - assistant, - assistant_crud, - vector_store, - vector_store_crud, -) -> None: - """Best-effort cleanup of partially created remote resources.""" - try: - if assistant is not None and assistant_crud is not None: - _backout(assistant_crud, assistant.id) - elif vector_store is not None and vector_store_crud is not None: - _backout(vector_store_crud, vector_store.id) - else: - logger.warning( - "[create_collection._backout] Skipping: no resource/crud available" - ) - except Exception: - logger.warning("[create_collection.execute_job] Backout failed") - - def _mark_job_failed( project_id: int, job_id: str, @@ -172,17 +143,15 @@ def execute_job( ) -> None: """ Worker entrypoint scheduled by start_job. - Orchestrates: job state, client/storage init, batching, vector-store upload, + Orchestrates: job state, provider init, collection creation, optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - # Keep references for potential backout/cleanup on failure - assistant = None - assistant_crud = None - vector_store = None - vector_store_crud = None + # Keeping the references for potential backout/cleanup on failure collection_job = None + result = None + provider = None try: creation_request = CreationRequest(**request) @@ -199,49 +168,32 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) storage = get_cloud_storage(session=session, project_id=project_id) - - # Batch documents for upload, and flatten for linking/metrics later document_crud = DocumentCrud(session, project_id) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, ) - flat_docs = [doc for batch in docs_batches for doc in batch] - vector_store_crud = OpenAIVectorStoreCrud(client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + result = provider.create( + collection_request=creation_request, + storage=storage, + document_crud=document_crud, + ) - # if with_assistant is true, create assistant backed by the vector store - if with_assistant: - assistant_crud = OpenAIAssistantCrud(client) + llm_service_id = result.llm_service_id + llm_service_name = result.llm_service_name + # Storing collection params (name, description, chunking_params, etc.) in DB + # for future reference and to support different providers with varying configurations + collection_blob = result.collection_blob - # Filter out None to avoid sending unset options - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store.id, **assistant_options) - llm_service_id = assistant.id - llm_service_name = assistant_options.get("model") or "assistant" - - logger.info( - "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store.id, - ) - else: - # If no assistant, the collection points directly at the vector store - llm_service_id = vector_store.id - llm_service_name = OPENAI_VECTOR_STORE - logger.info( - "[execute_job] Skipping assistant creation | with_assistant=False" + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_many_by_ids( + [doc.id for doc in creation_request.collection_params.documents] ) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} @@ -259,6 +211,7 @@ def execute_job( organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, + collection_blob=collection_blob, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) @@ -299,12 +252,13 @@ def execute_job( exc_info=True, ) - _cleanup_remote_resources( - assistant=assistant, - assistant_crud=assistant_crud, - vector_store=vector_store, - vector_store_crud=vector_store_crud, - ) + if provider is not None and result is not None: + try: + provider.cleanup(result) + except Exception: + logger.warning( + "[create_collection.execute_job] Provider cleanup failed" + ) collection_job = _mark_job_failed( project_id=project_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index ca337b79..a49fa4b1 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -6,7 +6,6 @@ from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( CollectionJobStatus, CollectionJobUpdate, @@ -15,9 +14,10 @@ CollectionIDPublic, ) from app.models.collection import DeletionRequest -from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -155,7 +155,6 @@ def execute_job( job_uuid = UUID(job_id) collection_job = None - client = None try: with Session(engine) as session: @@ -169,20 +168,19 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) - collection = CollectionCrud(session, project_id).read_one(collection_id) - # Identify which external service (assistant/vector store) this collection belongs to - service = (collection.llm_service_name or "").strip().lower() - is_vector = service == OPENAI_VECTOR_STORE - llm_service_id = collection.llm_service_id + provider = get_llm_provider( + session=session, + provider=deletion_request.provider, + project_id=project_id, + organization_id=organization_id, + ) - # Delete the corresponding OpenAI resource (vector store or assistant) - if is_vector: - OpenAIVectorStoreCrud(client).delete(llm_service_id) - else: - OpenAIAssistantCrud(client).delete(llm_service_id) + provider.delete( + llm_service_id=collection.llm_service_id, + llm_service_name=collection.llm_service_name, + ) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..6995e081 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -69,17 +69,6 @@ def batch_documents( return docs_batches -def _backout(crud, llm_service_id: str): - """Best-effort cleanup: attempt to delete the assistant by ID""" - try: - crud.delete(llm_service_id) - except OpenAIError as err: - logger.error( - f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", - exc_info=True, - ) - - # Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will # eventually be removed from Kaapi. Once that happens, this function can be safely deleted - def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): diff --git a/backend/app/services/collections/providers/__init__.py b/backend/app/services/collections/providers/__init__.py new file mode 100644 index 00000000..5a9b6a55 --- /dev/null +++ b/backend/app/services/collections/providers/__init__.py @@ -0,0 +1,6 @@ +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider +from app.services.collections.providers.registry import ( + LLMProvider, + get_llm_provider, +) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py new file mode 100644 index 00000000..9fb21f3e --- /dev/null +++ b/backend/app/services/collections/providers/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Any + +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.models import CreationRequest, CreateCollectionResult, Collection + + +class BaseProvider(ABC): + """Abstract base class for collection providers. + + All provider implementations (OpenAI, Bedrock, etc.) must inherit from + this class and implement the required methods. + + Providers handle creation of knowledge bases (vector stores) and + optional assistant/agent creation backed by those knowledge bases. + + Attributes: + client: The provider-specific client instance + """ + + def __init__(self, client: Any): + """Initialize provider with client. + + Args: + client: Provider-specific client instance + """ + self.client = client + + @abstractmethod + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create collection with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant/agent + assistant_options: Options for assistant creation (provider-specific) + + Returns: + CreateCollectionresult containing: + - llm_service_id: ID of the created resource (vector store or assistant) + - llm_service_name: Name of the service + - kb_blob: All collection params except documents + """ + raise NotImplementedError("Providers must implement execute method") + + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection. + + Called when a collection is being deleted and remote resources need to be cleaned up. + + Args: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement delete method") + + @abstractmethod + def cleanup(self, collection_result: CreateCollectionResult) -> None: + """Clean up/rollback resources created during execute. + + Called when collection creation fails and remote resources need to be deleted. + + Args: + collection_result: The CreateCollectionresult returned from execute, containing resource IDs + """ + raise NotImplementedError("Providers must implement cleanup method") + + def get_provider_name(self) -> str: + """Get the name of the provider. + + Returns: + Provider name (e.g., "openai", "bedrock", "pinecone") + """ + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py new file mode 100644 index 00000000..36c76af1 --- /dev/null +++ b/backend/app/services/collections/providers/openai.py @@ -0,0 +1,156 @@ +import logging +from typing import Any + +from openai import OpenAI + +from app.services.collections.providers import BaseProvider +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.services.collections.helpers import batch_documents, OPENAI_VECTOR_STORE +from app.models import CreateCollectionResult, CreationRequest, Collection + + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """OpenAI-specific collection provider for vector stores and assistants.""" + + def __init__(self, client: OpenAI): + super().__init__(client) + self.client = client + + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create OpenAI vector store with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant + assistant_options: Options for assistant creation (model, instructions, etc.) + + Returns: + CreateCollectionResult containing llm_service_id, llm_service_name, and collection_blob + """ + try: + collection_params = collection_request.collection_params + document_ids = [doc.id for doc in collection_params.documents] + + docs_batches = batch_documents( + document_crud, + document_ids, + collection_request.batch_size, + ) + + vector_store_crud = OpenAIVectorStoreCrud(self.client) + vector_store = vector_store_crud.create() + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + logger.info( + "[OpenAIProvider.execute] Vector store created | " + f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" + ) + + collection_blob = { + "name": collection_params.name, + "description": collection_params.description, + "chunking_params": collection_params.chunking_params, + "additional_params": collection_params.additional_params, + } + + # Check if we need to create an assistant (based on assistant options in request) + with_assistant = ( + collection_request.model is not None + and collection_request.instructions is not None + ) + if with_assistant: + assistant_crud = OpenAIAssistantCrud(self.client) + + assistant_options = { + "model": collection_request.model, + "instructions": collection_request.instructions, + "temperature": collection_request.temperature, + } + filtered_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **filtered_options) + + logger.info( + "[OpenAIProvider.execute] Assistant created | " + f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + ) + + return CreateCollectionResult( + llm_service_id=assistant.id, + llm_service_name=filtered_options.get("model", "assistant"), + collection_blob=collection_blob, + ) + else: + logger.info( + "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" + ) + + return CreateCollectionResult( + llm_service_id=vector_store.id, + llm_service_name=OPENAI_VECTOR_STORE, + collection_blob=collection_blob, + ) + + except Exception as e: + logger.error( + f"[OpenAIProvider.execute] Failed to create knowledge base: {str(e)}", + exc_info=True, + ) + raise + + def delete(self, collection: Collection) -> None: + """Delete OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + collection: Collection that has been requested to be deleted + """ + try: + if collection.llm_service_name != OPENAI_VECTOR_STORE: + OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" + ) + else: + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) + except Exception as e: + logger.error( + f"[OpenAIProvider.delete] Failed to delete resource | " + f"llm_service_id={collection.llm_service_id}, error={str(e)}", + exc_info=True, + ) + raise + + def cleanup(self, result: CreateCollectionResult) -> None: + """Clean up OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + result: The CreateCollectionResult from execute containing resource IDs + """ + self.delete(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/services/collections/providers/registry.py b/backend/app/services/collections/providers/registry.py new file mode 100644 index 00000000..10d07d45 --- /dev/null +++ b/backend/app/services/collections/providers/registry.py @@ -0,0 +1,71 @@ +import logging + +from sqlmodel import Session +from openai import OpenAI + +from app.crud import get_provider_credential +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider + + +logger = logging.getLogger(__name__) + + +class LLMProvider: + OPENAI = "openai" + # Future constants for providers: + # ANTHROPIC = "ANTHROPIC" + # GEMINI = "gemini" + + _registry: dict[str, type[BaseProvider]] = { + OPENAI: OpenAIProvider, + # Future providers: + # ANTHROPIC: BedrockProvider, + # GEMINI: GeminiProvider, + } + + @classmethod + def get(cls, name: str) -> type[BaseProvider]: + """Return the provider class for a given name.""" + provider = cls._registry.get(name) + if not provider: + raise ValueError( + f"Provider '{name}' is not supported. " + f"Supported providers: {', '.join(cls._registry.keys())}" + ) + return provider + + @classmethod + def supported_providers(cls) -> list[str]: + """Return a list of supported provider names.""" + return list(cls._registry.keys()) + + +def get_llm_provider( + session: Session, provider: str, project_id: int, organization_id: int +) -> BaseProvider: + provider_class = LLMProvider.get(provider) + + credentials = get_provider_credential( + session=session, + provider=provider, + project_id=project_id, + org_id=organization_id, + ) + + if not credentials: + raise ValueError( + f"Credentials for provider '{provider}' not configured for this project." + ) + + if provider == LLMProvider.OPENAI: + if "api_key" not in credentials: + raise ValueError("OpenAI credentials not configured for this project.") + client = OpenAI(api_key=credentials["api_key"]) + else: + logger.error( + f"[get_llm_provider] Unsupported provider type requested: {provider}" + ) + raise ValueError(f"Provider '{provider}' is not supported.") + + return provider_class(client=client) From 364a2b52a627dfe34cf284d62e20ed1214f4f56d Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 24 Dec 2025 16:57:59 +0530 Subject: [PATCH 02/11] fixing a delete mistake --- backend/app/services/collections/delete_collection.py | 5 +---- backend/app/services/collections/helpers.py | 11 +++++++++++ backend/app/services/collections/providers/openai.py | 8 ++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index a49fa4b1..e9570964 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -177,10 +177,7 @@ def execute_job( organization_id=organization_id, ) - provider.delete( - llm_service_id=collection.llm_service_id, - llm_service_name=collection.llm_service_name, - ) + provider.delete(collection) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6995e081..795b04cd 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -69,6 +69,17 @@ def batch_documents( return docs_batches +def _backout(crud, llm_service_id: str): + """Best-effort cleanup: attempt to delete the assistant by ID""" + try: + crud.delete(llm_service_id) + except OpenAIError as err: + logger.error( + f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) + + # Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will # eventually be removed from Kaapi. Once that happens, this function can be safely deleted - def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index 36c76af1..ba734d85 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -7,7 +7,11 @@ from app.crud import DocumentCrud from app.core.cloud.storage import CloudStorage from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import batch_documents, OPENAI_VECTOR_STORE +from app.services.collections.helpers import ( + batch_documents, + OPENAI_VECTOR_STORE, + _backout, +) from app.models import CreateCollectionResult, CreationRequest, Collection @@ -153,4 +157,4 @@ def cleanup(self, result: CreateCollectionResult) -> None: Args: result: The CreateCollectionResult from execute containing resource IDs """ - self.delete(result.llm_service_id, result.llm_service_name) + _backout(result.llm_service_id, result.llm_service_name) From 541e311d90c982a64166797cf17b95a881cf0d78 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 11:38:49 +0530 Subject: [PATCH 03/11] pushing everything together --- ...1_extend_collection_table_for_provider_.py | 109 +++++++++ backend/app/api/routes/collections.py | 4 + backend/app/crud/collection/collection.py | 10 + backend/app/models/__init__.py | 3 +- backend/app/models/collection.py | 214 ++++++++++++++++++ backend/app/models/organization.py | 3 - .../services/collections/create_collection.py | 24 +- .../services/collections/delete_collection.py | 2 +- backend/app/services/collections/helpers.py | 36 ++- .../services/collections/providers/base.py | 18 +- .../services/collections/providers/openai.py | 64 ++---- .../collections/test_collection_delete.py | 6 +- .../collections/test_collection_info.py | 18 +- .../collections/test_collection_job_info.py | 8 +- .../collections/test_collection_list.py | 11 +- .../app/tests/api/routes/test_assistants.py | 4 +- .../api/routes/test_openai_conversation.py | 2 +- .../collection/test_crud_collection_create.py | 2 +- .../collection/test_crud_collection_delete.py | 14 +- .../test_crud_collection_read_all.py | 4 +- .../test_crud_collection_read_one.py | 4 +- backend/app/tests/crud/test_assistants.py | 2 +- .../tests/crud/test_openai_conversation.py | 2 +- .../collections/test_create_collection.py | 120 ++++------ .../collections/test_delete_collection.py | 161 ++++++------- .../services/llm/providers/test_openai.py | 2 +- .../response/test_process_response.py | 2 +- backend/app/tests/utils/collection.py | 10 +- backend/app/tests/utils/llm_provider.py | 198 ++++++++++++++++ 29 files changed, 788 insertions(+), 269 deletions(-) create mode 100644 backend/app/alembic/versions/041_extend_collection_table_for_provider_.py create mode 100644 backend/app/models/collection.py create mode 100644 backend/app/tests/utils/llm_provider.py diff --git a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py new file mode 100644 index 00000000..84c386e8 --- /dev/null +++ b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py @@ -0,0 +1,109 @@ +"""extend collection table for provider agnostic support + +Revision ID: 041 +Revises: 040 +Create Date: 2026-01-15 16:53:19.495583 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + +provider_type = postgresql.ENUM( + "OPENAI", + # aws + # gemini + name="providertype", + create_type=False, +) + + +def upgrade(): + provider_type.create(op.get_bind(), checkfirst=True) + op.add_column( + "collection", + sa.Column( + "provider", + provider_type, + nullable=False, + comment="LLM provider used for this collection", + ), + ) + op.execute("UPDATE collection SET provider = 'OPENAI' WHERE provider IS NULL") + op.add_column( + "collection", + sa.Column( + "name", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="Name of the collection", + ), + ) + op.add_column( + "collection", + sa.Column( + "description", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="Description of the collection", + ), + ) + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM service provider", + existing_nullable=False, + ) + op.create_unique_constraint(None, "collection", ["name"]) + op.drop_constraint( + op.f("collection_organization_id_fkey"), "collection", type_="foreignkey" + ) + op.drop_column("collection", "organization_id") + + +def downgrade(): + op.add_column( + "collection", + sa.Column( + "organization_id", + sa.INTEGER(), + autoincrement=False, + nullable=True, + comment="Reference to the organization", + ), + ) + op.execute( + """UPDATE collection SET organization_id = (SELECT organization_id FROM project + WHERE project.id = collection.project_id)""" + ) + op.alter_column("collection", "organization_id", nullable=False) + op.create_foreign_key( + op.f("collection_organization_id_fkey"), + "collection", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint("collection_name_key", "collection", type_="unique") + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "description") + op.drop_column("collection", "name") + op.drop_column("collection", "provider") diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index d19fad31..614871fb 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -28,6 +28,7 @@ CollectionPublic, ) from app.utils import APIResponse, load_description, validate_callback_url +from app.services.collections.helpers import ensure_unique_name from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -87,6 +88,9 @@ def create_collection( if request.callback_url: validate_callback_url(str(request.callback_url)) + if request.name: + ensure_unique_name(session, current_user.project_.id, request.name) + collection_job_crud = CollectionJobCrud(session, current_user.project_.id) collection_job = collection_job_crud.create( CollectionJobCreate( diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index 3c83912a..cb8b6e27 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -93,6 +93,16 @@ def read_all(self): collections = self.session.exec(statement).all() return collections + def exists_by_name(self, collection_name: str) -> bool: + statement = ( + select(Collection.id) + .where(Collection.project_id == self.project_id) + .where(Collection.name == collection_name) + .where(Collection.deleted_at.is_(None)) + ) + result = self.session.exec(statement).first() + return result is not None + def delete_by_id(self, collection_id: UUID) -> Collection: coll = self.read_one(collection_id) coll.deleted_at = now() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ef08fd09..a4d76ee2 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,13 +8,12 @@ from .collection import ( Collection, - CreateCollectionParams, - CreateCollectionResult, CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, DeletionRequest, + ProviderType, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py new file mode 100644 index 00000000..322262bb --- /dev/null +++ b/backend/app/models/collection.py @@ -0,0 +1,214 @@ +from datetime import datetime +from enum import Enum +from typing import Any, Literal +from uuid import UUID, uuid4 + +from pydantic import HttpUrl, model_validator +from sqlmodel import Field, Relationship, SQLModel + +from app.core.util import now +from app.models.document import DocumentPublic +from .project import Project + + +class ProviderType(str, Enum): + """Supported LLM providers for collections.""" + + OPENAI = "OPENAI" + # BEDROCK = "bedrock" + # GEMINI = "gemini" + + +class Collection(SQLModel, table=True): + """Database model for Collection operations.""" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + description="Unique identifier for the collection", + sa_column_kwargs={"comment": "Unique identifier for the collection"}, + ) + provider: ProviderType = ( + Field( + nullable=False, + description="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", + sa_column_kwargs={"LLM provider used for this collection"}, + ), + ) + llm_service_id: str = Field( + nullable=False, + description="External LLM service identifier (e.g., OpenAI vector store ID)", + sa_column_kwargs={ + "comment": "External LLM service identifier (e.g., OpenAI vector store ID)" + }, + ) + llm_service_name: str = Field( + nullable=False, + description="Name of the LLM service", + sa_column_kwargs={"comment": "Name of the LLM service"}, + ) + name: str = Field( + nullable=True, + unique=True, + description="Name of the collection", + sa_column_kwargs={"comment": "Name of the collection"}, + ) + description: str = Field( + nullable=True, + description="Description of the collection", + sa_column_kwargs={"comment": "Description of the collection"}, + ) + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + description="Project the collection belongs to", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + inserted_at: datetime = Field( + default_factory=now, + description="Timestamp when the collection was created", + sa_column_kwargs={"comment": "Timestamp when the collection was created"}, + ) + updated_at: datetime = Field( + default_factory=now, + description="Timestamp when the collection was updated", + sa_column_kwargs={"comment": "Timestamp when the collection was last updated"}, + ) + deleted_at: datetime | None = Field( + default=None, + description="Timestamp when the collection was deleted", + sa_column_kwargs={"comment": "Timestamp when the collection was deleted"}, + ) + project: Project = Relationship(back_populates="collections") + + +# Request models +class CollectionOptions(SQLModel): + name: str | None = Field(default=None, description="Name of the collection") + description: str | None = Field( + default=None, description="Description of the collection" + ) + documents: list[UUID] = Field( + description="List of document IDs", + ) + batch_size: int = Field( + default=1, + description=( + "Number of documents to send to OpenAI in a single " + "transaction. See the `file_ids` parameter in the " + "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." + ), + ) + + def model_post_init(self, __context: Any): + self.documents = list(set(self.documents)) + + +class AssistantOptions(SQLModel): + # Fields to be passed along to OpenAI. They must be a subset of + # parameters accepted by the OpenAI.clien.beta.assistants.create + # API. + model: str | None = Field( + default=None, + description=( + "**[Deprecated]** " + "OpenAI model to attach to this assistant. The model " + "must be compatable with the assistants API; see the " + "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." + ), + ) + + instructions: str | None = Field( + default=None, + description=( + "**[Deprecated]** " + "Assistant instruction. Sometimes referred to as the " + '"system" prompt.' + ), + ) + temperature: float = Field( + default=1e-6, + description=( + "**[Deprecated]** " + "Model temperature. The default is slightly " + "greater-than zero because it is [unknown how OpenAI " + "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." + ), + ) + + @model_validator(mode="before") + def _assistant_fields_all_or_none(cls, values: dict[str, Any]) -> dict[str, Any]: + def norm(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + return s if s else None + return x # let Pydantic handle non-strings + + model = norm(values.get("model")) + instructions = norm(values.get("instructions")) + + if (model is None) ^ (instructions is None): + raise ValueError( + "To create an Assistant, provide BOTH 'model' and 'instructions'. " + "If you only want a vector store, remove both fields." + ) + + values["model"] = model + values["instructions"] = instructions + return values + + +class CallbackRequest(SQLModel): + callback_url: HttpUrl | None = Field( + default=None, + description="URL to call to report endpoint status", + ) + + +class ProviderOptions(SQLModel): + """LLM provider configuration.""" + + provider: Literal["openai"] = Field( + default="openai", description="LLM provider to use for this collection" + ) + + +class CreationRequest( + AssistantOptions, + CollectionOptions, + ProviderOptions, + CallbackRequest, +): + def extract_super_type(self, cls: "CreationRequest"): + for field_name in cls.model_fields.keys(): + field_value = getattr(self, field_name) + yield (field_name, field_value) + + +class DeletionRequest(CallbackRequest): + collection_id: UUID = Field(description="Collection to delete") + + +# Response models + + +class CollectionIDPublic(SQLModel): + id: UUID + + +class CollectionPublic(SQLModel): + id: UUID + llm_service_id: str + llm_service_name: str + project_id: int + + inserted_at: datetime + updated_at: datetime + deleted_at: datetime | None = None + + +class CollectionWithDocsPublic(CollectionPublic): + documents: list[DocumentPublic] | None = None diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 0f936607..b9ff9a7f 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -74,9 +74,6 @@ class Organization(OrganizationBase, table=True): assistants: list["Assistant"] = Relationship( back_populates="organization", cascade_delete=True ) - collections: list["Collection"] = Relationship( - back_populates="organization", cascade_delete=True - ) openai_conversations: list["OpenAIConversation"] = Relationship( back_populates="organization", cascade_delete=True ) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 1086dc71..1522f863 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -20,8 +20,8 @@ CollectionJobUpdate, CollectionPublic, CollectionJobPublic, + CreationRequest, ) -from app.models.collection import CreationRequest from app.services.collections.helpers import extract_error_message from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job @@ -51,8 +51,8 @@ def start_job( project_id=project_id, job_id=str(collection_job_id), trace_id=trace_id, - with_assistant=with_assistant, request=request.model_dump(mode="json"), + with_assistant=with_assistant, organization_id=organization_id, ) @@ -134,11 +134,11 @@ def _mark_job_failed( def execute_job( request: dict, + with_assistant: bool, project_id: int, organization_id: int, task_id: str, job_id: str, - with_assistant: bool, task_instance, ) -> None: """ @@ -155,6 +155,11 @@ def execute_job( try: creation_request = CreationRequest(**request) + if ( + with_assistant == True + ): # this will be removed once dalgo switches to vector store creation only + creation_request.provider = "openai" + job_uuid = UUID(job_id) with Session(engine) as session: @@ -186,15 +191,10 @@ def execute_job( llm_service_id = result.llm_service_id llm_service_name = result.llm_service_name - # Storing collection params (name, description, chunking_params, etc.) in DB - # for future reference and to support different providers with varying configurations - collection_blob = result.collection_blob with Session(engine) as session: document_crud = DocumentCrud(session, project_id) - flat_docs = document_crud.read_many_by_ids( - [doc.id for doc in creation_request.collection_params.documents] - ) + flat_docs = document_crud.read_each(creation_request.documents) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} file_sizes_kb = [ @@ -208,15 +208,15 @@ def execute_job( collection = Collection( id=collection_id, project_id=project_id, - organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, - collection_blob=collection_blob, + provider=creation_request.provider.upper(), + name=creation_request.name, + description=creation_request.description, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - # Link documents to the new collection if flat_docs: DocumentCollectionCrud(session).create(collection, flat_docs) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index e9570964..e175301c 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -172,7 +172,7 @@ def execute_job( provider = get_llm_provider( session=session, - provider=deletion_request.provider, + provider=collection.provider.lower(), project_id=project_id, organization_id=organization_id, ) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..0665260a 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -5,17 +5,26 @@ from uuid import UUID from typing import List +from fastapi import HTTPException from sqlmodel import select from openai import OpenAIError -from app.crud.document.document import DocumentCrud +from app.crud import DocumentCrud, CollectionCrud +from app.api.deps import SessionDep from app.models import DocumentCollection, Collection logger = logging.getLogger(__name__) -# llm service name for when only an openai vector store is being made -OPENAI_VECTOR_STORE = "openai vector store" + +def get_service_name(provider: str) -> str: + """Get the collection service name for a provider.""" + names = { + "openai": "openai vector store", + # "bedrock": "bedrock knowledge base", + # "gemini": "gemini file search store", + } + return names.get(provider.lower(), "") def extract_error_message(err: Exception) -> str: @@ -101,4 +110,23 @@ def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): service = ( (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" ) - return v_crud if service == OPENAI_VECTOR_STORE else a_crud + return v_crud if service == get_service_name("openai") else a_crud + + +def ensure_unique_name( + session: SessionDep, + project_id: int, + requested_name: str, +) -> str: + """ + Ensure collection name is unique based on strategy. + + """ + existing = CollectionCrud(session, project_id).exists_by_name(requested_name) + if existing: + raise HTTPException( + status_code=409, + detail=f"Collection '{requested_name}' already exists. Choose a different name.", + ) + + return requested_name diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 9fb21f3e..e4c91d9c 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -3,7 +3,7 @@ from app.crud import DocumentCrud from app.core.cloud.storage import CloudStorage -from app.models import CreationRequest, CreateCollectionResult, Collection +from app.models import CreationRequest, Collection class BaseProvider(ABC): @@ -12,8 +12,8 @@ class BaseProvider(ABC): All provider implementations (OpenAI, Bedrock, etc.) must inherit from this class and implement the required methods. - Providers handle creation of knowledge bases (vector stores) and - optional assistant/agent creation backed by those knowledge bases. + Providers handle creation of collection and + optional assistant/agent creation backed by those collections. Attributes: client: The provider-specific client instance @@ -33,11 +33,11 @@ def create( collection_request: CreationRequest, storage: CloudStorage, document_crud: DocumentCrud, - ) -> CreateCollectionResult: + ) -> Collection: """Create collection with documents and optionally an assistant. Args: - collection_params: Collection parameters (name, description, chunking_params, etc.) + collection_request: Collection parameters (name, description, document list, etc.) storage: Cloud storage instance for file access document_crud: DocumentCrud instance for fetching documents batch_size: Number of documents to process per batch @@ -45,10 +45,8 @@ def create( assistant_options: Options for assistant creation (provider-specific) Returns: - CreateCollectionresult containing: - - llm_service_id: ID of the created resource (vector store or assistant) - - llm_service_name: Name of the service - - kb_blob: All collection params except documents + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) """ raise NotImplementedError("Providers must implement execute method") @@ -65,7 +63,7 @@ def delete(self, collection: Collection) -> None: raise NotImplementedError("Providers must implement delete method") @abstractmethod - def cleanup(self, collection_result: CreateCollectionResult) -> None: + def cleanup(self, collection: Collection) -> None: """Clean up/rollback resources created during execute. Called when collection creation fails and remote resources need to be deleted. diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index ba734d85..a33bd854 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -1,5 +1,4 @@ import logging -from typing import Any from openai import OpenAI @@ -7,12 +6,8 @@ from app.crud import DocumentCrud from app.core.cloud.storage import CloudStorage from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import ( - batch_documents, - OPENAI_VECTOR_STORE, - _backout, -) -from app.models import CreateCollectionResult, CreationRequest, Collection +from app.services.collections.helpers import batch_documents, get_service_name, _backout +from app.models import CreationRequest, Collection logger = logging.getLogger(__name__) @@ -30,27 +25,14 @@ def create( collection_request: CreationRequest, storage: CloudStorage, document_crud: DocumentCrud, - ) -> CreateCollectionResult: - """Create OpenAI vector store with documents and optionally an assistant. - - Args: - collection_params: Collection parameters (name, description, chunking_params, etc.) - storage: Cloud storage instance for file access - document_crud: DocumentCrud instance for fetching documents - batch_size: Number of documents to process per batch - with_assistant: Whether to create an assistant - assistant_options: Options for assistant creation (model, instructions, etc.) - - Returns: - CreateCollectionResult containing llm_service_id, llm_service_name, and collection_blob + ) -> Collection: + """ + Create OpenAI vector store with documents and optionally an assistant. """ try: - collection_params = collection_request.collection_params - document_ids = [doc.id for doc in collection_params.documents] - docs_batches = batch_documents( document_crud, - document_ids, + collection_request.documents, collection_request.batch_size, ) @@ -64,13 +46,6 @@ def create( f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" ) - collection_blob = { - "name": collection_params.name, - "description": collection_params.description, - "chunking_params": collection_params.chunking_params, - "additional_params": collection_params.additional_params, - } - # Check if we need to create an assistant (based on assistant options in request) with_assistant = ( collection_request.model is not None @@ -95,25 +70,23 @@ def create( f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" ) - return CreateCollectionResult( + return Collection( llm_service_id=assistant.id, llm_service_name=filtered_options.get("model", "assistant"), - collection_blob=collection_blob, ) else: logger.info( "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" ) - return CreateCollectionResult( + return Collection( llm_service_id=vector_store.id, - llm_service_name=OPENAI_VECTOR_STORE, - collection_blob=collection_blob, + llm_service_name=get_service_name("openai"), ) except Exception as e: logger.error( - f"[OpenAIProvider.execute] Failed to create knowledge base: {str(e)}", + f"[OpenAIProvider.execute] Failed to create collection: {str(e)}", exc_info=True, ) raise @@ -124,12 +97,9 @@ def delete(self, collection: Collection) -> None: Determines what to delete based on llm_service_name: - If assistant was created, delete the assistant (which also removes the vector store) - If only vector store was created, delete the vector store - - Args: - collection: Collection that has been requested to be deleted """ try: - if collection.llm_service_name != OPENAI_VECTOR_STORE: + if collection.llm_service_name != get_service_name("openai"): OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) logger.info( f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" @@ -147,14 +117,8 @@ def delete(self, collection: Collection) -> None: ) raise - def cleanup(self, result: CreateCollectionResult) -> None: - """Clean up OpenAI resources (assistant or vector store). - - Determines what to delete based on llm_service_name: - - If assistant was created, delete the assistant (which also removes the vector store) - - If only vector store was created, delete the vector store - - Args: - result: The CreateCollectionResult from execute containing resource IDs + def cleanup(self, result: Collection) -> None: + """ + Clean up OpenAI resources (assistant or vector store). """ _backout(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/tests/api/routes/collections/test_collection_delete.py b/backend/app/tests/api/routes/collections/test_collection_delete.py index d55b461a..f7ed400d 100644 --- a/backend/app/tests/api/routes/collections/test_collection_delete.py +++ b/backend/app/tests/api/routes/collections/test_collection_delete.py @@ -9,7 +9,7 @@ from app.tests.utils.auth import TestAuthContext from app.models import CollectionJobStatus from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection @patch("app.api.routes.collections.delete_service.start_job") @@ -28,7 +28,7 @@ def test_delete_collection_calls_start_job_and_returns_job( - Calls delete_service.start_job with correct arguments """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) resp = client.request( "DELETE", @@ -72,7 +72,7 @@ def test_delete_collection_with_callback_url_passes_it_to_start_job( into the DeletionRequest and then into delete_service.start_job. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) payload = { "callback_url": "https://example.com/collections/delete-callback", diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 09d59dc6..ff38687d 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -6,9 +6,13 @@ from app.core.config import settings from app.tests.utils.utils import get_project, get_document -from app.tests.utils.collection import get_collection, get_vector_store_collection +from app.tests.utils.collection import ( + get_assistant_collection, + get_vector_store_collection, +) from app.crud import DocumentCollectionCrud from app.models import Collection, Document +from app.services.collections.helpers import get_service_name def link_document_to_collection( @@ -40,13 +44,13 @@ def test_collection_info_returns_assistant_collection_with_docs( ) -> None: """ Happy path: - - Assistant-style collection (get_collection) + - Assistant-style collection (get_assistant_collection) - include_docs = True (default) - At least one document linked """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) document = link_document_to_collection(db, collection) @@ -82,7 +86,7 @@ def test_collection_info_include_docs_false_returns_no_docs( When include_docs=false, the endpoint should not populate the documents list. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) link_document_to_collection(db, collection) @@ -113,7 +117,7 @@ def test_collection_info_pagination_skip_and_limit( We create multiple document links and then request a paginated slice. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) documents = db.exec( select(Document).where(Document.deleted_at.is_(None)).limit(2) @@ -148,7 +152,7 @@ def test_collection_info_vector_store_collection( via get_vector_store_collection. """ project = get_project(db, "Dalgo") - collection = get_vector_store_collection(db, project) + collection = get_vector_store_collection(db, project, provider="openai") link_document_to_collection(db, collection) @@ -163,7 +167,7 @@ def test_collection_info_vector_store_collection( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "openai vector store" + assert payload["llm_service_name"] == get_service_name("openai") assert payload["llm_service_id"] == collection.llm_service_id docs = payload.get("documents", []) diff --git a/backend/app/tests/api/routes/collections/test_collection_job_info.py b/backend/app/tests/api/routes/collections/test_collection_job_info.py index abc2e3cb..ad95b676 100644 --- a/backend/app/tests/api/routes/collections/test_collection_job_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_job_info.py @@ -5,7 +5,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection, get_collection_job +from app.tests.utils.collection import get_assistant_collection, get_collection_job from app.models import ( CollectionActionType, CollectionJobStatus, @@ -41,7 +41,7 @@ def test_collection_info_create_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL @@ -101,7 +101,7 @@ def test_collection_info_delete_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, @@ -133,7 +133,7 @@ def test_collection_info_delete_failed( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) collection_job = get_collection_job( db, diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index de713532..779a56a5 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -4,9 +4,10 @@ from app.core.config import settings from app.tests.utils.utils import get_project from app.tests.utils.collection import ( - get_collection, + get_assistant_collection, get_vector_store_collection, ) +from app.services.collections.helpers import get_service_name def test_list_collections_returns_api_response( @@ -39,7 +40,7 @@ def test_list_collections_includes_assistant_collection( user_api_key_header: dict[str, str], ) -> None: """ - Ensure that a newly created assistant-style collection (get_collection) + Ensure that a newly created assistant-style collection (get_assistant_collection) appears in the list for the current project. """ @@ -51,7 +52,7 @@ def test_list_collections_includes_assistant_collection( ) assert response_before.status_code == 200 - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) response_after = client.get( f"{settings.API_V1_STR}/collections/", @@ -82,7 +83,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( appear in the list and expose the expected LLM fields. """ project = get_project(db, "Dalgo") - collection = get_vector_store_collection(db, project) + collection = get_vector_store_collection(db, project, provider="openai") response = client.get( f"{settings.API_V1_STR}/collections/", @@ -101,7 +102,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( row = matching[0] assert row["project_id"] == project.id - assert row["llm_service_name"] == "openai vector store" + assert row["llm_service_name"] == get_service_name("openai") assert row["llm_service_id"] == collection.llm_service_id diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index 3d2e9aa9..d54651c2 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -6,8 +6,8 @@ from sqlmodel import Session from fastapi import HTTPException from fastapi.testclient import TestClient - -from app.tests.utils.openai import mock_openai_assistant +from unittest.mock import patch +from app.tests.utils.llm_provider import mock_openai_assistant from app.tests.utils.utils import get_assistant from app.tests.utils.auth import TestAuthContext diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py index 7957c3d9..2620d80b 100644 --- a/backend/app/tests/api/routes/test_openai_conversation.py +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -3,7 +3,7 @@ from app.crud.openai_conversation import create_conversation from app.models import OpenAIConversationCreate -from app.tests.utils.openai import generate_openai_id +from app.tests.utils.llm_provider import generate_openai_id from app.tests.utils.auth import TestAuthContext diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py index f8b0e1f2..aaa32e97 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -18,9 +18,9 @@ def test_create_associates_documents(self, db: Session) -> None: collection = Collection( id=uuid4(), project_id=project.id, - organization_id=project.organization_id, llm_service_id="asst_dummy", llm_service_name="gpt-4o", + provider="OPENAI", ) store = DocumentStore(db, project_id=collection.project_id) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index 2c2ff45c..7ea32e7d 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -9,7 +9,7 @@ from app.tests.utils.document import DocumentStore -def get_collection_for_delete( +def get_assistant_collection_for_delete( db: Session, client=None, project_id: int = None ) -> Collection: project = get_project(db) @@ -24,10 +24,10 @@ def get_collection_for_delete( ) return Collection( - organization_id=project.organization_id, project_id=project_id, llm_service_id=assistant.id, llm_service_name="gpt-4o", + provider="OPENAI", ) @@ -40,7 +40,9 @@ def test_delete_marks_deleted(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection_for_delete(db, client, project_id=project.id) + collection = get_assistant_collection_for_delete( + db, client, project_id=project.id + ) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -53,7 +55,7 @@ def test_delete_follows_insert(self, db: Session) -> None: assistant = OpenAIAssistantCrud(client) project = get_project(db) - collection = get_collection_for_delete(db, project_id=project.id) + collection = get_assistant_collection_for_delete(db, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -74,7 +76,9 @@ def test_delete_document_deletes_collections(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection_for_delete(db, client, project_id=project.id) + coll = get_assistant_collection_for_delete( + db, client, project_id=project.id + ) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index fcf1cafc..0382b883 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -6,7 +6,7 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection def create_collections(db: Session, n: int) -> Collection: @@ -16,7 +16,7 @@ def create_collections(db: Session, n: int) -> Collection: with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): - collection = get_collection(db, project=project) + collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) if crud is None: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index fe601196..6f6473f3 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -8,7 +8,7 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection +from app.tests.utils.collection import get_assistant_collection def mk_collection(db: Session) -> Collection: @@ -16,7 +16,7 @@ def mk_collection(db: Session) -> Collection: project = get_project(db) with openai_mock.router: client = OpenAI(api_key="sk-test-key") - collection = get_collection(db, project=project) + collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) crud = CollectionCrud(db, collection.project_id) diff --git a/backend/app/tests/crud/test_assistants.py b/backend/app/tests/crud/test_assistants.py index e40669c1..12aa4994 100644 --- a/backend/app/tests/crud/test_assistants.py +++ b/backend/app/tests/crud/test_assistants.py @@ -15,7 +15,7 @@ get_assistant_by_id, get_assistants_by_project, ) -from app.tests.utils.openai import mock_openai_assistant +from app.tests.utils.llm_provider import mock_openai_assistant from app.tests.utils.utils import ( get_project, get_assistant, diff --git a/backend/app/tests/crud/test_openai_conversation.py b/backend/app/tests/crud/test_openai_conversation.py index fec3da1e..314238bc 100644 --- a/backend/app/tests/crud/test_openai_conversation.py +++ b/backend/app/tests/crud/test_openai_conversation.py @@ -15,7 +15,7 @@ ) from app.models import OpenAIConversationCreate, Project from app.tests.utils.utils import get_project, get_organization -from app.tests.utils.openai import generate_openai_id +from app.tests.utils.llm_provider import generate_openai_id def test_get_conversation_by_id_success(db: Session) -> None: diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 0ea5e495..bcd1c4fa 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -16,9 +16,9 @@ from app.models import CollectionJobStatus, CollectionJob, CollectionActionType, Project from app.models.collection import CreationRequest from app.services.collections.create_collection import start_job, execute_job -from app.tests.utils.openai import get_mock_openai_client_with_vector_store +from app.tests.utils.llm_provider import get_mock_provider from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection_job, get_collection +from app.tests.utils.collection import get_collection_job, get_assistant_collection from app.tests.utils.document import DocumentStore @@ -57,12 +57,10 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non """ project = get_project(db) request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], batch_size=1, callback_url=None, + provider="openai", ) job_id = uuid4() @@ -115,10 +113,10 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_openai_client: Any, db: Session -) -> None: + mock_get_llm_provider, db: Session +): """ execute_job should: - set task_id on the CollectionJob @@ -139,16 +137,12 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, - documents=[document.id], - batch_size=1, - callback_url=None, + documents=[document.id], batch_size=1, callback_url=None, provider="openai" ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + ) job_id = uuid4() _ = get_collection_job( @@ -184,8 +178,8 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( created_collection = CollectionCrud(db, project.id).read_one( updated_job.collection_id ) - assert created_collection.llm_service_id == "mock_assistant_id" - assert created_collection.llm_service_name == sample_request.model + assert created_collection.llm_service_id == "mock_vector_store_id" + assert created_collection.llm_service_name == "openai vector store" assert created_collection.updated_at is not None docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) @@ -195,10 +189,10 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") -def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( - mock_get_openai_client: Any, db: Session -) -> None: +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( + mock_get_llm_provider, db +): project = get_project(db) job = get_collection_job( @@ -211,32 +205,23 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( ) req = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.0, - documents=[], - batch_size=1, - callback_url=None, + documents=[], batch_size=1, callback_url=None, provider="openai" ) - _ = mock_get_openai_client.return_value + mock_provider = get_mock_provider( + llm_service_id="vs_123", llm_service_name="openai vector store" + ) + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.create_collection.Session" ) as SessionCtor, patch( - "app.services.collections.create_collection.OpenAIVectorStoreCrud" - ) as MockVS, patch( - "app.services.collections.create_collection.OpenAIAssistantCrud" - ) as MockAsst: + "app.services.collections.create_collection.CollectionCrud" + ) as MockCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False - MockVS.return_value.create.return_value = type( - "Vector store", (), {"id": "vs_123"} - )() - MockVS.return_value.update.return_value = [] - - MockAsst.return_value.create.side_effect = RuntimeError("assistant boom") + MockCrud.return_value.create.side_effect = Exception("DB constraint violation") task_id = str(uuid4()) execute_job( @@ -249,21 +234,18 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( task_instance=None, ) - failed = CollectionJobCrud(db, project.id).read_one(job.id) - assert failed.task_id == task_id - assert failed.status == CollectionJobStatus.FAILED - assert "assistant boom" in (failed.error_message or "") - - MockVS.return_value.delete.assert_called_once_with("vs_123") + mock_provider.cleanup.assert_called_once() @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_flow_callback_job_and_creates_collection( - mock_send_callback: Any, mock_get_openai_client: Any, db: Session -) -> None: + mock_send_callback, + mock_get_llm_provider, + db, +): """ execute_job should: - set task_id on the CollectionJob @@ -286,16 +268,15 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( callback_url = "https://example.com/collections/create-success" sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[document.id], batch_size=1, callback_url=callback_url, + provider="openai", ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + ) job_id = uuid.uuid4() _ = get_collection_job( @@ -339,11 +320,13 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_creates_collection_with_callback( - mock_send_callback: Any, mock_get_openai_client: Any, db: Session -) -> None: + mock_send_callback, + mock_get_llm_provider, + db, +): """ execute_job should: - set task_id on the CollectionJob @@ -366,16 +349,15 @@ def test_execute_job_success_creates_collection_with_callback( callback_url = "https://example.com/collections/create-success" sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[document.id], batch_size=1, callback_url=callback_url, + provider="openai", ) - mock_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_client + mock_get_llm_provider.return_value = get_mock_provider( + llm_service_id="mock_vector_store_id", llm_service_name="gpt-4o" + ) job_id = uuid.uuid4() _ = get_collection_job( @@ -419,13 +401,13 @@ def test_execute_job_success_creates_collection_with_callback( @pytest.mark.usefixtures("aws_credentials") @mock_aws -@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") @patch("app.services.collections.create_collection.CollectionCrud") def test_execute_job_failure_flow_callback_job_and_marks_failed( - MockCollectionCrud: Any, - mock_send_callback: Any, - mock_get_openai_client: Any, + MockCollectionCrud, + mock_send_callback, + mock_get_llm_provider, db: Session, ) -> None: """ @@ -434,7 +416,7 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_assistant_collection(db, project, assistant_id="asst_123") job = get_collection_job( db, project, @@ -443,7 +425,7 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( collection_id=None, ) - mock_get_openai_client.return_value = MagicMock() + mock_get_llm_provider.return_value = MagicMock() callback_url = "https://example.com/collections/create-failure" @@ -451,12 +433,10 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( collection_crud_instance.read_one.return_value = collection sample_request = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[uuid.uuid4()], batch_size=1, callback_url=callback_url, + provider="openai", ) task_id = uuid.uuid4() diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index 26153ee4..11d91d0b 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -2,14 +2,13 @@ from unittest.mock import patch, MagicMock from uuid import uuid4, UUID -from sqlmodel import Session -from sqlalchemy.exc import SQLAlchemyError from app.models.collection import DeletionRequest + from app.tests.utils.utils import get_project from app.crud import CollectionJobCrud from app.models import CollectionJobStatus, CollectionActionType -from app.tests.utils.collection import get_collection, get_collection_job +from app.tests.utils.collection import get_collection_job, get_vector_store_collection from app.services.collections.delete_collection import start_job, execute_job @@ -20,7 +19,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non - return the same job_id (UUID) """ project = get_project(db) - created_collection = get_collection(db, project) + created_collection = get_vector_store_collection(db, project, provider="OPENAI") req = DeletionRequest(collection_id=created_collection.id) @@ -72,19 +71,22 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non assert "trace_id" in kwargs -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") def test_execute_job_delete_success_updates_job_and_calls_delete( - mock_get_openai_client: Any, db: Session -) -> None: + mock_get_llm_provider, db +): """ - execute_job should set task_id on the CollectionJob - - call remote delete via OpenAIAssistantCrud.delete(...) + - call provider.delete() to delete remote resources - delete local record via CollectionCrud.delete_by_id(...) - mark job successful and clear error_message """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, project, vector_store_id="asst_123", provider="OPENAI" + ) + job = get_collection_job( db, project, @@ -93,13 +95,13 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete = MagicMock() + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db @@ -108,8 +110,6 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.return_value = None - task_id = uuid4() req = DeletionRequest(collection_id=collection.id) @@ -128,27 +128,32 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( assert updated_job.status == CollectionJobStatus.SUCCESSFUL assert updated_job.error_message in (None, "") + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() -@patch("app.services.collections.delete_collection.get_openai_client") -def test_execute_job_delete_failure_marks_job_failed( - mock_get_openai_client: Any, db: Session -) -> None: + +@patch("app.services.collections.delete_collection.get_llm_provider") +def test_execute_job_delete_failure_marks_job_failed(mock_get_llm_provider, db): """ - When the remote delete (OpenAIAssistantCrud.delete) raises, - the job should be marked FAILED and error_message set. + When provider.delete() raises an exception: + - Job should be marked FAILED + - error_message should be set + - Local collection should NOT be deleted """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector_123", + provider="OPENAI", + ) + job = get_collection_job( db, project, @@ -157,13 +162,13 @@ def test_execute_job_delete_failure_marks_job_failed( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete.side_effect = Exception("Remote deletion failed") + mock_get_llm_provider.return_value = mock_provider with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db @@ -172,10 +177,6 @@ def test_execute_job_delete_failure_marks_job_failed( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( - "something went wrong" - ) - task_id = uuid4() req = DeletionRequest(collection_id=collection.id) @@ -194,24 +195,25 @@ def test_execute_job_delete_failure_marks_job_failed( assert failed_job.status == CollectionJobStatus.FAILED assert ( failed_job.error_message - and "something went wrong" in failed_job.error_message + and "Remote deletion failed" in failed_job.error_message ) + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_not_called() - mock_get_openai_client.assert_called_once() + + mock_get_llm_provider.assert_called_once() -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_success_with_callback_sends_success_payload( - mock_get_openai_client: Any, - db: Session, -) -> None: + mock_send_callback, + mock_get_llm_provider, + db, +): """ When deletion succeeds and a callback_url is provided: - job is marked SUCCESSFUL @@ -220,7 +222,13 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector 123", + provider="OPENAI", + ) + job = get_collection_job( db, project, @@ -229,27 +237,23 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete = MagicMock() + mock_get_llm_provider.return_value = mock_provider callback_url = "https://example.com/collections/delete-success" with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" - ) as MockCollectionCrud, patch( - "app.services.collections.delete_collection.send_callback" - ) as mock_send_callback: + ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.return_value = None - task_id = uuid4() req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) @@ -268,12 +272,12 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( assert updated_job.status == CollectionJobStatus.SUCCESSFUL assert updated_job.error_message in (None, "") + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() mock_send_callback.assert_called_once() cb_url_arg, payload_arg = mock_send_callback.call_args.args @@ -285,20 +289,28 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( assert UUID(payload_arg["data"]["job_id"]) == job.id -@patch("app.services.collections.delete_collection.get_openai_client") +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( - mock_get_openai_client: Any, - db: Session, -) -> None: + mock_send_callback, + mock_get_llm_provider, + db, +): """ - When the remote delete raises AND a callback_url is provided: + When provider.delete() raises AND a callback_url is provided: - job is marked FAILED with error_message set - send_callback is called once - failure payload has success=False, status=FAILED, correct collection id, and error message """ project = get_project(db) - collection = get_collection(db, project, assistant_id="asst_123") + collection = get_vector_store_collection( + db, + project, + vector_store_id="vector_123", + provider="OPENAI", + ) + job = get_collection_job( db, project, @@ -307,28 +319,23 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( collection_id=collection.id, ) - mock_get_openai_client.return_value = MagicMock() + mock_provider = MagicMock() + mock_provider.delete.side_effect = Exception("Remote deletion failed") + mock_get_llm_provider.return_value = mock_provider + callback_url = "https://example.com/collections/delete-failed" with patch( "app.services.collections.delete_collection.Session" ) as SessionCtor, patch( - "app.services.collections.delete_collection.OpenAIAssistantCrud" - ) as MockAssistantCrud, patch( "app.services.collections.delete_collection.CollectionCrud" - ) as MockCollectionCrud, patch( - "app.services.collections.delete_collection.send_callback" - ) as mock_send_callback: + ) as MockCollectionCrud: SessionCtor.return_value.__enter__.return_value = db SessionCtor.return_value.__exit__.return_value = False collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( - "something went wrong" - ) - task_id = uuid4() req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) @@ -347,24 +354,22 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( assert failed_job.status == CollectionJobStatus.FAILED assert ( failed_job.error_message - and "something went wrong" in failed_job.error_message + and "Remote deletion failed" in failed_job.error_message ) + mock_provider.delete.assert_called_once_with(collection) + MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - - MockAssistantCrud.assert_called_once() - MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") - collection_crud_instance.delete_by_id.assert_not_called() - mock_get_openai_client.assert_called_once() + mock_get_llm_provider.assert_called_once() mock_send_callback.assert_called_once() cb_url_arg, payload_arg = mock_send_callback.call_args.args assert str(cb_url_arg) == callback_url assert payload_arg["success"] is False - assert "something went wrong" in (payload_arg["error"] or "") + assert "Remote deletion failed" in (payload_arg["error"] or "") assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED assert payload_arg["data"]["collection"]["id"] == str(collection.id) assert UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/llm/providers/test_openai.py b/backend/app/tests/services/llm/providers/test_openai.py index 745dd00b..4dfb671e 100644 --- a/backend/app/tests/services/llm/providers/test_openai.py +++ b/backend/app/tests/services/llm/providers/test_openai.py @@ -12,7 +12,7 @@ ) from app.models.llm.request import ConversationConfig from app.services.llm.providers.openai import OpenAIProvider -from app.tests.utils.openai import mock_openai_response +from app.tests.utils.llm_provider import mock_openai_response class TestOpenAIProvider: diff --git a/backend/app/tests/services/response/response/test_process_response.py b/backend/app/tests/services/response/response/test_process_response.py index 3145dd6b..68ba24d2 100644 --- a/backend/app/tests/services/response/response/test_process_response.py +++ b/backend/app/tests/services/response/response/test_process_response.py @@ -17,7 +17,7 @@ ) from app.utils import APIResponse from app.tests.utils.test_data import create_test_credential -from app.tests.utils.openai import mock_openai_response, generate_openai_id +from app.tests.utils.llm_provider import mock_openai_response, generate_openai_id from app.crud import JobCrud, create_assistant diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index bdc68b29..9b5f8b99 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,9 +8,11 @@ CollectionActionType, CollectionJob, CollectionJobStatus, + ProviderType, Project, ) from app.crud import CollectionCrud, CollectionJobCrud +from app.services.collections.helpers import get_service_name class constants: @@ -23,7 +25,7 @@ def uuid_increment(value: UUID) -> UUID: return UUID(int=inc) -def get_collection( +def get_assistant_collection( db: Session, project: Project, *, @@ -44,6 +46,7 @@ def get_collection( organization_id=project.organization_id, llm_service_name=model, llm_service_id=assistant_id, + provider=ProviderType.OPENAI, ) return CollectionCrud(db, project.id).create(collection) @@ -54,6 +57,7 @@ def get_vector_store_collection( *, vector_store_id: Optional[str] = None, collection_id: Optional[UUID] = None, + provider: str, ) -> Collection: """ Create a Collection configured for the Vector Store path. @@ -65,9 +69,9 @@ def get_vector_store_collection( collection = Collection( id=collection_id or uuid4(), project_id=project.id, - organization_id=project.organization_id, - llm_service_name="openai vector store", + llm_service_name=get_service_name("openai"), llm_service_id=vector_store_id, + provider=provider.upper(), ) return CollectionCrud(db, project.id).create(collection) diff --git a/backend/app/tests/utils/llm_provider.py b/backend/app/tests/utils/llm_provider.py new file mode 100644 index 00000000..542a0baf --- /dev/null +++ b/backend/app/tests/utils/llm_provider.py @@ -0,0 +1,198 @@ +import time +import secrets +import string +from typing import Optional +from types import SimpleNamespace +from unittest.mock import MagicMock + +from openai.types.beta import Assistant as OpenAIAssistant +from openai.types.beta.assistant import ToolResources, ToolResourcesFileSearch +from openai.types.beta.assistant_tool import FileSearchTool +from openai.types.beta.file_search_tool import FileSearch + + +def generate_openai_id(prefix: str, length: int = 40) -> str: + """Generate a realistic ID similar to OpenAI's format (alphanumeric only)""" + # Generate random alphanumeric string + chars = string.ascii_lowercase + string.digits + random_part = "".join(secrets.choice(chars) for _ in range(length)) + return f"{prefix}{random_part}" + + +def mock_openai_assistant( + assistant_id: str = "assistant_mock", + vector_store_ids: Optional[list[str]] = ["vs_1", "vs_2"], + max_num_results: int = 30, +) -> OpenAIAssistant: + return OpenAIAssistant( + id=assistant_id, + created_at=int(time.time()), + description="Mock description", + instructions="Mock instructions", + metadata={}, + model="gpt-4o", + name="Mock Assistant", + object="assistant", + tools=[ + FileSearchTool( + type="file_search", + file_search=FileSearch( + max_num_results=max_num_results, + ), + ) + ], + temperature=1.0, + tool_resources=ToolResources( + code_interpreter=None, + file_search=ToolResourcesFileSearch(vector_store_ids=vector_store_ids), + ), + top_p=1.0, + reasoning_effort=None, + ) + + +def mock_openai_response( + text: str = "Hello world", + previous_response_id: str | None = None, + model: str = "gpt-4", + conversation_id: str | None = None, +) -> SimpleNamespace: + """Return a minimal mock OpenAI-like response object for testing. + + Args: + text: The response text + previous_response_id: Optional previous response ID + model: Model name + conversation_id: Optional conversation ID. If provided, adds conversation object to response. + """ + + usage = SimpleNamespace( + input_tokens=10, + output_tokens=20, + total_tokens=30, + ) + + output_item = SimpleNamespace( + id=generate_openai_id("out_"), + type="message", + role="assistant", + content=[{"type": "output_text", "text": text}], + ) + + conversation = None + if conversation_id: + conversation = SimpleNamespace(id=conversation_id) + + response = SimpleNamespace( + id=generate_openai_id("resp_"), + created_at=int(time.time()), + model=model, + object="response", + output=[output_item], + output_text=text, + usage=usage, + previous_response_id=previous_response_id, + conversation=conversation, + model_dump=lambda: { + "id": response.id, + "model": model, + "output_text": text, + "usage": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + }, + "conversation": {"id": conversation_id} if conversation_id else None, + }, + ) + return response + + +def get_mock_openai_client_with_vector_store() -> MagicMock: + mock_client = MagicMock() + + mock_vector_store = MagicMock() + mock_vector_store.id = "mock_vector_store_id" + mock_client.vector_stores.create.return_value = mock_vector_store + + mock_file_batch = MagicMock() + mock_file_batch.file_counts.completed = 2 + mock_file_batch.file_counts.total = 2 + mock_client.vector_stores.file_batches.upload_and_poll.return_value = ( + mock_file_batch + ) + + mock_client.vector_stores.files.list.return_value = {"data": []} + + mock_assistant = MagicMock() + mock_assistant.id = "mock_assistant_id" + mock_assistant.name = "Mock Assistant" + mock_assistant.model = "gpt-4o" + mock_assistant.instructions = "Mock instructions" + mock_client.beta.assistants.create.return_value = mock_assistant + + return mock_client + + +def create_mock_batch( + batch_id: str = "batch-xyz789", + status: str = "completed", + output_file_id: str | None = "output-file-123", + error_file_id: str | None = None, + total: int = 100, + completed: int = 100, + failed: int = 0, +) -> MagicMock: + """ + Create a mock OpenAI batch object with configurable properties. + + Args: + batch_id: The batch ID + status: Batch status (completed, in_progress, failed, expired, cancelled, etc.) + output_file_id: Output file ID (None for incomplete batches) + error_file_id: Error file ID (None if no errors) + total: Total number of requests in the batch + completed: Number of completed requests + failed: Number of failed requests + + Returns: + MagicMock configured to represent an OpenAI batch object + """ + mock_batch = MagicMock() + mock_batch.id = batch_id + mock_batch.status = status + mock_batch.output_file_id = output_file_id + mock_batch.error_file_id = error_file_id + + # Create request_counts mock + mock_batch.request_counts = MagicMock() + mock_batch.request_counts.total = total + mock_batch.request_counts.completed = completed + mock_batch.request_counts.failed = failed + + return mock_batch + + +def get_mock_provider( + llm_service_id: str = "mock_service_id", + llm_service_name: str = "mock_service_name", +): + """ + Create a properly configured mock provider for tests. + + Returns a mock that mimics BaseProvider with: + - create() method returning result with llm_service_id and llm_service_name + - cleanup() method for cleanup on failure + - delete() method for deletion + """ + mock_provider = MagicMock() + + mock_result = MagicMock() + mock_result.llm_service_id = llm_service_id + mock_result.llm_service_name = llm_service_name + + mock_provider.create.return_value = mock_result + mock_provider.cleanup = MagicMock() + mock_provider.delete = MagicMock() + + return mock_provider From 9ee861f447a2e0b5e5937ef2916eee257557bd65 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 11:47:16 +0530 Subject: [PATCH 04/11] removing un-needed files --- ..._adding_blob_column_in_collection_table.py | 47 ---- backend/app/models/collection/__init__.py | 14 -- backend/app/models/collection/request.py | 208 ------------------ backend/app/models/collection/response.py | 33 --- 4 files changed, 302 deletions(-) delete mode 100644 backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py delete mode 100644 backend/app/models/collection/__init__.py delete mode 100644 backend/app/models/collection/request.py delete mode 100644 backend/app/models/collection/response.py diff --git a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py deleted file mode 100644 index 8f65f055..00000000 --- a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py +++ /dev/null @@ -1,47 +0,0 @@ -"""adding blob column in collection table - -Revision ID: 041 -Revises: 040 -Create Date: 2025-12-24 11:03:44.620424 - -""" -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -revision = "041" -down_revision = "040" -branch_labels = None -depends_on = None - - -def upgrade(): - op.add_column( - "collection", - sa.Column( - "collection_blob", - postgresql.JSONB(astext_type=sa.Text()), - nullable=True, - comment="Provider-specific knowledge base creation parameters (name, description, chunking params etc.)", - ), - ) - op.alter_column( - "collection", - "llm_service_name", - existing_type=sa.VARCHAR(), - comment="Name of the LLM service", - existing_comment="Name of the LLM service provider", - existing_nullable=False, - ) - - -def downgrade(): - op.alter_column( - "collection", - "llm_service_name", - existing_type=sa.VARCHAR(), - comment="Name of the LLM service provider", - existing_comment="Name of the LLM service", - existing_nullable=False, - ) - op.drop_column("collection", "collection_blob") diff --git a/backend/app/models/collection/__init__.py b/backend/app/models/collection/__init__.py deleted file mode 100644 index e31f65bc..00000000 --- a/backend/app/models/collection/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from app.models.collection.request import ( - Collection, - CreationRequest, - DeletionRequest, - CallbackRequest, - AssistantOptions, - CreateCollectionParams, -) -from app.models.collection.response import ( - CollectionIDPublic, - CollectionPublic, - CollectionWithDocsPublic, - CreateCollectionResult, -) diff --git a/backend/app/models/collection/request.py b/backend/app/models/collection/request.py deleted file mode 100644 index 9f8e106b..00000000 --- a/backend/app/models/collection/request.py +++ /dev/null @@ -1,208 +0,0 @@ -from datetime import datetime -from typing import Any, Literal -from uuid import UUID, uuid4 - -from pydantic import HttpUrl, model_validator -import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import JSONB -from sqlmodel import Field, Relationship, SQLModel - -from app.core.util import now -from app.models.organization import Organization -from app.models.project import Project - - -class Collection(SQLModel, table=True): - """Database model for Collection operations.""" - - id: UUID = Field( - default_factory=uuid4, - primary_key=True, - sa_column_kwargs={"comment": "Unique identifier for the collection"}, - ) - llm_service_id: str = Field( - nullable=False, - sa_column_kwargs={ - "comment": "External LLM service identifier (e.g., OpenAI vector store ID)" - }, - ) - llm_service_name: str = Field( - nullable=False, - sa_column_kwargs={"comment": "Name of the LLM service"}, - ) - collection_blob: dict[str, Any] | None = Field( - sa_column=sa.Column( - JSONB, - nullable=True, - comment="Provider-specific collection parameters (name, description, chunking params etc.)", - ) - ) - organization_id: int = Field( - foreign_key="organization.id", - nullable=False, - ondelete="CASCADE", - sa_column_kwargs={"comment": "Reference to the organization"}, - ) - project_id: int = Field( - foreign_key="project.id", - nullable=False, - ondelete="CASCADE", - sa_column_kwargs={"comment": "Reference to the project"}, - ) - inserted_at: datetime = Field( - default_factory=now, - sa_column_kwargs={"comment": "Timestamp when the collection was created"}, - ) - updated_at: datetime = Field( - default_factory=now, - sa_column_kwargs={"comment": "Timestamp when the collection was last updated"}, - ) - deleted_at: datetime | None = Field( - default=None, - sa_column_kwargs={"comment": "Timestamp when the collection was deleted"}, - ) - - # Relationships - organization: Organization = Relationship(back_populates="collections") - project: Project = Relationship(back_populates="collections") - - -class DocumentInput(SQLModel): - """Document to be added to knowledge base.""" - - name: str | None = Field( - description="Display name for the document", - ) - id: UUID = Field( - description="Reference to uploaded file/document in Kaapi", - ) - - -class CreateCollectionParams(SQLModel): - """Request-specific parameters for knowledge base creation.""" - - name: str | None = Field( - min_length=1, - description="Name of the knowledge base to create or update", - ) - description: str | None = Field( - default=None, - description="Description of the knowledge base (required by Bedrock, optional for others)", - ) - documents: list[DocumentInput] = Field( - default_factory=list, - description="List of documents to add to the knowledge base", - ) - chunking_params: dict[str, Any] | None = Field( - default=None, - description="Chunking parameters for document processing (e.g., chunk_size, chunk_overlap)", - ) - additional_params: dict[str, Any] | None = Field( - default=None, - description="Additional provider-specific parameters", - ) - - def model_post_init(self, __context: Any): - """Deduplicate documents by file_id.""" - seen = set() - unique_docs = [] - for doc in self.documents: - if doc.file_id not in seen: - seen.add(doc.file_id) - unique_docs.append(doc) - self.documents = unique_docs - - -class AssistantOptions(SQLModel): - # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.client.beta.assistants.create - # API. - model: str | None = Field( - default=None, - description=( - "**[Deprecated]** " - "OpenAI model to attach to this assistant. The model " - "must be compatable with the assistants API; see the " - "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." - ), - ) - - instructions: str | None = Field( - default=None, - description=( - "**[Deprecated]** " - "Assistant instruction. Sometimes referred to as the " - '"system" prompt.' - ), - ) - temperature: float = Field( - default=1e-6, - description=( - "**[Deprecated]** " - "Model temperature. The default is slightly " - "greater-than zero because it is [unknown how OpenAI " - "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." - ), - ) - - @model_validator(mode="before") - def _assistant_fields_all_or_none(cls, values: dict[str, Any]) -> dict[str, Any]: - def norm(x: Any) -> Any: - if x is None: - return None - if isinstance(x, str): - s = x.strip() - return s if s else None - return x # let Pydantic handle non-strings - - model = norm(values.get("model")) - instructions = norm(values.get("instructions")) - - if (model is None) ^ (instructions is None): - raise ValueError( - "To create an Assistant, provide BOTH 'model' and 'instructions'. " - "If you only want a vector store, remove both fields." - ) - - values["model"] = model - values["instructions"] = instructions - return values - - -class CallbackRequest(SQLModel): - """Optional callback configuration for async job notifications.""" - - callback_url: HttpUrl | None = Field( - default=None, - description="URL to call to report endpoint status", - ) - - -class ProviderOptions(SQLModel): - """LLM provider configuration.""" - - provider: Literal["openai"] = Field( - default="openai", description="LLM provider to use for this collection" - ) - - -class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): - """API request for collection creation""" - - collection_params: CreateCollectionParams = Field( - ..., - description="Collection creation specific parameters (name, documents, etc.)", - ) - batch_size: int = Field( - default=10, - ge=1, - le=500, - description="Number of documents to process in a single batch", - ) - - -class DeletionRequest(ProviderOptions, CallbackRequest): - - """API request for collection deletion""" - - collection_id: UUID = Field(description="Collection to delete") diff --git a/backend/app/models/collection/response.py b/backend/app/models/collection/response.py deleted file mode 100644 index f72c5ee7..00000000 --- a/backend/app/models/collection/response.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import Any -from uuid import UUID - -from sqlmodel import SQLModel - -from app.models.document import DocumentPublic - - -class CreateCollectionResult(SQLModel): - llm_service_id: str - llm_service_name: str - collection_blob: dict[str, Any] - - -class CollectionIDPublic(SQLModel): - id: UUID - - -class CollectionPublic(SQLModel): - id: UUID - llm_service_id: str - llm_service_name: str - project_id: int - organization_id: int - - inserted_at: datetime - updated_at: datetime - deleted_at: datetime | None = None - - -class CollectionWithDocsPublic(CollectionPublic): - documents: list[DocumentPublic] | None = None From 9ed9ca52647428720db98b81bd2f10dfea4a7029 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 11:55:21 +0530 Subject: [PATCH 05/11] adding a missed import --- .../app/tests/services/collections/test_delete_collection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index 11d91d0b..ef4508ce 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -1,10 +1,9 @@ -from typing import Any from unittest.mock import patch, MagicMock from uuid import uuid4, UUID +from sqlmodel import Session from app.models.collection import DeletionRequest - from app.tests.utils.utils import get_project from app.crud import CollectionJobCrud from app.models import CollectionJobStatus, CollectionActionType From aed5e24f6e49068910792c648d496c27f9b1bbbb Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 13:02:33 +0530 Subject: [PATCH 06/11] adding more test cases --- ...1_extend_collection_table_for_provider_.py | 3 +- backend/app/models/collection.py | 2 +- .../services/collections/create_collection.py | 4 +- backend/app/services/collections/helpers.py | 11 -- .../services/collections/providers/base.py | 13 +- .../services/collections/providers/openai.py | 16 +- .../test_crud_collection_read_one.py | 1 - .../collections/providers/test_openai.py | 154 ++++++++++++++++++ .../collections/test_create_collection.py | 24 +-- .../collections/test_delete_collection.py | 20 ++- .../services/collections/test_helpers.py | 53 ++++-- 11 files changed, 223 insertions(+), 78 deletions(-) create mode 100644 backend/app/tests/services/collections/providers/test_openai.py diff --git a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py index 84c386e8..b1bb90e4 100644 --- a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py +++ b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py @@ -33,11 +33,12 @@ def upgrade(): sa.Column( "provider", provider_type, - nullable=False, + nullable=True, comment="LLM provider used for this collection", ), ) op.execute("UPDATE collection SET provider = 'OPENAI' WHERE provider IS NULL") + op.alter_column("collection", "provider", nullable=False) op.add_column( "collection", sa.Column( diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 322262bb..42624bdf 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -32,7 +32,7 @@ class Collection(SQLModel, table=True): Field( nullable=False, description="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", - sa_column_kwargs={"LLM provider used for this collection"}, + sa_column_kwargs={"comment": "LLM provider used for this collection"}, ), ) llm_service_id: str = Field( diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 1522f863..552bef99 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -156,7 +156,7 @@ def execute_job( try: creation_request = CreationRequest(**request) if ( - with_assistant == True + with_assistant ): # this will be removed once dalgo switches to vector store creation only creation_request.provider = "openai" @@ -254,7 +254,7 @@ def execute_job( if provider is not None and result is not None: try: - provider.cleanup(result) + provider.delete(result) except Exception: logger.warning( "[create_collection.execute_job] Provider cleanup failed" diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 0665260a..7965e2e2 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -78,17 +78,6 @@ def batch_documents( return docs_batches -def _backout(crud, llm_service_id: str): - """Best-effort cleanup: attempt to delete the assistant by ID""" - try: - crud.delete(llm_service_id) - except OpenAIError as err: - logger.error( - f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", - exc_info=True, - ) - - # Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will # eventually be removed from Kaapi. Once that happens, this function can be safely deleted - def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index e4c91d9c..d76fb618 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -19,7 +19,7 @@ class BaseProvider(ABC): client: The provider-specific client instance """ - def __init__(self, client: Any): + def __init__(self, client: Any) -> None: """Initialize provider with client. Args: @@ -62,17 +62,6 @@ def delete(self, collection: Collection) -> None: """ raise NotImplementedError("Providers must implement delete method") - @abstractmethod - def cleanup(self, collection: Collection) -> None: - """Clean up/rollback resources created during execute. - - Called when collection creation fails and remote resources need to be deleted. - - Args: - collection_result: The CreateCollectionresult returned from execute, containing resource IDs - """ - raise NotImplementedError("Providers must implement cleanup method") - def get_provider_name(self) -> str: """Get the name of the provider. diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index a33bd854..b8a73412 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -6,7 +6,7 @@ from app.crud import DocumentCrud from app.core.cloud.storage import CloudStorage from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import batch_documents, get_service_name, _backout +from app.services.collections.helpers import batch_documents, get_service_name from app.models import CreationRequest, Collection @@ -42,7 +42,7 @@ def create( list(vector_store_crud.update(vector_store.id, storage, docs_batches)) logger.info( - "[OpenAIProvider.execute] Vector store created | " + "[OpenAIProvider.create] Vector store created | " f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" ) @@ -66,7 +66,7 @@ def create( assistant = assistant_crud.create(vector_store.id, **filtered_options) logger.info( - "[OpenAIProvider.execute] Assistant created | " + "[OpenAIProvider.create] Assistant created | " f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" ) @@ -76,7 +76,7 @@ def create( ) else: logger.info( - "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" + "[OpenAIProvider.create] Skipping assistant creation | with_assistant=False" ) return Collection( @@ -86,7 +86,7 @@ def create( except Exception as e: logger.error( - f"[OpenAIProvider.execute] Failed to create collection: {str(e)}", + f"[OpenAIProvider.create] Failed to create collection: {str(e)}", exc_info=True, ) raise @@ -116,9 +116,3 @@ def delete(self, collection: Collection) -> None: exc_info=True, ) raise - - def cleanup(self, result: Collection) -> None: - """ - Clean up OpenAI resources (assistant or vector store). - """ - _backout(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index 6f6473f3..2fc5f767 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -15,7 +15,6 @@ def mk_collection(db: Session) -> Collection: openai_mock = OpenAIMock() project = get_project(db) with openai_mock.router: - client = OpenAI(api_key="sk-test-key") collection = get_assistant_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) diff --git a/backend/app/tests/services/collections/providers/test_openai.py b/backend/app/tests/services/collections/providers/test_openai.py new file mode 100644 index 00000000..bee3c95a --- /dev/null +++ b/backend/app/tests/services/collections/providers/test_openai.py @@ -0,0 +1,154 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.collections.providers.openai import OpenAIProvider +from app.models.collection import Collection +from app.services.collections.helpers import get_service_name +from app.tests.utils.llm_provider import ( + generate_openai_id, + get_mock_openai_client_with_vector_store, +) + + +def test_create_openai_vector_store_only() -> None: + client = get_mock_openai_client_with_vector_store() + provider = OpenAIProvider(client=client) + + collection_request = SimpleNamespace( + documents=["doc1", "doc2"], + batch_size=1, + model=None, + instructions=None, + temperature=None, + ) + + storage = MagicMock() + document_crud = MagicMock() + + fake_batches = [["doc1"], ["doc2"]] + vector_store_id = generate_openai_id("vs_") + + with patch( + "app.services.collections.providers.openai.batch_documents", + return_value=fake_batches, + ), patch( + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls: + vector_store_crud = vector_store_crud_cls.return_value + vector_store_crud.create.return_value = MagicMock(id=vector_store_id) + vector_store_crud.update.return_value = iter([None]) + + collection = provider.create( + collection_request, + storage, + document_crud, + ) + + assert isinstance(collection, Collection) + assert collection.llm_service_id == vector_store_id + assert collection.llm_service_name == get_service_name("openai") + + +def test_create_openai_with_assistant() -> None: + client = get_mock_openai_client_with_vector_store() + provider = OpenAIProvider(client=client) + + collection_request = SimpleNamespace( + documents=["doc1"], + batch_size=1, + model="gpt-4o", + instructions="You are helpful", + temperature=0.7, + ) + + storage = MagicMock() + document_crud = MagicMock() + + fake_batches = [["doc1"]] + vector_store_id = generate_openai_id("vs_") + assistant_id = generate_openai_id("asst_") + + with patch( + "app.services.collections.providers.openai.batch_documents", + return_value=fake_batches, + ), patch( + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls, patch( + "app.services.collections.providers.openai.OpenAIAssistantCrud" + ) as assistant_crud_cls: + vector_store_crud = vector_store_crud_cls.return_value + vector_store_crud.create.return_value = MagicMock(id=vector_store_id) + vector_store_crud.update.return_value = iter([None]) + + assistant_crud = assistant_crud_cls.return_value + assistant_crud.create.return_value = MagicMock(id=assistant_id) + + collection = provider.create( + collection_request, + storage, + document_crud, + ) + + assert collection.llm_service_id == assistant_id + assert collection.llm_service_name == "gpt-4o" + + +def test_delete_openai_assistant() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + collection = Collection( + llm_service_id=generate_openai_id("asst_"), + llm_service_name="gpt-4o", + ) + + with patch( + "app.services.collections.providers.openai.OpenAIAssistantCrud" + ) as assistant_crud_cls: + assistant_crud = assistant_crud_cls.return_value + provider.delete(collection) + + assistant_crud.delete.assert_called_once_with(collection.llm_service_id) + + +def test_delete_openai_vector_store() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + collection = Collection( + llm_service_id=generate_openai_id("vs_"), + llm_service_name=get_service_name("openai"), + ) + + with patch( + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls: + vector_store_crud = vector_store_crud_cls.return_value + provider.delete(collection) + + vector_store_crud.delete.assert_called_once_with(collection.llm_service_id) + + +def test_create_propagates_exception() -> None: + provider = OpenAIProvider(client=MagicMock()) + + collection_request = SimpleNamespace( + documents=["doc1"], + batch_size=1, + model=None, + instructions=None, + temperature=None, + ) + + with patch( + "app.services.collections.providers.openai.batch_documents", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(RuntimeError): + provider.create( + collection_request, + MagicMock(), + MagicMock(), + ) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index bcd1c4fa..0d61f217 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -115,8 +115,8 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non @mock_aws @patch("app.services.collections.create_collection.get_llm_provider") def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_llm_provider, db: Session -): + mock_get_llm_provider: MagicMock, db: Session +) -> None: """ execute_job should: - set task_id on the CollectionJob @@ -191,8 +191,8 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( @mock_aws @patch("app.services.collections.create_collection.get_llm_provider") def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( - mock_get_llm_provider, db -): + mock_get_llm_provider: MagicMock, db +) -> None: project = get_project(db) job = get_collection_job( @@ -242,10 +242,10 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collectio @patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_flow_callback_job_and_creates_collection( - mock_send_callback, - mock_get_llm_provider, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, db, -): +) -> None: """ execute_job should: - set task_id on the CollectionJob @@ -323,10 +323,10 @@ def test_execute_job_success_flow_callback_job_and_creates_collection( @patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") def test_execute_job_success_creates_collection_with_callback( - mock_send_callback, - mock_get_llm_provider, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, db, -): +) -> None: """ execute_job should: - set task_id on the CollectionJob @@ -406,8 +406,8 @@ def test_execute_job_success_creates_collection_with_callback( @patch("app.services.collections.create_collection.CollectionCrud") def test_execute_job_failure_flow_callback_job_and_marks_failed( MockCollectionCrud, - mock_send_callback, - mock_get_llm_provider, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, db: Session, ) -> None: """ diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index ef4508ce..010a0e3a 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -72,8 +72,8 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non @patch("app.services.collections.delete_collection.get_llm_provider") def test_execute_job_delete_success_updates_job_and_calls_delete( - mock_get_llm_provider, db -): + mock_get_llm_provider: MagicMock, db +) -> None: """ - execute_job should set task_id on the CollectionJob - call provider.delete() to delete remote resources @@ -137,7 +137,9 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( @patch("app.services.collections.delete_collection.get_llm_provider") -def test_execute_job_delete_failure_marks_job_failed(mock_get_llm_provider, db): +def test_execute_job_delete_failure_marks_job_failed( + mock_get_llm_provider: MagicMock, db +) -> None: """ When provider.delete() raises an exception: - Job should be marked FAILED @@ -209,10 +211,10 @@ def test_execute_job_delete_failure_marks_job_failed(mock_get_llm_provider, db): @patch("app.services.collections.delete_collection.get_llm_provider") @patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_success_with_callback_sends_success_payload( - mock_send_callback, - mock_get_llm_provider, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, db, -): +) -> None: """ When deletion succeeds and a callback_url is provided: - job is marked SUCCESSFUL @@ -291,10 +293,10 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( @patch("app.services.collections.delete_collection.get_llm_provider") @patch("app.services.collections.delete_collection.send_callback") def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( - mock_send_callback, - mock_get_llm_provider, + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, db, -): +) -> None: """ When provider.delete() raises AND a callback_url is provided: - job is marked FAILED with error_message set diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index be94ffeb..4a3b0e35 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -1,11 +1,17 @@ from __future__ import annotations import json -from typing import Any from types import SimpleNamespace from uuid import uuid4 +import pytest +from sqlmodel import Session +from fastapi import HTTPException + from app.services.collections import helpers +from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_vector_store_collection +from app.services.collections.helpers import ensure_unique_name def test_extract_error_message_parses_json_and_strips_prefix() -> None: @@ -84,28 +90,39 @@ def test_batch_documents_empty_input() -> None: assert crud.calls == [] -# _backout +def test_ensure_unique_name_success(db: Session) -> None: + requested_name = "new_collection_name" + + project = get_project(db) + result = ensure_unique_name( + session=db, + project_id=project.id, + requested_name=requested_name, + ) -def test_backout_calls_delete_and_swallows_openai_error(monkeypatch: Any) -> None: - class Crud: - def __init__(self): - self.calls = 0 + assert result == requested_name - def delete(self, resource_id: str): - self.calls += 1 - crud = Crud() - helpers._backout(crud, "rsrc_1") - assert crud.calls == 1 +def test_ensure_unique_name_conflict_with_vector_store_collection(db: Session) -> None: + existing_name = "vector_collection" + project = get_project(db) - class DummyOpenAIError(Exception): - pass + collection = get_vector_store_collection( + db=db, + project=project, + provider="openai", + ) - monkeypatch.setattr(helpers, "OpenAIError", DummyOpenAIError) + collection.name = existing_name + db.commit() - class FailingCrud: - def delete(self, resource_id: str): - raise DummyOpenAIError("nope") + with pytest.raises(HTTPException) as exc: + ensure_unique_name( + session=db, + project_id=project.id, + requested_name=existing_name, + ) - helpers._backout(FailingCrud(), "rsrc_2") + assert exc.value.status_code == 409 + assert "already exists" in exc.value.detail From 0d6c444524b34ab22d05488fd899c7c209a8876b Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 13:11:00 +0530 Subject: [PATCH 07/11] fixing small issue --- .../providers/{test_openai.py => test_openai_provider.py} | 0 .../app/tests/services/collections/test_create_collection.py | 2 +- backend/app/tests/utils/llm_provider.py | 1 - 3 files changed, 1 insertion(+), 2 deletions(-) rename backend/app/tests/services/collections/providers/{test_openai.py => test_openai_provider.py} (100%) diff --git a/backend/app/tests/services/collections/providers/test_openai.py b/backend/app/tests/services/collections/providers/test_openai_provider.py similarity index 100% rename from backend/app/tests/services/collections/providers/test_openai.py rename to backend/app/tests/services/collections/providers/test_openai_provider.py diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 0d61f217..2fefd60e 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -234,7 +234,7 @@ def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collectio task_instance=None, ) - mock_provider.cleanup.assert_called_once() + mock_provider.delete.assert_called_once() @pytest.mark.usefixtures("aws_credentials") diff --git a/backend/app/tests/utils/llm_provider.py b/backend/app/tests/utils/llm_provider.py index 542a0baf..afa2dfde 100644 --- a/backend/app/tests/utils/llm_provider.py +++ b/backend/app/tests/utils/llm_provider.py @@ -192,7 +192,6 @@ def get_mock_provider( mock_result.llm_service_name = llm_service_name mock_provider.create.return_value = mock_result - mock_provider.cleanup = MagicMock() mock_provider.delete = MagicMock() return mock_provider From 24018c204ee3f5fff78709e950526d94945028ba Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 13:48:55 +0530 Subject: [PATCH 08/11] coderabbit fixes --- .../041_extend_collection_table_for_provider_.py | 2 +- backend/app/models/collection.py | 10 +++++++++- .../collections/providers/test_openai_provider.py | 2 ++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py index b1bb90e4..10869c47 100644 --- a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py +++ b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py @@ -65,7 +65,7 @@ def upgrade(): existing_comment="Name of the LLM service provider", existing_nullable=False, ) - op.create_unique_constraint(None, "collection", ["name"]) + op.create_unique_constraint(None, "collection", ["project_id", "name"]) op.drop_constraint( op.f("collection_organization_id_fkey"), "collection", type_="foreignkey" ) diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 42624bdf..d54f7b6b 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -4,6 +4,7 @@ from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator +from sqlalchemy import UniqueConstraint from sqlmodel import Field, Relationship, SQLModel from app.core.util import now @@ -22,6 +23,14 @@ class ProviderType(str, Enum): class Collection(SQLModel, table=True): """Database model for Collection operations.""" + __table_args__ = ( + UniqueConstraint( + "project_id", + "name", + name="uq_collection_project_id_name", + ), + ) + id: UUID = Field( default_factory=uuid4, primary_key=True, @@ -49,7 +58,6 @@ class Collection(SQLModel, table=True): ) name: str = Field( nullable=True, - unique=True, description="Name of the collection", sa_column_kwargs={"comment": "Name of the collection"}, ) diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index bee3c95a..a4193b2f 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -102,6 +102,8 @@ def test_delete_openai_assistant() -> None: collection = Collection( llm_service_id=generate_openai_id("asst_"), llm_service_name="gpt-4o", + provider="openai", + project_id=1, ) with patch( From a6fb950e3d87d898445aaabf958eeb81e498ee1c Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 14:04:14 +0530 Subject: [PATCH 09/11] small test case fox --- .../app/tests/api/routes/collections/test_collection_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 77735ae9..d993035f 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -201,7 +201,7 @@ def test_collection_info_include_docs_and_url( the endpoint returns documents with their URLs. """ project = get_project(db, "Dalgo") - collection = get_collection(db, project) + collection = get_assistant_collection(db, project) document = link_document_to_collection(db, collection) From 26fa79243051c9cf953e5fc0a5334a968b54b45b Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 16:46:06 +0530 Subject: [PATCH 10/11] small fixes --- .../041_extend_collection_table_for_provider_.py | 10 ++++++---- backend/app/api/docs/collections/create.md | 14 +++++++------- backend/app/models/collection.py | 2 +- .../app/services/collections/create_collection.py | 3 ++- .../app/services/collections/delete_collection.py | 2 +- .../api/routes/collections/test_collection_info.py | 2 +- .../api/routes/collections/test_collection_list.py | 2 +- .../collection/test_crud_collection_create.py | 4 ++-- .../collection/test_crud_collection_delete.py | 4 ++-- .../services/collections/test_delete_collection.py | 9 ++------- .../app/tests/services/collections/test_helpers.py | 1 - backend/app/tests/utils/collection.py | 5 ++--- 12 files changed, 27 insertions(+), 31 deletions(-) diff --git a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py index 10869c47..ac0bcd89 100644 --- a/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py +++ b/backend/app/alembic/versions/041_extend_collection_table_for_provider_.py @@ -18,7 +18,7 @@ depends_on = None provider_type = postgresql.ENUM( - "OPENAI", + "openai", # aws # gemini name="providertype", @@ -37,7 +37,7 @@ def upgrade(): comment="LLM provider used for this collection", ), ) - op.execute("UPDATE collection SET provider = 'OPENAI' WHERE provider IS NULL") + op.execute("UPDATE collection SET provider = 'openai' WHERE provider IS NULL") op.alter_column("collection", "provider", nullable=False) op.add_column( "collection", @@ -65,7 +65,9 @@ def upgrade(): existing_comment="Name of the LLM service provider", existing_nullable=False, ) - op.create_unique_constraint(None, "collection", ["project_id", "name"]) + op.create_unique_constraint( + "uq_collection_project_id_name", "collection", ["project_id", "name"] + ) op.drop_constraint( op.f("collection_organization_id_fkey"), "collection", type_="foreignkey" ) @@ -96,7 +98,7 @@ def downgrade(): ["id"], ondelete="CASCADE", ) - op.drop_constraint("collection_name_key", "collection", type_="unique") + op.drop_constraint("uq_collection_project_id_name", "collection", type_="unique") op.alter_column( "collection", "llm_service_name", diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index c3a5f440..52787f75 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -10,17 +10,17 @@ pipeline: "model" and "instruction" in the request body otherwise only a vector store will be created from the documents given. -If any one of the OpenAI interactions fail, all OpenAI resources are -cleaned up. If a Vector Store is unable to be created, for example, +If any one of the LLM service interactions fail, all service resources are +cleaned up. If an Openai vector Store is unable to be created, for example, all file(s) that were uploaded to OpenAI are removed from OpenAI. Failure can occur from OpenAI being down, or some parameter value being invalid. It can also fail due to document types not being accepted. This is especially true for PDFs that may not be parseable. -Vector store/assistant will be created asynchronously. The immediate response -from this endpoint is `collection_job` object which is going to contain -the collection "job ID" and status. Once the collection has been created, -information about the collection will be returned to the user via the -callback URL. If a callback URL is not provided, clients can check the +In the case of Openai, Vector store/assistant will be created asynchronously. +The immediate response from this endpoint is `collection_job` object which is +going to contain the collection "job ID" and status. Once the collection has +been created, information about the collection will be returned to the user via +the callback URL. If a callback URL is not provided, clients can check the `collection job info` endpoint with the `job_id`, to retrieve information about the creation of collection. diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index d54f7b6b..7a537079 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -15,7 +15,7 @@ class ProviderType(str, Enum): """Supported LLM providers for collections.""" - OPENAI = "OPENAI" + openai = "openai" # BEDROCK = "bedrock" # GEMINI = "gemini" diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 552bef99..a4321576 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -205,12 +205,13 @@ def execute_job( collection_crud = CollectionCrud(session, project_id) collection_id = uuid4() + collection = Collection( id=collection_id, project_id=project_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, - provider=creation_request.provider.upper(), + provider=creation_request.provider, name=creation_request.name, description=creation_request.description, ) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index e175301c..db02a353 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -172,7 +172,7 @@ def execute_job( provider = get_llm_provider( session=session, - provider=collection.provider.lower(), + provider=collection.provider, project_id=project_id, organization_id=organization_id, ) diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index d993035f..88cc7ed3 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -152,7 +152,7 @@ def test_collection_info_vector_store_collection( via get_vector_store_collection. """ project = get_project(db, "Dalgo") - collection = get_vector_store_collection(db, project, provider="openai") + collection = get_vector_store_collection(db, project) link_document_to_collection(db, collection) diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index 779a56a5..e9b5626d 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -83,7 +83,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( appear in the list and expose the expected LLM fields. """ project = get_project(db, "Dalgo") - collection = get_vector_store_collection(db, project, provider="openai") + collection = get_vector_store_collection(db, project) response = client.get( f"{settings.API_V1_STR}/collections/", diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py index aaa32e97..87c8ac67 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -4,7 +4,7 @@ from sqlmodel import Session, select from app.crud import CollectionCrud -from app.models import DocumentCollection, Collection +from app.models import DocumentCollection, Collection, ProviderType from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project @@ -20,7 +20,7 @@ def test_create_associates_documents(self, db: Session) -> None: project_id=project.id, llm_service_id="asst_dummy", llm_service_name="gpt-4o", - provider="OPENAI", + provider=ProviderType.openai, ) store = DocumentStore(db, project_id=collection.project_id) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index 7ea32e7d..5cf4643d 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -3,7 +3,7 @@ from sqlmodel import Session, select from app.crud import CollectionCrud -from app.models import APIKey, Collection +from app.models import APIKey, Collection, ProviderType from app.crud.rag import OpenAIAssistantCrud from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore @@ -27,7 +27,7 @@ def get_assistant_collection_for_delete( project_id=project_id, llm_service_id=assistant.id, llm_service_name="gpt-4o", - provider="OPENAI", + provider=ProviderType.openai, ) diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index 010a0e3a..d1243ba5 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -18,7 +18,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non - return the same job_id (UUID) """ project = get_project(db) - created_collection = get_vector_store_collection(db, project, provider="OPENAI") + created_collection = get_vector_store_collection(db, project) req = DeletionRequest(collection_id=created_collection.id) @@ -82,9 +82,7 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( """ project = get_project(db) - collection = get_vector_store_collection( - db, project, vector_store_id="asst_123", provider="OPENAI" - ) + collection = get_vector_store_collection(db, project, vector_store_id="asst_123") job = get_collection_job( db, @@ -152,7 +150,6 @@ def test_execute_job_delete_failure_marks_job_failed( db, project, vector_store_id="vector_123", - provider="OPENAI", ) job = get_collection_job( @@ -227,7 +224,6 @@ def test_execute_job_delete_success_with_callback_sends_success_payload( db, project, vector_store_id="vector 123", - provider="OPENAI", ) job = get_collection_job( @@ -309,7 +305,6 @@ def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( db, project, vector_store_id="vector_123", - provider="OPENAI", ) job = get_collection_job( diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 4a3b0e35..f53271f1 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -111,7 +111,6 @@ def test_ensure_unique_name_conflict_with_vector_store_collection(db: Session) - collection = get_vector_store_collection( db=db, project=project, - provider="openai", ) collection.name = existing_name diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 9b5f8b99..f1844cb9 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -46,7 +46,7 @@ def get_assistant_collection( organization_id=project.organization_id, llm_service_name=model, llm_service_id=assistant_id, - provider=ProviderType.OPENAI, + provider=ProviderType.openai, ) return CollectionCrud(db, project.id).create(collection) @@ -57,7 +57,6 @@ def get_vector_store_collection( *, vector_store_id: Optional[str] = None, collection_id: Optional[UUID] = None, - provider: str, ) -> Collection: """ Create a Collection configured for the Vector Store path. @@ -71,7 +70,7 @@ def get_vector_store_collection( project_id=project.id, llm_service_name=get_service_name("openai"), llm_service_id=vector_store_id, - provider=provider.upper(), + provider=ProviderType.openai, ) return CollectionCrud(db, project.id).create(collection) From 8b5427b75466240b3cb8d8e1930fd54e9f8f22d1 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 20 Jan 2026 16:58:20 +0530 Subject: [PATCH 11/11] small coderabbit reviews --- backend/app/api/docs/collections/create.md | 2 +- backend/app/services/collections/create_collection.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index 52787f75..915951b9 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -11,7 +11,7 @@ pipeline: created from the documents given. If any one of the LLM service interactions fail, all service resources are -cleaned up. If an Openai vector Store is unable to be created, for example, +cleaned up. If an OpenAI vector Store is unable to be created, for example, all file(s) that were uploaded to OpenAI are removed from OpenAI. Failure can occur from OpenAI being down, or some parameter value being invalid. It can also fail due to document types not being diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index a4321576..5e7389db 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -151,6 +151,7 @@ def execute_job( # Keeping the references for potential backout/cleanup on failure collection_job = None result = None + creation_request = None provider = None try: