run unit tests and change api where needed
This commit is contained in:
parent
78cdab98d9
commit
e43a756ac1
9 changed files with 80 additions and 34 deletions
15
.github/workflows/unit_tests.yml
vendored
15
.github/workflows/unit_tests.yml
vendored
|
|
@ -29,6 +29,11 @@ jobs:
|
|||
NEO4J_AUTH: neo4j/testpass
|
||||
NEO4J_PLUGINS: '["apoc"]'
|
||||
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||
memgraph:
|
||||
image: memgraph/memgraph:latest
|
||||
ports:
|
||||
- 7688:7687
|
||||
options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
|
|
@ -77,3 +82,13 @@ jobs:
|
|||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
uv run pytest tests/test_*_int.py -k "neo4j"
|
||||
- name: Run Memgraph integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
MEMGRAPH_URI: bolt://localhost:7688
|
||||
MEMGRAPH_USER:
|
||||
MEMGRAPH_PASSWORD:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
uv run pytest tests/test_*_int.py -k "memgraph"
|
||||
|
||||
|
|
@ -44,5 +44,34 @@ services:
|
|||
environment:
|
||||
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
|
||||
|
||||
memgraph:
|
||||
image: memgraph/memgraph:latest
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"mg_client",
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
"7687",
|
||||
"--use-ssl=false",
|
||||
"--query",
|
||||
"RETURN 1;"
|
||||
]
|
||||
interval: 5s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
start_period: 3s
|
||||
ports:
|
||||
- "7688:7687" # Bolt (using different port to avoid conflict)
|
||||
volumes:
|
||||
- memgraph_data:/var/lib/memgraph
|
||||
environment:
|
||||
- MEMGRAPH_USER=${MEMGRAPH_USER:-}
|
||||
- MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-}
|
||||
command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"]
|
||||
|
||||
volumes:
|
||||
neo4j_data:
|
||||
memgraph_data:
|
||||
|
|
|
|||
|
|
@ -48,7 +48,9 @@ load_dotenv()
|
|||
# Memgraph connection parameters
|
||||
# Make sure Memgraph is running (default port 7687, same as Neo4j)
|
||||
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
|
||||
memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default
|
||||
memgraph_user = os.environ.get(
|
||||
'MEMGRAPH_USER', ''
|
||||
) # Memgraph often doesn't require auth by default
|
||||
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
|
||||
|
||||
if not memgraph_uri:
|
||||
|
|
@ -66,7 +68,7 @@ async def main():
|
|||
|
||||
# Initialize Memgraph driver
|
||||
memgraph_driver = MemgraphDriver(memgraph_uri, memgraph_user, memgraph_password)
|
||||
|
||||
|
||||
# Initialize Graphiti with Memgraph connection
|
||||
graphiti = Graphiti(graph_driver=memgraph_driver)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from .neo4j_driver import Neo4jDriver
|
||||
from .memgraph_driver import MemgraphDriver
|
||||
from .neo4j_driver import Neo4jDriver
|
||||
|
||||
__all__ = ['Neo4jDriver', 'MemgraphDriver']
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ logger = logging.getLogger(__name__)
|
|||
class MemgraphDriver(GraphDriver):
|
||||
provider = GraphProvider.MEMGRAPH
|
||||
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
||||
def __init__(
|
||||
self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
|
|
@ -37,7 +39,9 @@ class MemgraphDriver(GraphDriver):
|
|||
)
|
||||
self._database = database
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]:
|
||||
async def execute_query(
|
||||
self, cypher_query_: LiteralString, **kwargs: Any
|
||||
) -> tuple[list, Any, Any]:
|
||||
"""
|
||||
Execute a Cypher query against Memgraph using implicit transactions.
|
||||
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
|
||||
|
|
@ -49,20 +53,20 @@ class MemgraphDriver(GraphDriver):
|
|||
# but first extract database-specific parameters
|
||||
database = kwargs.pop('database_', self._database)
|
||||
kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param)
|
||||
|
||||
|
||||
# All remaining kwargs are query parameters
|
||||
params = kwargs
|
||||
else:
|
||||
# Extract database parameter if params was provided separately
|
||||
database = kwargs.pop('database_', self._database)
|
||||
kwargs.pop('parameters_', None) # Remove if present
|
||||
|
||||
|
||||
async with self.client.session(database=database) as session:
|
||||
try:
|
||||
result = await session.run(cypher_query_, params)
|
||||
records = [record async for record in result]
|
||||
summary = await result.consume()
|
||||
keys = result.keys()
|
||||
keys = result.keys()
|
||||
return (records, summary, keys)
|
||||
except Exception as e:
|
||||
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
|
||||
|
|
@ -77,4 +81,4 @@ class MemgraphDriver(GraphDriver):
|
|||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
# TODO: Implement index deletion for Memgraph
|
||||
raise NotImplementedError("Index deletion not implemented for MemgraphDriver")
|
||||
raise NotImplementedError('Index deletion not implemented for MemgraphDriver')
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|||
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||
|
||||
if provider == GraphProvider.MEMGRAPH:
|
||||
return f'cosineSimilarity({vec1}, {vec2})'
|
||||
return f'vector_search.cosine_similarity({vec1}, {vec2})'
|
||||
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e e.fact_embedding = $edge_data.fact_embedding
|
||||
SET e.fact_embedding = $edge_data.fact_embedding
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
|
|
|
|||
|
|
@ -49,7 +49,18 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
|||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j and Memgraph
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {
|
||||
uuid: $uuid, name: $name, group_id: $group_id,
|
||||
source_description: $source_description, source: $source,
|
||||
content: $content, entity_edges: $entity_edges,
|
||||
created_at: $created_at, valid_at: $valid_at
|
||||
}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n:$($group_label)
|
||||
|
|
@ -96,23 +107,11 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
|||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END |
|
||||
SET n:`${episode.group_label}`
|
||||
)
|
||||
SET n = {
|
||||
uuid: episode.uuid,
|
||||
name: episode.name,
|
||||
group_id: episode.group_id,
|
||||
source_description: episode.source_description,
|
||||
source: episode.source,
|
||||
content: episode.content,
|
||||
entity_edges: episode.entity_edges,
|
||||
created_at: episode.created_at,
|
||||
valid_at: episode.valid_at
|
||||
}
|
||||
RETURN n.uuid AS uuid;
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
|
|
@ -183,7 +182,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.MEMGRAPH:
|
||||
return """
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
|
|
@ -255,13 +254,10 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END |
|
||||
SET n:`${label}`
|
||||
)
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = node.name_embedding
|
||||
RETURN n.uuid AS uuid;
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ class EpisodicNode(Node):
|
|||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.provider == GraphProvider.NEO4J:
|
||||
if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH):
|
||||
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
||||
|
||||
result = await driver.execute_query(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue