diff --git a/module_4_rag/batch_score_documents.py b/module_4_rag/batch_score_documents.py index 5d64da0..6a39c25 100644 --- a/module_4_rag/batch_score_documents.py +++ b/module_4_rag/batch_score_documents.py @@ -1,11 +1,17 @@ import os import pandas as pd +from nltk.tokenize import sent_tokenize from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F -INPUT_FILENAME = "./data/city_wikipedia_summaries.csv" -EXPORT_FILENAME = "./data/city_wikipedia_summaries_with_embeddings.parquet" +BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), "feature_repo")) +DATA_DIR = os.path.join(BASE_DIR, "data") +INPUT_FILENAME = os.path.join(DATA_DIR, "city_wikipedia_summaries.csv") +CHUNKED_FILENAME = os.path.join(DATA_DIR, "city_wikipedia_summaries_chunked.csv") +EXPORT_FILENAME = os.path.join( + DATA_DIR, "city_wikipedia_summaries_with_embeddings.parquet" +) TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2" MODEL = "sentence-transformers/all-MiniLM-L6-v2" @@ -36,23 +42,33 @@ def run_model(sentences, tokenizer, model): def score_data() -> None: - if EXPORT_FILENAME not in os.listdir(): - print("scored data not found...generating embeddings...") - df = pd.read_csv(INPUT_FILENAME) + os.makedirs(DATA_DIR, exist_ok=True) + + if not os.path.exists(EXPORT_FILENAME): + print("Scored data not found... generating embeddings...") + + if not os.path.exists(CHUNKED_FILENAME): + print("Chunked data not found... generating chunked data...") + df = pd.read_csv(INPUT_FILENAME) + df["Sentence Chunks"] = df["Wiki Summary"].apply(lambda x: sent_tokenize(x)) + chunked_df = df.explode("Sentence Chunks") + chunked_df.to_csv(CHUNKED_FILENAME, index=False) + df = chunked_df + else: + df = pd.read_csv(CHUNKED_FILENAME) + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) model = AutoModel.from_pretrained(MODEL) embeddings = run_model(df["Wiki Summary"].tolist(), tokenizer, model) - print(embeddings) - print("shape = ", df.shape) - df["Embeddings"] = list(embeddings.detach().cpu().numpy()) print("embeddings generated...") + df["Embeddings"] = list(embeddings.detach().cpu().numpy()) df["event_timestamp"] = pd.to_datetime("today") df["item_id"] = df.index - print(df.head()) + df.to_parquet(EXPORT_FILENAME, index=False) - print("...data exported. job complete") + print("...data exported. Job complete") else: - print("scored data found...skipping generating embeddings.") + print("Scored data found... skipping generating embeddings.") if __name__ == "__main__": diff --git a/module_4_rag/feature_repo/data/city_wikipedia_summaries_with_embeddings.parquet b/module_4_rag/feature_repo/data/city_wikipedia_summaries_with_embeddings.parquet index 004d8bd..42c56eb 100644 Binary files a/module_4_rag/feature_repo/data/city_wikipedia_summaries_with_embeddings.parquet and b/module_4_rag/feature_repo/data/city_wikipedia_summaries_with_embeddings.parquet differ