From d410d513104a097ec80c1251d3f74a3cb8e9caaa Mon Sep 17 00:00:00 2001 From: dnth Date: Mon, 30 Dec 2024 21:53:48 +0800 Subject: [PATCH 1/2] add tqdm for indexing progress --- byaldi/colpali.py | 77 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..5c78c97 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -10,6 +10,7 @@ from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor from pdf2image import convert_from_path from PIL import Image +from tqdm.auto import tqdm from byaldi.objects import Result @@ -48,10 +49,12 @@ def __init__( self.pretrained_model_name_or_path = pretrained_model_name_or_path self.model_name = self.pretrained_model_name_or_path self.n_gpu = torch.cuda.device_count() if n_gpu == -1 else n_gpu - device = ( - device or ( - "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - ) + device = device or ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" ) self.index_name = index_name self.verbose = verbose @@ -357,8 +360,7 @@ def index( raise ValueError( f"Number of metadata entries ({len(metadata)}) does not match number of documents ({len(items)})" ) - for i, item in enumerate(items): - print(f"Indexing file: {item}") + for i, item in enumerate(tqdm(items, desc="Indexing files")): doc_id = doc_ids[i] if doc_ids else self.highest_doc_id + 1 doc_metadata = metadata[doc_id] if metadata else None self.add_to_index( @@ -535,7 +537,11 @@ def _add_to_index( # Generate embedding with torch.inference_mode(): processed_image = { - k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) for k, v in processed_image.items() } embedding = self.model(**processed_image) @@ -592,24 +598,32 @@ def _add_to_index( def remove_from_index(self): raise NotImplementedError("This method is not implemented yet.") - def filter_embeddings(self,filter_metadata:Dict[str,str]): + def filter_embeddings(self, filter_metadata: Dict[str, str]): req_doc_ids = [] - for idx,metadata_dict in self.doc_id_to_metadata.items(): - for metadata_key,metadata_value in metadata_dict.items(): + for idx, metadata_dict in self.doc_id_to_metadata.items(): + for metadata_key, metadata_value in metadata_dict.items(): if metadata_key in filter_metadata: if filter_metadata[metadata_key] == metadata_value: req_doc_ids.append(idx) - - req_embedding_ids = [eid for eid,doc in self.embed_id_to_doc_id.items() if doc['doc_id'] in req_doc_ids] - req_embeddings = [ie for idx,ie in enumerate(self.indexed_embeddings) if idx in req_embedding_ids] + + req_embedding_ids = [ + eid + for eid, doc in self.embed_id_to_doc_id.items() + if doc["doc_id"] in req_doc_ids + ] + req_embeddings = [ + ie + for idx, ie in enumerate(self.indexed_embeddings) + if idx in req_embedding_ids + ] return req_embeddings, req_embedding_ids - + def search( self, query: Union[str, List[str]], k: int = 10, - filter_metadata: Optional[Dict[str,str]] = None, + filter_metadata: Optional[Dict[str, str]] = None, return_base64_results: Optional[bool] = None, ) -> Union[List[Result], List[List[Result]]]: # Set default value for return_base64_results if not provided @@ -631,15 +645,24 @@ def search( # Process query with torch.inference_mode(): batch_query = self.processor.process_queries([q]) - batch_query = {k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in batch_query.items()} + batch_query = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch_query.items() + } embeddings_query = self.model(**batch_query) qs = list(torch.unbind(embeddings_query.to("cpu"))) if not filter_metadata: req_embeddings = self.indexed_embeddings else: - req_embeddings, req_embedding_ids = self.filter_embeddings(filter_metadata=filter_metadata) + req_embeddings, req_embedding_ids = self.filter_embeddings( + filter_metadata=filter_metadata + ) # Compute scores - scores = self.processor.score(qs,req_embeddings).cpu().numpy() + scores = self.processor.score(qs, req_embeddings).cpu().numpy() # Get top k relevant pages top_pages = scores.argsort(axis=1)[0][-k:][::-1].tolist() @@ -714,7 +737,14 @@ def encode_image( with torch.inference_mode(): batch = self.processor.process_images(images) - batch = {k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in batch.items()} + batch = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch.items() + } embeddings = self.model(**batch) return embeddings.cpu() @@ -735,7 +765,14 @@ def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor: with torch.inference_mode(): batch = self.processor.process_queries(query) - batch = {k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in batch.items()} + batch = { + k: v.to(self.device).to( + self.model.dtype + if v.dtype in [torch.float16, torch.bfloat16, torch.float32] + else v.dtype + ) + for k, v in batch.items() + } embeddings = self.model(**batch) return embeddings.cpu() From 1f57fbe81fb57effe7f63c32faebfb5c84aeb866 Mon Sep 17 00:00:00 2001 From: dnth Date: Mon, 30 Dec 2024 22:03:24 +0800 Subject: [PATCH 2/2] update tqdm without formatting --- byaldi/colpali.py | 75 ++++++++++++----------------------------------- 1 file changed, 19 insertions(+), 56 deletions(-) diff --git a/byaldi/colpali.py b/byaldi/colpali.py index 5c78c97..edb6f9f 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -49,12 +49,10 @@ def __init__( self.pretrained_model_name_or_path = pretrained_model_name_or_path self.model_name = self.pretrained_model_name_or_path self.n_gpu = torch.cuda.device_count() if n_gpu == -1 else n_gpu - device = device or ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" + device = ( + device or ( + "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + ) ) self.index_name = index_name self.verbose = verbose @@ -537,11 +535,7 @@ def _add_to_index( # Generate embedding with torch.inference_mode(): processed_image = { - k: v.to(self.device).to( - self.model.dtype - if v.dtype in [torch.float16, torch.bfloat16, torch.float32] - else v.dtype - ) + k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in processed_image.items() } embedding = self.model(**processed_image) @@ -598,32 +592,24 @@ def _add_to_index( def remove_from_index(self): raise NotImplementedError("This method is not implemented yet.") - def filter_embeddings(self, filter_metadata: Dict[str, str]): + def filter_embeddings(self,filter_metadata:Dict[str,str]): req_doc_ids = [] - for idx, metadata_dict in self.doc_id_to_metadata.items(): - for metadata_key, metadata_value in metadata_dict.items(): + for idx,metadata_dict in self.doc_id_to_metadata.items(): + for metadata_key,metadata_value in metadata_dict.items(): if metadata_key in filter_metadata: if filter_metadata[metadata_key] == metadata_value: req_doc_ids.append(idx) - - req_embedding_ids = [ - eid - for eid, doc in self.embed_id_to_doc_id.items() - if doc["doc_id"] in req_doc_ids - ] - req_embeddings = [ - ie - for idx, ie in enumerate(self.indexed_embeddings) - if idx in req_embedding_ids - ] + + req_embedding_ids = [eid for eid,doc in self.embed_id_to_doc_id.items() if doc['doc_id'] in req_doc_ids] + req_embeddings = [ie for idx,ie in enumerate(self.indexed_embeddings) if idx in req_embedding_ids] return req_embeddings, req_embedding_ids - + def search( self, query: Union[str, List[str]], k: int = 10, - filter_metadata: Optional[Dict[str, str]] = None, + filter_metadata: Optional[Dict[str,str]] = None, return_base64_results: Optional[bool] = None, ) -> Union[List[Result], List[List[Result]]]: # Set default value for return_base64_results if not provided @@ -645,24 +631,15 @@ def search( # Process query with torch.inference_mode(): batch_query = self.processor.process_queries([q]) - batch_query = { - k: v.to(self.device).to( - self.model.dtype - if v.dtype in [torch.float16, torch.bfloat16, torch.float32] - else v.dtype - ) - for k, v in batch_query.items() - } + batch_query = {k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in batch_query.items()} embeddings_query = self.model(**batch_query) qs = list(torch.unbind(embeddings_query.to("cpu"))) if not filter_metadata: req_embeddings = self.indexed_embeddings else: - req_embeddings, req_embedding_ids = self.filter_embeddings( - filter_metadata=filter_metadata - ) + req_embeddings, req_embedding_ids = self.filter_embeddings(filter_metadata=filter_metadata) # Compute scores - scores = self.processor.score(qs, req_embeddings).cpu().numpy() + scores = self.processor.score(qs,req_embeddings).cpu().numpy() # Get top k relevant pages top_pages = scores.argsort(axis=1)[0][-k:][::-1].tolist() @@ -737,14 +714,7 @@ def encode_image( with torch.inference_mode(): batch = self.processor.process_images(images) - batch = { - k: v.to(self.device).to( - self.model.dtype - if v.dtype in [torch.float16, torch.bfloat16, torch.float32] - else v.dtype - ) - for k, v in batch.items() - } + batch = {k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in batch.items()} embeddings = self.model(**batch) return embeddings.cpu() @@ -765,17 +735,10 @@ def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor: with torch.inference_mode(): batch = self.processor.process_queries(query) - batch = { - k: v.to(self.device).to( - self.model.dtype - if v.dtype in [torch.float16, torch.bfloat16, torch.float32] - else v.dtype - ) - for k, v in batch.items() - } + batch = {k: v.to(self.device).to(self.model.dtype if v.dtype in [torch.float16, torch.bfloat16, torch.float32] else v.dtype) for k, v in batch.items()} embeddings = self.model(**batch) return embeddings.cpu() def get_doc_ids_to_file_names(self): - return self.doc_ids_to_file_names + return self.doc_ids_to_file_names \ No newline at end of file