feat(graph-service): add entity node handling and update Docker configurations (#100)
* feat: Add entity node request + service maintenance * chore: Fix linter
This commit is contained in:
parent
3f12254916
commit
ad2962c6ba
10 changed files with 59 additions and 23 deletions
|
|
@ -37,6 +37,7 @@ COPY ./server /app
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PORT=8000
|
||||||
# Command to run the application
|
# Command to run the application
|
||||||
CMD ["uvicorn", "graph_service.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
||||||
|
CMD uvicorn graph_service.main:app --host 0.0.0.0 --port $PORT
|
||||||
|
|
@ -2,7 +2,7 @@ version: '3.8'
|
||||||
|
|
||||||
services:
|
services:
|
||||||
graph:
|
graph:
|
||||||
build: .
|
image: zepai/graphiti:latest
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|
||||||
|
|
@ -11,6 +11,7 @@ services:
|
||||||
- NEO4J_URI=bolt://neo4j:${NEO4J_PORT}
|
- NEO4J_URI=bolt://neo4j:${NEO4J_PORT}
|
||||||
- NEO4J_USER=${NEO4J_USER}
|
- NEO4J_USER=${NEO4J_USER}
|
||||||
- NEO4J_PASSWORD=${NEO4J_PASSWORD}
|
- NEO4J_PASSWORD=${NEO4J_PASSWORD}
|
||||||
|
- PORT=8000
|
||||||
neo4j:
|
neo4j:
|
||||||
image: neo4j:5.22.0
|
image: neo4j:5.22.0
|
||||||
|
|
||||||
|
|
|
||||||
0
server/graph_service/app.py
Normal file
0
server/graph_service/app.py
Normal file
|
|
@ -1,5 +1,5 @@
|
||||||
from .common import Message, Result
|
from .common import Message, Result
|
||||||
from .ingest import AddMessagesRequest
|
from .ingest import AddEntityNodeRequest, AddMessagesRequest
|
||||||
from .retrieve import (
|
from .retrieve import (
|
||||||
FactResult,
|
FactResult,
|
||||||
GetMemoryRequest,
|
GetMemoryRequest,
|
||||||
|
|
@ -12,6 +12,7 @@ __all__ = [
|
||||||
'SearchQuery',
|
'SearchQuery',
|
||||||
'Message',
|
'Message',
|
||||||
'AddMessagesRequest',
|
'AddMessagesRequest',
|
||||||
|
'AddEntityNodeRequest',
|
||||||
'SearchResults',
|
'SearchResults',
|
||||||
'FactResult',
|
'FactResult',
|
||||||
'Result',
|
'Result',
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,9 @@ class Result(BaseModel):
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
content: str = Field(..., description='The content of the message')
|
content: str = Field(..., description='The content of the message')
|
||||||
|
uuid: str | None = Field(default=None, description='The uuid of the message (optional)')
|
||||||
name: str = Field(
|
name: str = Field(
|
||||||
default='', description='The name of the episodic node for the message (message uuid)'
|
default='', description='The name of the episodic node for the message (optional)'
|
||||||
)
|
)
|
||||||
role_type: Literal['user', 'assistant', 'system'] = Field(
|
role_type: Literal['user', 'assistant', 'system'] = Field(
|
||||||
..., description='The role type of the message (user, assistant or system)'
|
..., description='The role type of the message (user, assistant or system)'
|
||||||
|
|
|
||||||
|
|
@ -6,3 +6,10 @@ from graph_service.dto.common import Message
|
||||||
class AddMessagesRequest(BaseModel):
|
class AddMessagesRequest(BaseModel):
|
||||||
group_id: str = Field(..., description='The group id of the messages to add')
|
group_id: str = Field(..., description='The group id of the messages to add')
|
||||||
messages: list[Message] = Field(..., description='The messages to add')
|
messages: list[Message] = Field(..., description='The messages to add')
|
||||||
|
|
||||||
|
|
||||||
|
class AddEntityNodeRequest(BaseModel):
|
||||||
|
uuid: str = Field(..., description='The uuid of the node to add')
|
||||||
|
group_id: str = Field(..., description='The group id of the node to add')
|
||||||
|
name: str = Field(..., description='The name of the node to add')
|
||||||
|
summary: str = Field(default='', description='The summary of the node to add')
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -10,9 +9,6 @@ class SearchQuery(BaseModel):
|
||||||
group_id: str = Field(..., description='The group id of the memory to get')
|
group_id: str = Field(..., description='The group id of the memory to get')
|
||||||
query: str
|
query: str
|
||||||
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
||||||
search_type: Literal['facts', 'user_centered_facts'] = Field(
|
|
||||||
default='facts', description='The type of search to perform'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FactResult(BaseModel):
|
class FactResult(BaseModel):
|
||||||
|
|
@ -24,6 +20,9 @@ class FactResult(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
expired_at: datetime | None
|
expired_at: datetime | None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_encoders = {datetime: lambda v: v.astimezone(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
class SearchResults(BaseModel):
|
class SearchResults(BaseModel):
|
||||||
facts: list[FactResult]
|
facts: list[FactResult]
|
||||||
|
|
@ -32,6 +31,9 @@ class SearchResults(BaseModel):
|
||||||
class GetMemoryRequest(BaseModel):
|
class GetMemoryRequest(BaseModel):
|
||||||
group_id: str = Field(..., description='The group id of the memory to get')
|
group_id: str = Field(..., description='The group id of the memory to get')
|
||||||
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
|
||||||
|
center_node_uuid: str | None = Field(
|
||||||
|
..., description='The uuid of the node to center the retrieval on'
|
||||||
|
)
|
||||||
messages: list[Message] = Field(
|
messages: list[Message] = Field(
|
||||||
..., description='The messages to build the retrieval query from '
|
..., description='The messages to build the retrieval query from '
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from fastapi import APIRouter, FastAPI, status
|
||||||
from graphiti_core.nodes import EpisodeType # type: ignore
|
from graphiti_core.nodes import EpisodeType # type: ignore
|
||||||
from graphiti_core.utils import clear_data # type: ignore
|
from graphiti_core.utils import clear_data # type: ignore
|
||||||
|
|
||||||
from graph_service.dto import AddMessagesRequest, Message, Result
|
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
|
||||||
from graph_service.zep_graphiti import ZepGraphitiDep
|
from graph_service.zep_graphiti import ZepGraphitiDep
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,6 +69,20 @@ async def add_messages(
|
||||||
return Result(message='Messages added to processing queue', success=True)
|
return Result(message='Messages added to processing queue', success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/entity-node', status_code=status.HTTP_201_CREATED)
|
||||||
|
async def add_entity_node(
|
||||||
|
request: AddEntityNodeRequest,
|
||||||
|
graphiti: ZepGraphitiDep,
|
||||||
|
):
|
||||||
|
node = await graphiti.save_entity_node(
|
||||||
|
uuid=request.uuid,
|
||||||
|
group_id=request.group_id,
|
||||||
|
name=request.name,
|
||||||
|
summary=request.summary,
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
@router.post('/clear', status_code=status.HTTP_200_OK)
|
@router.post('/clear', status_code=status.HTTP_200_OK)
|
||||||
async def clear(
|
async def clear(
|
||||||
graphiti: ZepGraphitiDep,
|
graphiti: ZepGraphitiDep,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status
|
||||||
|
|
||||||
from graph_service.dto import (
|
from graph_service.dto import (
|
||||||
|
|
@ -14,16 +16,10 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.post('/search', status_code=status.HTTP_200_OK)
|
@router.post('/search', status_code=status.HTTP_200_OK)
|
||||||
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
||||||
center_node_uuid: str | None = None
|
|
||||||
if query.search_type == 'user_centered_facts':
|
|
||||||
user_node = await graphiti.get_user_node(query.group_id)
|
|
||||||
if user_node:
|
|
||||||
center_node_uuid = user_node.uuid
|
|
||||||
relevant_edges = await graphiti.search(
|
relevant_edges = await graphiti.search(
|
||||||
group_ids=[query.group_id],
|
group_ids=[query.group_id],
|
||||||
query=query.query,
|
query=query.query,
|
||||||
num_results=query.max_facts,
|
num_results=query.max_facts,
|
||||||
center_node_uuid=center_node_uuid,
|
|
||||||
)
|
)
|
||||||
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
|
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
|
||||||
return SearchResults(
|
return SearchResults(
|
||||||
|
|
@ -31,6 +27,14 @@ async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
|
||||||
|
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
|
||||||
|
episodes = await graphiti.retrieve_episodes(
|
||||||
|
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
|
||||||
|
)
|
||||||
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
@router.post('/get-memory', status_code=status.HTTP_200_OK)
|
@router.post('/get-memory', status_code=status.HTTP_200_OK)
|
||||||
async def get_memory(
|
async def get_memory(
|
||||||
request: GetMemoryRequest,
|
request: GetMemoryRequest,
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,19 @@ from graph_service.dto import FactResult
|
||||||
|
|
||||||
|
|
||||||
class ZepGraphiti(Graphiti):
|
class ZepGraphiti(Graphiti):
|
||||||
def __init__(
|
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
|
||||||
self, uri: str, user: str, password: str, user_id: str, llm_client: LLMClient | None = None
|
|
||||||
):
|
|
||||||
super().__init__(uri, user, password, llm_client)
|
super().__init__(uri, user, password, llm_client)
|
||||||
self.user_id = user_id
|
|
||||||
|
|
||||||
async def get_user_node(self, user_id: str) -> EntityNode | None: ...
|
async def save_entity_node(self, name: str, uuid: str, group_id: str, summary: str = ''):
|
||||||
|
new_node = EntityNode(
|
||||||
|
name=name,
|
||||||
|
uuid=uuid,
|
||||||
|
group_id=group_id,
|
||||||
|
summary=summary,
|
||||||
|
)
|
||||||
|
await new_node.generate_name_embedding(self.llm_client.get_embedder())
|
||||||
|
await new_node.save(self.driver)
|
||||||
|
return new_node
|
||||||
|
|
||||||
|
|
||||||
async def get_graphiti(settings: ZepEnvDep):
|
async def get_graphiti(settings: ZepEnvDep):
|
||||||
|
|
@ -25,7 +31,6 @@ async def get_graphiti(settings: ZepEnvDep):
|
||||||
uri=settings.neo4j_uri,
|
uri=settings.neo4j_uri,
|
||||||
user=settings.neo4j_user,
|
user=settings.neo4j_user,
|
||||||
password=settings.neo4j_password,
|
password=settings.neo4j_password,
|
||||||
user_id='test1234',
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
yield client
|
yield client
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue