run unit tests and change api where needed

This commit is contained in:
DavIvek 2025-09-09 16:43:23 +02:00
parent 78cdab98d9
commit e43a756ac1
9 changed files with 80 additions and 34 deletions

View file

@ -29,6 +29,11 @@ jobs:
NEO4J_AUTH: neo4j/testpass
NEO4J_PLUGINS: '["apoc"]'
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
memgraph:
image: memgraph/memgraph:latest
ports:
- 7688:7687
options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10
steps:
- uses: actions/checkout@v4
- name: Set up Python
@ -77,3 +82,13 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
uv run pytest tests/test_*_int.py -k "neo4j"
- name: Run Memgraph integration tests
env:
PYTHONPATH: ${{ github.workspace }}
MEMGRAPH_URI: bolt://localhost:7688
MEMGRAPH_USER:
MEMGRAPH_PASSWORD:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
uv run pytest tests/test_*_int.py -k "memgraph"

View file

@ -44,5 +44,34 @@ services:
environment:
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
memgraph:
image: memgraph/memgraph:latest
healthcheck:
test:
[
"CMD",
"mg_client",
"--host",
"localhost",
"--port",
"7687",
"--use-ssl=false",
"--query",
"RETURN 1;"
]
interval: 5s
timeout: 10s
retries: 10
start_period: 3s
ports:
- "7688:7687" # Bolt (using different port to avoid conflict)
volumes:
- memgraph_data:/var/lib/memgraph
environment:
- MEMGRAPH_USER=${MEMGRAPH_USER:-}
- MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-}
command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"]
volumes:
neo4j_data:
memgraph_data:

View file

@ -48,7 +48,9 @@ load_dotenv()
# Memgraph connection parameters
# Make sure Memgraph is running (default port 7687, same as Neo4j)
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default
memgraph_user = os.environ.get(
'MEMGRAPH_USER', ''
) # Memgraph often doesn't require auth by default
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
if not memgraph_uri:
@ -66,7 +68,7 @@ async def main():
# Initialize Memgraph driver
memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password)
# Initialize Graphiti with Memgraph connection
graphiti = Graphiti(graph_driver=memgraph_driver)

View file

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from .neo4j_driver import Neo4jDriver
from .memgraph_driver import MemgraphDriver
from .neo4j_driver import Neo4jDriver
__all__ = ['Neo4jDriver', 'MemgraphDriver']

View file

@ -29,7 +29,9 @@ logger = logging.getLogger(__name__)
class MemgraphDriver(GraphDriver):
provider = GraphProvider.MEMGRAPH
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
def __init__(
self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
@ -37,7 +39,9 @@ class MemgraphDriver(GraphDriver):
)
self._database = database
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]:
async def execute_query(
self, cypher_query_: LiteralString, **kwargs: Any
) -> tuple[list, Any, Any]:
"""
Execute a Cypher query against Memgraph using implicit transactions.
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
@ -49,20 +53,20 @@ class MemgraphDriver(GraphDriver):
# but first extract database-specific parameters
database = kwargs.pop('database_', self._database)
kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param)
# All remaining kwargs are query parameters
params = kwargs
else:
# Extract database parameter if params was provided separately
database = kwargs.pop('database_', self._database)
kwargs.pop('parameters_', None) # Remove if present
async with self.client.session(database=database) as session:
try:
result = await session.run(cypher_query_, params)
records = [record async for record in result]
summary = await result.consume()
keys = result.keys()
keys = result.keys()
return (records, summary, keys)
except Exception as e:
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
@ -77,4 +81,4 @@ class MemgraphDriver(GraphDriver):
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
# TODO: Implement index deletion for Memgraph
raise NotImplementedError("Index deletion not implemented for MemgraphDriver")
raise NotImplementedError('Index deletion not implemented for MemgraphDriver')

View file

@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
return f'array_cosine_similarity({vec1}, {vec2})'
if provider == GraphProvider.MEMGRAPH:
return f'cosineSimilarity({vec1}, {vec2})'
return f'vector_search.cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})'

View file

@ -104,7 +104,7 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
WITH e e.fact_embedding = $edge_data.fact_embedding
SET e.fact_embedding = $edge_data.fact_embedding
RETURN e.uuid AS uuid
"""
case _: # Neo4j

View file

@ -49,7 +49,18 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j and Memgraph
case GraphProvider.MEMGRAPH:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {
uuid: $uuid, name: $name, group_id: $group_id,
source_description: $source_description, source: $source,
content: $content, entity_edges: $entity_edges,
created_at: $created_at, valid_at: $valid_at
}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Episodic {uuid: $uuid})
SET n:$($group_label)
@ -96,23 +107,11 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END |
SET n:`${episode.group_label}`
)
SET n = {
uuid: episode.uuid,
name: episode.name,
group_id: episode.group_id,
source_description: episode.source_description,
source: episode.source,
content: episode.content,
entity_edges: episode.entity_edges,
created_at: episode.created_at,
valid_at: episode.valid_at
}
RETURN n.uuid AS uuid;
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
case _: # Neo4j
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
@ -183,7 +182,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
RETURN n.uuid AS uuid
"""
case GraphProvider.MEMGRAPH:
return """
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
@ -255,13 +254,10 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
return """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END |
SET n:`${label}`
)
SET n = node
WITH n, node
SET n.name_embedding = node.name_embedding
RETURN n.uuid AS uuid;
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """

View file

@ -299,7 +299,7 @@ class EpisodicNode(Node):
'source': self.source.value,
}
if driver.provider == GraphProvider.NEO4J:
if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH):
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
result = await driver.execute_query(