Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def tools() -> list[Tool]:


class TestAgent:
def test_to_dict(self, tools, monkeypatch):
def test_to_dict(self, tools, monkeypatch, mock_mem0_memory_client):
monkeypatch.setenv("OPENAI_API_KEY", "test")
monkeypatch.setenv("MEM0_API_KEY", "test")
agent = Agent(
chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"), tools=tools, chat_message_store=InMemoryChatMessageStore(),
memory_store=Mem0MemoryStore()
Expand Down Expand Up @@ -267,8 +268,9 @@ def test_to_dict(self, tools, monkeypatch):
},
}

def test_from_dict(self, tools, monkeypatch):
def test_from_dict(self, tools, monkeypatch, mock_mem0_memory_client):
monkeypatch.setenv("OPENAI_API_KEY", "test")
monkeypatch.setenv("MEM0_API_KEY", "test")
agent = Agent(
chat_generator=OpenAIChatGenerator(), tools=tools, chat_message_store=InMemoryChatMessageStore(),
memory_store=Mem0MemoryStore()
Expand Down
10 changes: 9 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

import pytest

from unittest.mock import Mock, patch
from haystack.testing.test_utils import set_all_seeds

set_all_seeds(0)
Expand All @@ -14,3 +14,11 @@
@pytest.fixture()
def test_files_path():
return Path(__file__).parent / "test_files"

@pytest.fixture
def mock_mem0_memory_client():
"""Mock the Mem0 MemoryClient."""
with patch("haystack_experimental.memory_stores.mem0.memory_store.MemoryClient") as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
yield mock_client
22 changes: 7 additions & 15 deletions test/memory_stores/test_mem0_memory_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@ def memory_store():

class TestMem0MemoryStore:

@pytest.fixture
def mock_memory_client(self):
"""Mock the Mem0 MemoryClient."""
with patch("haystack_experimental.memory_stores.mem0.memory_store.MemoryClient") as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
yield mock_client

@pytest.fixture
def sample_messages(self):
"""Sample ChatMessage objects for testing."""
Expand All @@ -51,17 +43,17 @@ def sample_messages(self):
ChatMessage.from_user("I like working with Haystack.", meta={"topic": "programming"}),
]

def test_init_with_user_id_and_api_key(self, mock_memory_client):
def test_init_with_user_id_and_api_key(self, mock_mem0_memory_client):
"""Test initialization with user_id and api_key."""
with patch.dict(os.environ, {}, clear=True):
store = Mem0MemoryStore(api_key=Secret.from_token("test_api_key_12345"))
assert store.client == mock_memory_client
assert store.client == mock_mem0_memory_client

def test_init_with_params(self, mock_memory_client):
def test_init_with_params(self, mock_mem0_memory_client):
store = Mem0MemoryStore(
api_key=Secret.from_token("test_api_key_12345")
)
assert store.client == mock_memory_client
assert store.client == mock_mem0_memory_client

def test_init_without_api_key_raises_error(self):
"""Test that initialization without API key raises ValueError."""
Expand All @@ -70,7 +62,7 @@ def test_init_without_api_key_raises_error(self):
match="None of the following authentication environment variables are set"):
Mem0MemoryStore()

def test_to_dict(self, monkeypatch, mock_memory_client):
def test_to_dict(self, monkeypatch, mock_mem0_memory_client):
"""Test serialization to dictionary."""
with patch.dict(os.environ, {}, clear=True):
monkeypatch.setenv("ENV_VAR", "test_api_key_12345")
Expand All @@ -80,15 +72,15 @@ def test_to_dict(self, monkeypatch, mock_memory_client):
result = store.to_dict()
assert result["init_parameters"]["api_key"] == {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}

def test_from_dict(self, monkeypatch, mock_memory_client):
def test_from_dict(self, monkeypatch, mock_mem0_memory_client):
with patch.dict(os.environ, {}, clear=True):
monkeypatch.setenv("ENV_VAR", "test_api_key_12345")
data = {
'type': 'haystack_experimental.memory_stores.mem0.memory_store.Mem0MemoryStore',
'init_parameters': { "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
}}
store = Mem0MemoryStore.from_dict(data)
assert store.client == mock_memory_client
assert store.client == mock_mem0_memory_client
assert store.api_key == Secret.from_env_var("ENV_VAR")

@pytest.mark.skipif(
Expand Down