run unit tests and change api where needed
This commit is contained in:
parent
78cdab98d9
commit
e43a756ac1
9 changed files with 80 additions and 34 deletions
15
.github/workflows/unit_tests.yml
vendored
15
.github/workflows/unit_tests.yml
vendored
|
|
@ -29,6 +29,11 @@ jobs:
|
||||||
NEO4J_AUTH: neo4j/testpass
|
NEO4J_AUTH: neo4j/testpass
|
||||||
NEO4J_PLUGINS: '["apoc"]'
|
NEO4J_PLUGINS: '["apoc"]'
|
||||||
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
|
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||||
|
memgraph:
|
||||||
|
image: memgraph/memgraph:latest
|
||||||
|
ports:
|
||||||
|
- 7688:7687
|
||||||
|
options: --health-cmd "mg_client --host localhost --port 7687 --use-ssl=false --query 'RETURN 1;'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|
@ -77,3 +82,13 @@ jobs:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
uv run pytest tests/test_*_int.py -k "neo4j"
|
uv run pytest tests/test_*_int.py -k "neo4j"
|
||||||
|
- name: Run Memgraph integration tests
|
||||||
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
|
MEMGRAPH_URI: bolt://localhost:7688
|
||||||
|
MEMGRAPH_USER:
|
||||||
|
MEMGRAPH_PASSWORD:
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
run: |
|
||||||
|
uv run pytest tests/test_*_int.py -k "memgraph"
|
||||||
|
|
||||||
|
|
@ -44,5 +44,34 @@ services:
|
||||||
environment:
|
environment:
|
||||||
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
|
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
|
||||||
|
|
||||||
|
memgraph:
|
||||||
|
image: memgraph/memgraph:latest
|
||||||
|
healthcheck:
|
||||||
|
test:
|
||||||
|
[
|
||||||
|
"CMD",
|
||||||
|
"mg_client",
|
||||||
|
"--host",
|
||||||
|
"localhost",
|
||||||
|
"--port",
|
||||||
|
"7687",
|
||||||
|
"--use-ssl=false",
|
||||||
|
"--query",
|
||||||
|
"RETURN 1;"
|
||||||
|
]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 10
|
||||||
|
start_period: 3s
|
||||||
|
ports:
|
||||||
|
- "7688:7687" # Bolt (using different port to avoid conflict)
|
||||||
|
volumes:
|
||||||
|
- memgraph_data:/var/lib/memgraph
|
||||||
|
environment:
|
||||||
|
- MEMGRAPH_USER=${MEMGRAPH_USER:-}
|
||||||
|
- MEMGRAPH_PASSWORD=${MEMGRAPH_PASSWORD:-}
|
||||||
|
command: ["--log-level=TRACE", "--also-log-to-stderr", "--bolt-port=7687"]
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
neo4j_data:
|
neo4j_data:
|
||||||
|
memgraph_data:
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,9 @@ load_dotenv()
|
||||||
# Memgraph connection parameters
|
# Memgraph connection parameters
|
||||||
# Make sure Memgraph is running (default port 7687, same as Neo4j)
|
# Make sure Memgraph is running (default port 7687, same as Neo4j)
|
||||||
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
|
memgraph_uri = os.environ.get('MEMGRAPH_URI', 'bolt://localhost:7687')
|
||||||
memgraph_user = os.environ.get('MEMGRAPH_USER', '') # Memgraph often doesn't require auth by default
|
memgraph_user = os.environ.get(
|
||||||
|
'MEMGRAPH_USER', ''
|
||||||
|
) # Memgraph often doesn't require auth by default
|
||||||
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
|
memgraph_password = os.environ.get('MEMGRAPH_PASSWORD', '')
|
||||||
|
|
||||||
if not memgraph_uri:
|
if not memgraph_uri:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .neo4j_driver import Neo4jDriver
|
|
||||||
from .memgraph_driver import MemgraphDriver
|
from .memgraph_driver import MemgraphDriver
|
||||||
|
from .neo4j_driver import Neo4jDriver
|
||||||
|
|
||||||
__all__ = ['Neo4jDriver', 'MemgraphDriver']
|
__all__ = ['Neo4jDriver', 'MemgraphDriver']
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,9 @@ logger = logging.getLogger(__name__)
|
||||||
class MemgraphDriver(GraphDriver):
|
class MemgraphDriver(GraphDriver):
|
||||||
provider = GraphProvider.MEMGRAPH
|
provider = GraphProvider.MEMGRAPH
|
||||||
|
|
||||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
def __init__(
|
||||||
|
self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = AsyncGraphDatabase.driver(
|
self.client = AsyncGraphDatabase.driver(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
|
|
@ -37,7 +39,9 @@ class MemgraphDriver(GraphDriver):
|
||||||
)
|
)
|
||||||
self._database = database
|
self._database = database
|
||||||
|
|
||||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]:
|
async def execute_query(
|
||||||
|
self, cypher_query_: LiteralString, **kwargs: Any
|
||||||
|
) -> tuple[list, Any, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a Cypher query against Memgraph using implicit transactions.
|
Execute a Cypher query against Memgraph using implicit transactions.
|
||||||
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
|
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
|
||||||
|
|
@ -77,4 +81,4 @@ class MemgraphDriver(GraphDriver):
|
||||||
|
|
||||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||||
# TODO: Implement index deletion for Memgraph
|
# TODO: Implement index deletion for Memgraph
|
||||||
raise NotImplementedError("Index deletion not implemented for MemgraphDriver")
|
raise NotImplementedError('Index deletion not implemented for MemgraphDriver')
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
||||||
return f'array_cosine_similarity({vec1}, {vec2})'
|
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return f'cosineSimilarity({vec1}, {vec2})'
|
return f'vector_search.cosine_similarity({vec1}, {vec2})'
|
||||||
|
|
||||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
SET e = $edge_data
|
SET e = $edge_data
|
||||||
WITH e e.fact_embedding = $edge_data.fact_embedding
|
SET e.fact_embedding = $edge_data.fact_embedding
|
||||||
RETURN e.uuid AS uuid
|
RETURN e.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,18 @@ def get_episode_node_save_query(provider: GraphProvider) -> str:
|
||||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j and Memgraph
|
case GraphProvider.MEMGRAPH:
|
||||||
|
return """
|
||||||
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
|
SET n = {
|
||||||
|
uuid: $uuid, name: $name, group_id: $group_id,
|
||||||
|
source_description: $source_description, source: $source,
|
||||||
|
content: $content, entity_edges: $entity_edges,
|
||||||
|
created_at: $created_at, valid_at: $valid_at
|
||||||
|
}
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
SET n:$($group_label)
|
SET n:$($group_label)
|
||||||
|
|
@ -96,21 +107,9 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
return """
|
return """
|
||||||
UNWIND $episodes AS episode
|
UNWIND $episodes AS episode
|
||||||
MERGE (n:Episodic {uuid: episode.uuid})
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END |
|
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||||
SET n:`${episode.group_label}`
|
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
)
|
RETURN n.uuid AS uuid
|
||||||
SET n = {
|
|
||||||
uuid: episode.uuid,
|
|
||||||
name: episode.name,
|
|
||||||
group_id: episode.group_id,
|
|
||||||
source_description: episode.source_description,
|
|
||||||
source: episode.source,
|
|
||||||
content: episode.content,
|
|
||||||
entity_edges: episode.entity_edges,
|
|
||||||
created_at: episode.created_at,
|
|
||||||
valid_at: episode.valid_at
|
|
||||||
}
|
|
||||||
RETURN n.uuid AS uuid;
|
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
|
|
@ -183,7 +182,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case GraphProvider.MEMGRAPH:
|
case GraphProvider.MEMGRAPH:
|
||||||
return """
|
return f"""
|
||||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
SET n:{labels}
|
SET n:{labels}
|
||||||
SET n = $entity_data
|
SET n = $entity_data
|
||||||
|
|
@ -255,13 +254,10 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
||||||
return """
|
return """
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MERGE (n:Entity {uuid: node.uuid})
|
MERGE (n:Entity {uuid: node.uuid})
|
||||||
FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END |
|
|
||||||
SET n:`${label}`
|
|
||||||
)
|
|
||||||
SET n = node
|
SET n = node
|
||||||
WITH n, node
|
WITH n, node
|
||||||
SET n.name_embedding = node.name_embedding
|
SET n.name_embedding = node.name_embedding
|
||||||
RETURN n.uuid AS uuid;
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,7 @@ class EpisodicNode(Node):
|
||||||
'source': self.source.value,
|
'source': self.source.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
if driver.provider == GraphProvider.NEO4J:
|
if driver.provider in (GraphProvider.NEO4J, GraphProvider.MEMGRAPH):
|
||||||
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue