From 6891dc089311716000b166c9cf21aebaaa5dc0c1 Mon Sep 17 00:00:00 2001 From: Jainil Gosalia Date: Wed, 20 Nov 2024 01:16:13 +0530 Subject: [PATCH] Added cache directory support for Colpali --- byaldi/RAGModel.py | 2 ++ byaldi/colpali.py | 1 + 2 files changed, 3 insertions(+) diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 32b66bf..401d8ad 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -45,6 +45,7 @@ def from_pretrained( index_root: str = ".byaldi", device: str = "cuda", verbose: int = 1, + cache_dir: Optional[str] = "/cache_dir/models", ): """Load a ColPali model from a pre-trained checkpoint. @@ -61,6 +62,7 @@ def from_pretrained( index_root=index_root, device=device, verbose=verbose, + cache_dir=cache_dir, ) return instance diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..139ab30 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -76,6 +76,7 @@ def __init__( else None ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + cache_dir=kwargs.get("cache_dir", None), ) elif "colqwen2" in pretrained_model_name_or_path.lower(): self.model = ColQwen2.from_pretrained(