run unit tests and change api where needed

This commit is contained in:
DavIvek 2025-09-09 16:43:23 +02:00
parent 78cdab98d9
commit e43a756ac1
9 changed files with 80 additions and 34 deletions

View file

@ -29,6 +29,11 @@ jobs:
NEO4J_AUTH: neo4j/testpass NEO4J_AUTH: neo4j/testpass
NEO4J_PLUGINS: '["apoc"]' NEO4J_PLUGINS: '["apoc"]'
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10 options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
memgraph:
image: memgraph/memgraph:latest
ports:
- 7688:7687
options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
@ -77,3 +82,13 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: | run: |
uv run pytest tests/test_*_int.py -k "neo4j" uv run pytest tests/test_*_int.py -k "neo4j"
- name: Run Memgraph integration tests
env:
PYTHONPATH: ${{ github.workspace }}
MEMGRAPH_URI: bolt://localhost:7688
MEMGRAPH_USER:
MEMGRAPH_PASSWORD:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
uv run pytest tests/test_*_int.py -k "memgraph"

View file

@ -44,5 +44,34 @@ services:
environment: environment:
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD} - NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
memgraph:
image: memgraph/memgraph:latest
healthcheck:
test:
[
"CMD",
"mg_client",
"--host",
"localhost",
"--port",
"7687",
"--use-ssl=false",
"--query",
"RETURN 1;"
]
interval: 5s
timeout: 10s
retries: 10
start_period: 3s
ports:
- "7688:7687" # Bolt (using different port to avoid conflict)
volumes:
- memgraph_data:/var/lib/memgraph
environment:
- MEMGRAPH_USER=${MEMGRAPH_USER:-}
- MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-}
command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"]
volumes: volumes:
neo4j_data: neo4j_data:
memgraph_data:

View file

@ -48,7 +48,9 @@ load_dotenv()
# Memgraph connection parameters # Memgraph connection parameters
# Make sure Memgraph is running (default port 7687, same as Neo4j) # Make sure Memgraph is running (default port 7687, same as Neo4j)
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687') memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default memgraph_user = os.environ.get(
'MEMGRAPH_USER', ''
) # Memgraph often doesn't require auth by default
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '') memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
if not memgraph_uri: if not memgraph_uri:

View file

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from .neo4j_driver import Neo4jDriver
from .memgraph_driver import MemgraphDriver from .memgraph_driver import MemgraphDriver
from .neo4j_driver import Neo4jDriver
__all__ = ['Neo4jDriver', 'MemgraphDriver'] __all__ = ['Neo4jDriver', 'MemgraphDriver']

View file

@ -29,7 +29,9 @@ logger = logging.getLogger(__name__)
class MemgraphDriver(GraphDriver): class MemgraphDriver(GraphDriver):
provider = GraphProvider.MEMGRAPH provider = GraphProvider.MEMGRAPH
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'): def __init__(
self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'
):
super().__init__() super().__init__()
self.client = AsyncGraphDatabase.driver( self.client = AsyncGraphDatabase.driver(
uri=uri, uri=uri,
@ -37,7 +39,9 @@ class MemgraphDriver(GraphDriver):
) )
self._database = database self._database = database
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]: async def execute_query(
self, cypher_query_: LiteralString, **kwargs: Any
) -> tuple[list, Any, Any]:
""" """
Execute a Cypher query against Memgraph using implicit transactions. Execute a Cypher query against Memgraph using implicit transactions.
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
@ -77,4 +81,4 @@ class MemgraphDriver(GraphDriver):
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
# TODO: Implement index deletion for Memgraph # TODO: Implement index deletion for Memgraph
raise NotImplementedError("Index deletion not implemented for MemgraphDriver") raise NotImplementedError('Index deletion not implemented for MemgraphDriver')

View file

@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
return f'array_cosine_similarity({vec1}, {vec2})' return f'array_cosine_similarity({vec1}, {vec2})'
if provider == GraphProvider.MEMGRAPH: if provider == GraphProvider.MEMGRAPH:
return f'cosineSimilarity({vec1}, {vec2})' return f'vector_search.cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})' return f'vector.similarity.cosine({vec1}, {vec2})'

View file

@ -104,7 +104,7 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
MATCH (target:Entity {uuid: $edge_data.target_uuid}) MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data SET e = $edge_data
WITH e e.fact_embedding = $edge_data.fact_embedding SET e.fact_embedding = $edge_data.fact_embedding
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
case _: # Neo4j case _: # Neo4j

View file

@ -49,7 +49,18 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j and Memgraph case GraphProvider.MEMGRAPH:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {
uuid: $uuid, name: $name, group_id: $group_id,
source_description: $source_description, source: $source,
content: $content, entity_edges: $entity_edges,
created_at: $created_at, valid_at: $valid_at
}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """ return """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n:$($group_label) SET n:$($group_label)
@ -96,21 +107,9 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
return """ return """
UNWIND $episodes AS episode UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid}) MERGE (n:Episodic {uuid: episode.uuid})
FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END | SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
SET n:`${episode.group_label}` entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
) RETURN n.uuid AS uuid
SET n = {
uuid: episode.uuid,
name: episode.name,
group_id: episode.group_id,
source_description: episode.source_description,
source: episode.source,
content: episode.content,
entity_edges: episode.entity_edges,
created_at: episode.created_at,
valid_at: episode.valid_at
}
RETURN n.uuid AS uuid;
""" """
case _: # Neo4j case _: # Neo4j
return """ return """
@ -183,7 +182,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case GraphProvider.MEMGRAPH: case GraphProvider.MEMGRAPH:
return """ return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}}) MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels} SET n:{labels}
SET n = $entity_data SET n = $entity_data
@ -255,13 +254,10 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
return """ return """
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid}) MERGE (n:Entity {uuid: node.uuid})
FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END |
SET n:`${label}`
)
SET n = node SET n = node
WITH n, node WITH n, node
SET n.name_embedding = node.name_embedding SET n.name_embedding = node.name_embedding
RETURN n.uuid AS uuid; RETURN n.uuid AS uuid
""" """
case _: # Neo4j case _: # Neo4j
return """ return """

View file

@ -299,7 +299,7 @@ class EpisodicNode(Node):
'source': self.source.value, 'source': self.source.value,
} }
if driver.provider == GraphProvider.NEO4J: if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH):
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '') episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
result = await driver.execute_query( result = await driver.execute_query(