feat: async close and multi-group search support (#151)

* chore: Support a list of group_ids on search + await driver.close()

* fix: formatter and linter

* chore: Version bump
This commit is contained in:
Pavlo Paliychuk 2024-09-24 16:13:04 -04:00 committed by GitHub
parent 794b705664
commit 44b016da6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 9 additions and 6 deletions

View file

@ -129,7 +129,7 @@ class Graphiti:
else: else:
self.llm_client = OpenAIClient() self.llm_client = OpenAIClient()
def close(self): async def close(self):
""" """
Close the connection to the Neo4j database. Close the connection to the Neo4j database.
@ -159,7 +159,7 @@ class Graphiti:
finally: finally:
graphiti.close() graphiti.close()
""" """
self.driver.close() await self.driver.close()
async def build_indices_and_constraints(self): async def build_indices_and_constraints(self):
""" """

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.3.4" version = "0.3.5"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",

View file

@ -6,7 +6,7 @@ from graph_service.dto.common import Message
class SearchQuery(BaseModel): class SearchQuery(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get') group_ids: list[str] = Field(description='The group ids for the memories to search')
query: str query: str
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve') max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')

View file

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional, cast
from fastapi import APIRouter, status from fastapi import APIRouter, status
@ -17,7 +18,9 @@ router = APIRouter()
@router.post('/search', status_code=status.HTTP_200_OK) @router.post('/search', status_code=status.HTTP_200_OK)
async def search(query: SearchQuery, graphiti: ZepGraphitiDep): async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
relevant_edges = await graphiti.search( relevant_edges = await graphiti.search(
group_ids=[query.group_id], group_ids=cast(
Optional[List[Optional[str]]], query.group_ids
), # Cast query.group_ids to match the expected type in graphiti.search
query=query.query, query=query.query,
num_results=query.max_facts, num_results=query.max_facts,
) )

View file

@ -79,7 +79,7 @@ async def get_graphiti(settings: ZepEnvDep):
try: try:
yield client yield client
finally: finally:
client.close() await client.close()
async def initialize_graphiti(settings: ZepEnvDep): async def initialize_graphiti(settings: ZepEnvDep):