feat: async close and multi-group search support (#151)
* chore: Support a list of group_ids on search + await driver.close() * fix: formatter and linter * chore: Version bump
This commit is contained in:
parent
794b705664
commit
44b016da6b
5 changed files with 9 additions and 6 deletions
|
|
@ -129,7 +129,7 @@ class Graphiti:
|
|||
else:
|
||||
self.llm_client = OpenAIClient()
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
"""
|
||||
Close the connection to the Neo4j database.
|
||||
|
||||
|
|
@ -159,7 +159,7 @@ class Graphiti:
|
|||
finally:
|
||||
graphiti.close()
|
||||
"""
|
||||
self.driver.close()
|
||||
await self.driver.close()
|
||||
|
||||
async def build_indices_and_constraints(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from graph_service.dto.common import Message
|
|||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
group_id: str = Field(..., description='The group id of the memory to get')
|
||||
group_ids: list[str] = Field(description='The group ids for the memories to search')
|
||||
query: str
|
||||
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
|
||||
|
|
@ -17,7 +18,9 @@ router = APIRouter()
|
|||
@router.post('/search', status_code=status.HTTP_200_OK)
|
||||
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
||||
relevant_edges = await graphiti.search(
|
||||
group_ids=[query.group_id],
|
||||
group_ids=cast(
|
||||
Optional[List[Optional[str]]], query.group_ids
|
||||
), # Cast query.group_ids to match the expected type in graphiti.search
|
||||
query=query.query,
|
||||
num_results=query.max_facts,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ async def get_graphiti(settings: ZepEnvDep):
|
|||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
await client.close()
|
||||
|
||||
|
||||
async def initialize_graphiti(settings: ZepEnvDep):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue