Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions graphai/core/embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ def embedding_from_json(s):
return np.array(json.loads(s))


def split_text(text, max_length, split_characters=('\n', '.', ';', ',', ' ', '$')):
result = []
assert max_length > 0
while len(text) > max_length:
for split_char in split_characters:
pos = text[:max_length].rfind(split_char)
if pos > 0:
result.append(text[:pos + 1])
text = text[pos + 1:]
break
if len(text) > max_length:
result.append(text[:max_length])
text = text[max_length:]
if len(text) > 0:
result.append(text)
return result


def generate_embedding_text_token(s, model_type):
"""
Generates an md5-based token for a string
Expand Down Expand Up @@ -215,22 +233,35 @@ def _get_model_output(self, model, text):
print(e)
return None

def _embed(self, model, text):
def _embed(self, model, text, force_split):
text_too_large = False
result = self._get_model_output(model, text)
if result is None:
text_too_large = True
if force_split:
model_max_tokens = self._get_model_max_tokens(model)
text_parts = split_text(text, model_max_tokens)
total_len = sum(len(text_part) for text_part in text_parts)
weights = [len(current_text) / total_len for current_text in text_parts]
results = [self._get_model_output(model, current_text) for current_text in text_parts]
if any(res is None for res in results):
text_too_large = True
else:
results = np.vstack([weights[i] * np.reshape(results[i], (1, len(results[i])))
for i in range(len(results))])
result = results.sum(axis=0).flatten()
else:
text_too_large = True
return result, text_too_large

def embed(self, text, model_type='all-MiniLM-L12-v2'):
def embed(self, text, model_type='all-MiniLM-L12-v2', force_split=True):
self.load_model(model_type)
if model_type not in self.models.keys():
raise NotImplementedError(f"Selected model type not implemented: {model_type}")
model = self.models[model_type]
max_tokens = self._get_model_max_tokens(model)
if text is None or len(text) == 0:
return None, False, max_tokens
results, text_too_large = self._embed(model, text)
results, text_too_large = self._embed(model, text, force_split)
return results, text_too_large, max_tokens


Expand Down Expand Up @@ -430,19 +461,17 @@ def embedding_text_list_embed_parallel(input_list, embedding_obj, model_type, i,
j -= 1
if j == i:
j = i + 1
current_results_list = embed_text(current_embedding_obj,
[input_list[remaining_indices[k]]['text'] for k in range(i, j)],
model_type)
current_results_list = [embed_text(current_embedding_obj,
input_list[remaining_indices[k]]['text'],
model_type) for k in range(i, j)]
for k in range(i, j):
current_results = {
'result': current_results_list['result'][k - i]
if not isinstance(current_results_list['result'], str)
else current_results_list['result'],
'successful': current_results_list['successful'],
'text_too_large': current_results_list['text_too_large'],
'fresh': current_results_list['fresh'],
'model_type': current_results_list['model_type'],
'device': current_results_list['device'],
'result': current_results_list[k - i]['result'],
'successful': current_results_list[k - i]['successful'],
'text_too_large': current_results_list[k - i]['text_too_large'],
'fresh': current_results_list[k - i]['fresh'],
'model_type': current_results_list[k - i]['model_type'],
'device': current_results_list[k - i]['device'],
'id_token': input_list[remaining_indices[k]]['token'],
'source': input_list[remaining_indices[k]]['text']
}
Expand Down
18 changes: 7 additions & 11 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

@patch('graphai.celery.embedding.tasks.embed_text_task.run')
@pytest.mark.usefixtures('example_word')
def test__embedding_embed__translate_text__mock_task(mock_run, example_word):
def test__embedding_embed__embed_text__mock_task(mock_run, example_word):
# Mock calling the task
embed_text_task.run(example_word, 'all-MiniLM-L12-v2')

Expand All @@ -30,7 +30,7 @@ def test__embedding_embed__translate_text__mock_task(mock_run, example_word):


@pytest.mark.usefixtures('example_word', 'very_long_text')
def test__translation_translate__translate_text__run_task(example_word, very_long_text):
def test__embedding_embed__embed_text__run_task(example_word, very_long_text):
# Call the task
embedding = embed_text_task.run(example_word, "all-MiniLM-L12-v2")

Expand All @@ -45,12 +45,11 @@ def test__translation_translate__translate_text__run_task(example_word, very_lon
# Call the task
embedding = embed_text_task.run(very_long_text, "all-MiniLM-L12-v2")

# Assert that the results are correct
# Assert that a very long text is properly broken up and embedded
assert isinstance(embedding, dict)
assert 'result' in embedding
assert embedding['successful'] is False
assert embedding['text_too_large'] is True
assert embedding['result'] == "Text over token limit for selected model (128)."
assert embedding['successful'] is True
assert embedding['text_too_large'] is False


@pytest.mark.celery(accept_content=['pickle', 'json'], result_serializer='pickle', task_serializer='pickle')
Expand Down Expand Up @@ -174,8 +173,5 @@ def test__embedding_embedding__embed_text__integration(fixture_app, celery_worke
assert len(embedding['task_result']) == 2 + len(example_word_list)
assert embedding['task_result'][0]['successful'] is True
assert embedding['task_result'][0]['result'] == original_results
assert embedding['task_result'][1]['successful'] is False
assert embedding['task_result'][1]['result'] == "Text over token limit for selected model (128)."
# All except one must have been successful
assert sum([1 if embedding['task_result'][i]['successful'] else 0
for i in range(len(embedding['task_result']))]) == len(embedding['task_result']) - 1
assert embedding['task_result'][1]['successful'] is True
assert embedding['task_result'][1]['text_too_large'] is False