From 44b016da6b2cb170f05d55090de6d41bddc9862d Mon Sep 17 00:00:00 2001 From: Pavlo Paliychuk Date: Tue, 24 Sep 2024 16:13:04 -0400 Subject: [PATCH] feat: async close and multi-group search support (#151) * chore: Support a list of group_ids on search + await driver.close() * fix: formatter and linter * chore: Version bump --- graphiti_core/graphiti.py | 4 ++-- pyproject.toml | 2 +- server/graph_service/dto/retrieve.py | 2 +- server/graph_service/routers/retrieve.py | 5 ++++- server/graph_service/zep_graphiti.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index de1416b3..540c9ff1 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -129,7 +129,7 @@ class Graphiti: else: self.llm_client = OpenAIClient() - def close(self): + async def close(self): """ Close the connection to the Neo4j database. @@ -159,7 +159,7 @@ class Graphiti: finally: graphiti.close() """ - self.driver.close() + await self.driver.close() async def build_indices_and_constraints(self): """ diff --git a/pyproject.toml b/pyproject.toml index 3132e80c..304b397a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.3.4" +version = "0.3.5" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/server/graph_service/dto/retrieve.py b/server/graph_service/dto/retrieve.py index 5b02a3f3..1e7513ae 100644 --- a/server/graph_service/dto/retrieve.py +++ b/server/graph_service/dto/retrieve.py @@ -6,7 +6,7 @@ from graph_service.dto.common import Message class SearchQuery(BaseModel): - group_id: str = Field(..., description='The group id of the memory to get') + group_ids: list[str] = Field(description='The group ids for the memories to search') query: str max_facts: int = Field(default=10, description='The maximum number of facts to retrieve') diff --git a/server/graph_service/routers/retrieve.py b/server/graph_service/routers/retrieve.py index 2be5e0a8..333ca0b4 100644 --- a/server/graph_service/routers/retrieve.py +++ b/server/graph_service/routers/retrieve.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import List, Optional, cast from fastapi import APIRouter, status @@ -17,7 +18,9 @@ router = APIRouter() @router.post('/search', status_code=status.HTTP_200_OK) async def search(query: SearchQuery, graphiti: ZepGraphitiDep): relevant_edges = await graphiti.search( - group_ids=[query.group_id], + group_ids=cast( + Optional[List[Optional[str]]], query.group_ids + ), # Cast query.group_ids to match the expected type in graphiti.search query=query.query, num_results=query.max_facts, ) diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 13054f37..66457130 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -79,7 +79,7 @@ async def get_graphiti(settings: ZepEnvDep): try: yield client finally: - client.close() + await client.close() async def initialize_graphiti(settings: ZepEnvDep):