diff --git a/graphai/core/embedding/embedding.py b/graphai/core/embedding/embedding.py index c7b631db..5a9ab042 100644 --- a/graphai/core/embedding/embedding.py +++ b/graphai/core/embedding/embedding.py @@ -40,6 +40,24 @@ def embedding_from_json(s): return np.array(json.loads(s)) +def split_text(text, max_length, split_characters=('\n', '.', ';', ',', ' ', '$')): + result = [] + assert max_length > 0 + while len(text) > max_length: + for split_char in split_characters: + pos = text[:max_length].rfind(split_char) + if pos > 0: + result.append(text[:pos + 1]) + text = text[pos + 1:] + break + if len(text) > max_length: + result.append(text[:max_length]) + text = text[max_length:] + if len(text) > 0: + result.append(text) + return result + + def generate_embedding_text_token(s, model_type): """ Generates an md5-based token for a string @@ -215,14 +233,27 @@ def _get_model_output(self, model, text): print(e) return None - def _embed(self, model, text): + def _embed(self, model, text, force_split): text_too_large = False result = self._get_model_output(model, text) if result is None: - text_too_large = True + if force_split: + model_max_tokens = self._get_model_max_tokens(model) + text_parts = split_text(text, model_max_tokens) + total_len = sum(len(text_part) for text_part in text_parts) + weights = [len(current_text) / total_len for current_text in text_parts] + results = [self._get_model_output(model, current_text) for current_text in text_parts] + if any(res is None for res in results): + text_too_large = True + else: + results = np.vstack([weights[i] * np.reshape(results[i], (1, len(results[i]))) + for i in range(len(results))]) + result = results.sum(axis=0).flatten() + else: + text_too_large = True return result, text_too_large - def embed(self, text, model_type='all-MiniLM-L12-v2'): + def embed(self, text, model_type='all-MiniLM-L12-v2', force_split=True): self.load_model(model_type) if model_type not in self.models.keys(): raise NotImplementedError(f"Selected model type not implemented: {model_type}") @@ -230,7 +261,7 @@ def embed(self, text, model_type='all-MiniLM-L12-v2'): max_tokens = self._get_model_max_tokens(model) if text is None or len(text) == 0: return None, False, max_tokens - results, text_too_large = self._embed(model, text) + results, text_too_large = self._embed(model, text, force_split) return results, text_too_large, max_tokens @@ -430,19 +461,17 @@ def embedding_text_list_embed_parallel(input_list, embedding_obj, model_type, i, j -= 1 if j == i: j = i + 1 - current_results_list = embed_text(current_embedding_obj, - [input_list[remaining_indices[k]]['text'] for k in range(i, j)], - model_type) + current_results_list = [embed_text(current_embedding_obj, + input_list[remaining_indices[k]]['text'], + model_type) for k in range(i, j)] for k in range(i, j): current_results = { - 'result': current_results_list['result'][k - i] - if not isinstance(current_results_list['result'], str) - else current_results_list['result'], - 'successful': current_results_list['successful'], - 'text_too_large': current_results_list['text_too_large'], - 'fresh': current_results_list['fresh'], - 'model_type': current_results_list['model_type'], - 'device': current_results_list['device'], + 'result': current_results_list[k - i]['result'], + 'successful': current_results_list[k - i]['successful'], + 'text_too_large': current_results_list[k - i]['text_too_large'], + 'fresh': current_results_list[k - i]['fresh'], + 'model_type': current_results_list[k - i]['model_type'], + 'device': current_results_list[k - i]['device'], 'id_token': input_list[remaining_indices[k]]['token'], 'source': input_list[remaining_indices[k]]['text'] } diff --git a/tests/test_embedding.py b/tests/test_embedding.py index dfe75dba..323ae7cd 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -18,7 +18,7 @@ @patch('graphai.celery.embedding.tasks.embed_text_task.run') @pytest.mark.usefixtures('example_word') -def test__embedding_embed__translate_text__mock_task(mock_run, example_word): +def test__embedding_embed__embed_text__mock_task(mock_run, example_word): # Mock calling the task embed_text_task.run(example_word, 'all-MiniLM-L12-v2') @@ -30,7 +30,7 @@ def test__embedding_embed__translate_text__mock_task(mock_run, example_word): @pytest.mark.usefixtures('example_word', 'very_long_text') -def test__translation_translate__translate_text__run_task(example_word, very_long_text): +def test__embedding_embed__embed_text__run_task(example_word, very_long_text): # Call the task embedding = embed_text_task.run(example_word, "all-MiniLM-L12-v2") @@ -45,12 +45,11 @@ def test__translation_translate__translate_text__run_task(example_word, very_lon # Call the task embedding = embed_text_task.run(very_long_text, "all-MiniLM-L12-v2") - # Assert that the results are correct + # Assert that a very long text is properly broken up and embedded assert isinstance(embedding, dict) assert 'result' in embedding - assert embedding['successful'] is False - assert embedding['text_too_large'] is True - assert embedding['result'] == "Text over token limit for selected model (128)." + assert embedding['successful'] is True + assert embedding['text_too_large'] is False @pytest.mark.celery(accept_content=['pickle', 'json'], result_serializer='pickle', task_serializer='pickle') @@ -174,8 +173,5 @@ def test__embedding_embedding__embed_text__integration(fixture_app, celery_worke assert len(embedding['task_result']) == 2 + len(example_word_list) assert embedding['task_result'][0]['successful'] is True assert embedding['task_result'][0]['result'] == original_results - assert embedding['task_result'][1]['successful'] is False - assert embedding['task_result'][1]['result'] == "Text over token limit for selected model (128)." - # All except one must have been successful - assert sum([1 if embedding['task_result'][i]['successful'] else 0 - for i in range(len(embedding['task_result']))]) == len(embedding['task_result']) - 1 + assert embedding['task_result'][1]['successful'] is True + assert embedding['task_result'][1]['text_too_large'] is False