From 71f1f66d112bca26c54109809293fc59855bf9ac Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Sun, 26 Oct 2025 22:07:36 -0400 Subject: [PATCH 1/4] Search client update (#1026) * update bulk interfae handling * bump version * format --- graphiti_core/driver/falkordb_driver.py | 10 +++++----- graphiti_core/driver/neo4j_driver.py | 4 ++-- graphiti_core/utils/bulk_utils.py | 6 ++---- mcp_server/graphiti_mcp_server.py | 16 ++++++++++++---- pyproject.toml | 2 +- uv.lock | 4 ++-- 6 files changed, 24 insertions(+), 18 deletions(-) diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 793f0545..d0b4ffe8 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. """ -import datetime import asyncio +import datetime import logging from typing import TYPE_CHECKING, Any @@ -231,17 +231,17 @@ class FalkorDriver(GraphDriver): """ cloned = FalkorDriver(falkor_db=self.client, database=database) - return cloned + return cloned async def health_check(self) -> None: """Check FalkorDB connectivity by running a simple query.""" try: - await self.execute_query("MATCH (n) RETURN 1 LIMIT 1") + await self.execute_query('MATCH (n) RETURN 1 LIMIT 1') return None except Exception as e: - print(f"FalkorDB health check failed: {e}") + print(f'FalkorDB health check failed: {e}') raise - + @staticmethod def convert_datetimes_to_strings(obj): if isinstance(obj, dict): diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 4e114943..4a0baf79 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -72,12 +72,12 @@ class Neo4jDriver(GraphDriver): return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) - + async def health_check(self) -> None: """Check Neo4j connectivity by running the driver's verify_connectivity method.""" try: await self.client.verify_connectivity() return None except Exception as e: - print(f"Neo4j health check failed: {e}") + print(f'Neo4j health check failed: {e}') raise diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 049aa53e..4a861b1b 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -214,12 +214,10 @@ async def add_nodes_and_edges_bulk_tx( edges.append(edge_data) if driver.graph_operations_interface: - await driver.graph_operations_interface.episodic_node_save_bulk( - None, driver, tx, episodic_nodes - ) + await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes) await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes) await driver.graph_operations_interface.episodic_edge_save_bulk( - None, driver, tx, episodic_edges + None, driver, tx, [edge.model_dump() for edge in episodic_edges] ) await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges) diff --git a/mcp_server/graphiti_mcp_server.py b/mcp_server/graphiti_mcp_server.py index 5a650b24..3919fd78 100644 --- a/mcp_server/graphiti_mcp_server.py +++ b/mcp_server/graphiti_mcp_server.py @@ -467,6 +467,7 @@ class Neo4jConfig(BaseModel): password=os.environ.get('NEO4J_PASSWORD', 'password'), ) + class FalkorConfig(BaseModel): """Configuration for FalkorDB database connection.""" @@ -483,6 +484,7 @@ class FalkorConfig(BaseModel): password = os.environ.get('FALKORDB_PASSWORD', '') return cls(host=host, port=port, user=user, password=password) + class GraphitiConfig(BaseModel): """Configuration for Graphiti client. @@ -504,7 +506,9 @@ class GraphitiConfig(BaseModel): """Create a configuration instance from environment variables.""" db_type = os.environ.get('DATABASE_TYPE') if not db_type: - raise ValueError('DATABASE_TYPE environment variable must be set (e.g., "neo4j" or "falkordb")') + raise ValueError( + 'DATABASE_TYPE environment variable must be set (e.g., "neo4j" or "falkordb")' + ) if db_type == 'neo4j': return cls( llm=GraphitiLLMConfig.from_env(), @@ -622,7 +626,9 @@ async def initialize_graphiti(): raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set') # Validate FalkorDB configuration - if config.database_type == 'falkordb' and (not config.falkordb.host or not config.falkordb.port): + if config.database_type == 'falkordb' and ( + not config.falkordb.host or not config.falkordb.port + ): raise ValueError('FALKORDB_HOST and FALKORDB_PORT must be set for FalkorDB') embedder_client = config.embedder.create_client() @@ -637,6 +643,7 @@ async def initialize_graphiti(): ) elif config.database_type == 'falkordb': from graphiti_core.driver.falkordb_driver import FalkorDriver + host = config.falkordb.host if hasattr(config.falkordb, 'host') else 'localhost' port = int(config.falkordb.port) if hasattr(config.falkordb, 'port') else 6379 username = config.falkordb.user or None @@ -1205,10 +1212,11 @@ async def get_status() -> StatusResponse: client = cast(Graphiti, graphiti_client) # Test database connection - await client.driver.health_check() # type: ignore # type: ignore + await client.driver.health_check() # type: ignore # type: ignore return StatusResponse( - status='ok', message=f'Graphiti MCP server is running and connected to {config.database_type}' + status='ok', + message=f'Graphiti MCP server is running and connected to {config.database_type}', ) except Exception as e: error_msg = str(e) diff --git a/pyproject.toml b/pyproject.toml index 1932e770..59e88e21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.22.0" +version = "0.22.1pre1" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index a489a57d..2bf14336 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <4" resolution-markers = [ "python_full_version >= '3.14'", @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.22.0rc5" +version = "0.22.1rc1" source = { editable = "." } dependencies = [ { name = "diskcache" }, From 56f6d09df0da8cabdc900e74041ab4a7a37f7852 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 27 Oct 2025 12:06:38 -0700 Subject: [PATCH 2/4] Add MCP server release workflow (#1025) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * conductor-checkpoint-start * conductor-checkpoint-msg_01B1n4yHQFoVrWWdKcqPQ4Qa * conductor-checkpoint-msg_01LS1v8ok5qtzAofv1TFRDHt * conductor-checkpoint-msg_01H5pxrRKDpizF4wv1irnvRz * conductor-checkpoint-msg_01EFo2gQBKSFkGcJoJ4bUWNS * conductor-checkpoint-msg_01QW92pnqMv17sfV4CxFKn7Y * conductor-checkpoint-msg_01VqPRMaBRGpBf9E8sdpPeFa * Fix critical issues in MCP server release workflow - Fix Docker tag format: use version only (0.4.0) instead of mcp-v0.4.0 - Add Python 3.11 setup for tomllib compatibility - Add workflow_dispatch trigger for testing without creating tags - Add conditional push logic for manual testing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * Remove workflow_dispatch trigger from MCP server release Simplify workflow to only trigger on mcp-v*.*.* tags. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * conductor-checkpoint-msg_019AX8ymwf9eec2KF979CJCM * conductor-checkpoint-msg_01LMofTLUNkicSq5vpFmsd1C * Add semantic version validation to MCP server release Validate tag follows X.Y.Z format before processing. Rejects malformed tags like mcp-v1.0 or mcp-v1.0.0.0. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * conductor-checkpoint-msg_01Ndj59qcprSNRfe3vuciwwA * conductor-checkpoint-msg_01PmA8bfCLKv7yHiaBz2MypS --------- Co-authored-by: Claude --- .github/workflows/mcp-server-docker.yml | 73 ----------------------- .github/workflows/release-mcp-server.yml | 74 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 73 deletions(-) delete mode 100644 .github/workflows/mcp-server-docker.yml create mode 100644 .github/workflows/release-mcp-server.yml diff --git a/.github/workflows/mcp-server-docker.yml b/.github/workflows/mcp-server-docker.yml deleted file mode 100644 index 67002d1a..00000000 --- a/.github/workflows/mcp-server-docker.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: Build and Push MCP Server Docker Image - -on: - push: - paths: - - "mcp_server/pyproject.toml" - branches: - - main - pull_request: - paths: - - "mcp_server/pyproject.toml" - branches: - - main - workflow_dispatch: - inputs: - push_image: - description: "Push image to registry (unchecked for testing)" - required: false - default: false - type: boolean - -env: - REGISTRY: docker.io - IMAGE_NAME: zepai/knowledge-graph-mcp - -jobs: - build-and-push: - runs-on: depot-ubuntu-24.04-small - environment: development - permissions: - contents: read - id-token: write - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Extract version from pyproject.toml - id: version - run: | - VERSION=$(python -c "import tomllib; print(tomllib.load(open('mcp_server/pyproject.toml', 'rb'))['project']['version'])") - echo "version=$VERSION" >> $GITHUB_OUTPUT - echo "tag=v$VERSION" >> $GITHUB_OUTPUT - - name: Log in to Docker Hub - if: github.event_name != 'pull_request' && (github.event_name != 'workflow_dispatch' || inputs.push_image) - uses: docker/login-action@v3 - with: - registry: ${{ env.REGISTRY }} - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Set up Depot CLI - uses: depot/setup-action@v1 - - - name: Extract metadata - id: meta - uses: docker/metadata-action@v5 - with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - tags: | - type=ref,event=branch - type=ref,event=pr - type=raw,value=${{ steps.version.outputs.tag }} - type=raw,value=latest,enable={{is_default_branch}} - - - name: Depot build and push image - uses: depot/build-push-action@v1 - with: - project: v9jv1mlpwc - context: ./mcp_server - platforms: linux/amd64,linux/arm64 - push: ${{ github.event_name != 'pull_request' && (github.event_name != 'workflow_dispatch' || inputs.push_image) }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/release-mcp-server.yml b/.github/workflows/release-mcp-server.yml new file mode 100644 index 00000000..7790d66f --- /dev/null +++ b/.github/workflows/release-mcp-server.yml @@ -0,0 +1,74 @@ +name: Release MCP Server + +on: + push: + tags: ["mcp-v*.*.*"] + +env: + REGISTRY: docker.io + IMAGE_NAME: zepai/knowledge-graph-mcp + +jobs: + release: + runs-on: depot-ubuntu-24.04-small + permissions: + contents: write + id-token: write + environment: + name: release + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Extract and validate version + id: version + run: | + TAG_VERSION=${GITHUB_REF#refs/tags/mcp-v} + + if ! [[ $TAG_VERSION =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "Tag must follow semantic versioning: mcp-vX.Y.Z" + exit 1 + fi + + PROJECT_VERSION=$(python -c "import tomllib; print(tomllib.load(open('mcp_server/pyproject.toml', 'rb'))['project']['version'])") + + if [ "$TAG_VERSION" != "$PROJECT_VERSION" ]; then + echo "Tag version mcp-v$TAG_VERSION does not match mcp_server/pyproject.toml version $PROJECT_VERSION" + exit 1 + fi + + echo "version=$PROJECT_VERSION" >> $GITHUB_OUTPUT + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Set up Depot CLI + uses: depot/setup-action@v1 + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=raw,value=${{ steps.version.outputs.version }} + type=raw,value=latest + + - name: Build and push Docker image + uses: depot/build-push-action@v1 + with: + project: v9jv1mlpwc + context: ./mcp_server + platforms: linux/amd64,linux/arm64 + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} From ae227ce927c1b165a052251a026aa25702d95eb2 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 29 Oct 2025 02:32:04 -0700 Subject: [PATCH 3/4] @didier-durand has signed the CLA in getzep/graphiti#1028 --- signatures/version1/cla.json | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/signatures/version1/cla.json b/signatures/version1/cla.json index 315720a8..e3878a72 100644 --- a/signatures/version1/cla.json +++ b/signatures/version1/cla.json @@ -431,6 +431,14 @@ "created_at": "2025-10-22T09:52:01Z", "repoId": 840056306, "pullRequestNo": 1020 + }, + { + "name": "didier-durand", + "id": 2927957, + "comment_id": 3460571645, + "created_at": "2025-10-29T09:31:25Z", + "repoId": 840056306, + "pullRequestNo": 1028 } ] } \ No newline at end of file From c29f4da21ee237b4c9ea22559218946810149e25 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Wed, 29 Oct 2025 09:51:58 -0400 Subject: [PATCH 4/4] update mmr to use bulk load overrides (#1029) * update mmr to use bulk load overrides * update returns * update --- .../graph_operations/graph_operations.py | 12 +- graphiti_core/search/search_utils.py | 112 +++++++++--------- pyproject.toml | 2 +- uv.lock | 2 +- 4 files changed, 64 insertions(+), 64 deletions(-) diff --git a/graphiti_core/driver/graph_operations/graph_operations.py b/graphiti_core/driver/graph_operations/graph_operations.py index e4887923..54a59053 100644 --- a/graphiti_core/driver/graph_operations/graph_operations.py +++ b/graphiti_core/driver/graph_operations/graph_operations.py @@ -77,14 +77,12 @@ class GraphOperationsInterface(BaseModel): async def node_load_embeddings_bulk( self, - _cls: Any, driver: Any, - transaction: Any, nodes: list[Any], batch_size: int = 100, - ) -> None: + ) -> dict[str, list[float]]: """ - Load embedding vectors for many nodes in batches. Mutates the provided node instances. + Load embedding vectors for many nodes in batches. """ raise NotImplementedError @@ -183,13 +181,11 @@ class GraphOperationsInterface(BaseModel): async def edge_load_embeddings_bulk( self, - _cls: Any, driver: Any, - transaction: Any, edges: list[Any], batch_size: int = 100, - ) -> None: + ) -> dict[str, list[float]]: """ - Load embedding vectors for many edges in batches. Mutates the provided edge instances. + Load embedding vectors for many edges in batches """ raise NotImplementedError diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 104aede6..4c0e84fa 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -217,11 +217,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -339,8 +339,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -596,11 +596,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -678,8 +678,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -708,11 +708,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -733,8 +733,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1037,8 +1037,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1097,8 +1097,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1240,9 +1240,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1287,9 +1287,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1378,9 +1378,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1450,9 +1450,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1488,9 +1488,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1563,10 +1563,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1636,10 +1636,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1675,10 +1675,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ @@ -1879,7 +1879,9 @@ def maximal_marginal_relevance( async def get_embeddings_for_nodes( driver: GraphDriver, nodes: list[EntityNode] ) -> dict[str, list[float]]: - if driver.provider == GraphProvider.NEPTUNE: + if driver.graph_operations_interface: + return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes) + elif driver.provider == GraphProvider.NEPTUNE: query = """ MATCH (n:Entity) WHERE n.uuid IN $node_uuids @@ -1949,7 +1951,9 @@ async def get_embeddings_for_communities( async def get_embeddings_for_edges( driver: GraphDriver, edges: list[EntityEdge] ) -> dict[str, list[float]]: - if driver.provider == GraphProvider.NEPTUNE: + if driver.graph_operations_interface: + return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges) + elif driver.provider == GraphProvider.NEPTUNE: query = """ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) WHERE e.uuid IN $edge_uuids diff --git a/pyproject.toml b/pyproject.toml index 59e88e21..5281f22b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.22.1pre1" +version = "0.22.1pre2" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index 2bf14336..84d7228c 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.22.1rc1" +version = "0.22.1rc2" source = { editable = "." } dependencies = [ { name = "diskcache" },