feat: async close and multi-group search support (#151)
* chore: Support a list of group_ids on search + await driver.close() * fix: formatter and linter * chore: Version bump
This commit is contained in:
parent
794b705664
commit
44b016da6b
5 changed files with 9 additions and 6 deletions
|
|
@ -129,7 +129,7 @@ class Graphiti:
|
||||||
else:
|
else:
|
||||||
self.llm_client = OpenAIClient()
|
self.llm_client = OpenAIClient()
|
||||||
|
|
||||||
def close(self):
|
async def close(self):
|
||||||
"""
|
"""
|
||||||
Close the connection to the Neo4j database.
|
Close the connection to the Neo4j database.
|
||||||
|
|
||||||
|
|
@ -159,7 +159,7 @@ class Graphiti:
|
||||||
finally:
|
finally:
|
||||||
graphiti.close()
|
graphiti.close()
|
||||||
"""
|
"""
|
||||||
self.driver.close()
|
await self.driver.close()
|
||||||
|
|
||||||
async def build_indices_and_constraints(self):
|
async def build_indices_and_constraints(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from graph_service.dto.common import Message
|
||||||
|
|
||||||
|
|
||||||
class SearchQuery(BaseModel):
|
class SearchQuery(BaseModel):
|
||||||
group_id: str = Field(..., description='The group id of the memory to get')
|
group_ids: list[str] = Field(description='The group ids for the memories to search')
|
||||||
query: str
|
query: str
|
||||||
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status
|
||||||
|
|
||||||
|
|
@ -17,7 +18,9 @@ router = APIRouter()
|
||||||
@router.post('/search', status_code=status.HTTP_200_OK)
|
@router.post('/search', status_code=status.HTTP_200_OK)
|
||||||
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
||||||
relevant_edges = await graphiti.search(
|
relevant_edges = await graphiti.search(
|
||||||
group_ids=[query.group_id],
|
group_ids=cast(
|
||||||
|
Optional[List[Optional[str]]], query.group_ids
|
||||||
|
), # Cast query.group_ids to match the expected type in graphiti.search
|
||||||
query=query.query,
|
query=query.query,
|
||||||
num_results=query.max_facts,
|
num_results=query.max_facts,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ async def get_graphiti(settings: ZepEnvDep):
|
||||||
try:
|
try:
|
||||||
yield client
|
yield client
|
||||||
finally:
|
finally:
|
||||||
client.close()
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
async def initialize_graphiti(settings: ZepEnvDep):
|
async def initialize_graphiti(settings: ZepEnvDep):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue